|
|
|
@ -394,15 +394,6 @@ class CrossAttention(nn.Module):
|
|
|
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
|
|
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
if XFORMERS_IS_AVAILBLE == False: |
|
|
|
|
if "--use-split-cross-attention" in sys.argv: |
|
|
|
|
print("Using split optimization for cross attention") |
|
|
|
|
CrossAttention = CrossAttentionDoggettx |
|
|
|
|
else: |
|
|
|
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") |
|
|
|
|
CrossAttention = CrossAttentionBirchSan |
|
|
|
|
|
|
|
|
|
class MemoryEfficientCrossAttention(nn.Module): |
|
|
|
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 |
|
|
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): |
|
|
|
@ -451,23 +442,27 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
|
|
|
) |
|
|
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
if XFORMERS_IS_AVAILBLE == False: |
|
|
|
|
if "--use-split-cross-attention" in sys.argv: |
|
|
|
|
print("Using split optimization for cross attention") |
|
|
|
|
CrossAttention = CrossAttentionDoggettx |
|
|
|
|
else: |
|
|
|
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") |
|
|
|
|
CrossAttention = CrossAttentionBirchSan |
|
|
|
|
else: |
|
|
|
|
print("Using xformers cross attention") |
|
|
|
|
CrossAttention = MemoryEfficientCrossAttention |
|
|
|
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module): |
|
|
|
|
ATTENTION_MODES = { |
|
|
|
|
"softmax": CrossAttention, # vanilla attention |
|
|
|
|
"softmax-xformers": MemoryEfficientCrossAttention |
|
|
|
|
} |
|
|
|
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, |
|
|
|
|
disable_self_attn=False): |
|
|
|
|
super().__init__() |
|
|
|
|
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" |
|
|
|
|
assert attn_mode in self.ATTENTION_MODES |
|
|
|
|
attn_cls = self.ATTENTION_MODES[attn_mode] |
|
|
|
|
self.disable_self_attn = disable_self_attn |
|
|
|
|
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, |
|
|
|
|
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, |
|
|
|
|
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn |
|
|
|
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) |
|
|
|
|
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, |
|
|
|
|
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, |
|
|
|
|
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none |
|
|
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
|
|