|
|
|
@ -455,11 +455,7 @@ class CrossAttentionPytorch(nn.Module):
|
|
|
|
|
|
|
|
|
|
b, _, _ = q.shape |
|
|
|
|
q, k, v = map( |
|
|
|
|
lambda t: t.unsqueeze(3) |
|
|
|
|
.reshape(b, t.shape[1], self.heads, self.dim_head) |
|
|
|
|
.permute(0, 2, 1, 3) |
|
|
|
|
.reshape(b * self.heads, t.shape[1], self.dim_head) |
|
|
|
|
.contiguous(), |
|
|
|
|
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2), |
|
|
|
|
(q, k, v), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -468,10 +464,7 @@ class CrossAttentionPytorch(nn.Module):
|
|
|
|
|
if exists(mask): |
|
|
|
|
raise NotImplementedError |
|
|
|
|
out = ( |
|
|
|
|
out.unsqueeze(0) |
|
|
|
|
.reshape(b, self.heads, out.shape[1], self.dim_head) |
|
|
|
|
.permute(0, 2, 1, 3) |
|
|
|
|
.reshape(b, out.shape[1], self.heads * self.dim_head) |
|
|
|
|
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return self.to_out(out) |
|
|
|
|