Browse Source

This makes pytorch2.0 attention perform a bit faster.

pull/553/head
comfyanonymous 2 years ago
parent
commit
6908f9c949
  1. 11
      comfy/ldm/modules/attention.py

11
comfy/ldm/modules/attention.py

@ -455,11 +455,7 @@ class CrossAttentionPytorch(nn.Module):
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
(q, k, v),
)
@ -468,10 +464,7 @@ class CrossAttentionPytorch(nn.Module):
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
)
return self.to_out(out)

Loading…
Cancel
Save