|
|
|
@ -318,11 +318,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
|
|
|
|
return attention_pytorch(q, k, v, heads, mask) |
|
|
|
|
|
|
|
|
|
q, k, v = map( |
|
|
|
|
lambda t: t.unsqueeze(3) |
|
|
|
|
.reshape(b, -1, heads, dim_head) |
|
|
|
|
.permute(0, 2, 1, 3) |
|
|
|
|
.reshape(b * heads, -1, dim_head) |
|
|
|
|
.contiguous(), |
|
|
|
|
lambda t: t.reshape(b, -1, heads, dim_head), |
|
|
|
|
(q, k, v), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -335,10 +331,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) |
|
|
|
|
|
|
|
|
|
out = ( |
|
|
|
|
out.unsqueeze(0) |
|
|
|
|
.reshape(b, heads, -1, dim_head) |
|
|
|
|
.permute(0, 2, 1, 3) |
|
|
|
|
.reshape(b, -1, heads * dim_head) |
|
|
|
|
out.reshape(b, -1, heads * dim_head) |
|
|
|
|
) |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|