|
|
@ -442,11 +442,61 @@ class MemoryEfficientCrossAttention(nn.Module): |
|
|
|
) |
|
|
|
) |
|
|
|
return self.to_out(out) |
|
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossAttentionPytorch(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): |
|
|
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
inner_dim = dim_head * heads |
|
|
|
|
|
|
|
context_dim = default(context_dim, query_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.heads = heads |
|
|
|
|
|
|
|
self.dim_head = dim_head |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
|
|
|
|
|
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
|
|
|
|
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) |
|
|
|
|
|
|
|
self.attention_op: Optional[Any] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, context=None, mask=None): |
|
|
|
|
|
|
|
q = self.to_q(x) |
|
|
|
|
|
|
|
context = default(context, x) |
|
|
|
|
|
|
|
k = self.to_k(context) |
|
|
|
|
|
|
|
v = self.to_v(context) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b, _, _ = q.shape |
|
|
|
|
|
|
|
q, k, v = map( |
|
|
|
|
|
|
|
lambda t: t.unsqueeze(3) |
|
|
|
|
|
|
|
.reshape(b, t.shape[1], self.heads, self.dim_head) |
|
|
|
|
|
|
|
.permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
.reshape(b * self.heads, t.shape[1], self.dim_head) |
|
|
|
|
|
|
|
.contiguous(), |
|
|
|
|
|
|
|
(q, k, v), |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if exists(mask): |
|
|
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
out = ( |
|
|
|
|
|
|
|
out.unsqueeze(0) |
|
|
|
|
|
|
|
.reshape(b, self.heads, out.shape[1], self.dim_head) |
|
|
|
|
|
|
|
.permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
.reshape(b, out.shape[1], self.heads * self.dim_head) |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
import sys |
|
|
|
if XFORMERS_IS_AVAILBLE == False: |
|
|
|
if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: |
|
|
|
if "--use-split-cross-attention" in sys.argv: |
|
|
|
if "--use-split-cross-attention" in sys.argv: |
|
|
|
print("Using split optimization for cross attention") |
|
|
|
print("Using split optimization for cross attention") |
|
|
|
CrossAttention = CrossAttentionDoggettx |
|
|
|
CrossAttention = CrossAttentionDoggettx |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
if "--use-pytorch-cross-attention" in sys.argv: |
|
|
|
|
|
|
|
print("Using pytorch cross attention") |
|
|
|
|
|
|
|
torch.backends.cuda.enable_math_sdp(False) |
|
|
|
|
|
|
|
CrossAttention = CrossAttentionPytorch |
|
|
|
else: |
|
|
|
else: |
|
|
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") |
|
|
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") |
|
|
|
CrossAttention = CrossAttentionBirchSan |
|
|
|
CrossAttention = CrossAttentionBirchSan |
|
|
|