|
|
|
@ -331,25 +331,13 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|
|
|
|
|
|
|
|
|
# compute attention |
|
|
|
|
B, C, H, W = q.shape |
|
|
|
|
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) |
|
|
|
|
|
|
|
|
|
q, k, v = map( |
|
|
|
|
lambda t: t.unsqueeze(3) |
|
|
|
|
.reshape(B, t.shape[1], 1, C) |
|
|
|
|
.permute(0, 2, 1, 3) |
|
|
|
|
.reshape(B * 1, t.shape[1], C) |
|
|
|
|
.contiguous(), |
|
|
|
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), |
|
|
|
|
(q, k, v), |
|
|
|
|
) |
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) |
|
|
|
|
|
|
|
|
|
out = ( |
|
|
|
|
out.unsqueeze(0) |
|
|
|
|
.reshape(B, 1, out.shape[1], C) |
|
|
|
|
.permute(0, 2, 1, 3) |
|
|
|
|
.reshape(B, out.shape[1], C) |
|
|
|
|
) |
|
|
|
|
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) |
|
|
|
|
out = out.transpose(2, 3).reshape(B, C, H, W) |
|
|
|
|
out = self.proj_out(out) |
|
|
|
|
return x+out |
|
|
|
|
|
|
|
|
|