Browse Source

Add pytorch attention support to VAE.

pull/61/head
comfyanonymous 2 years ago
parent
commit
83f23f82b8
  1. 22
      comfy/ldm/modules/attention.py
  2. 62
      comfy/ldm/modules/diffusionmodules/model.py
  3. 11
      comfy/model_management.py
  4. 1
      main.py

22
comfy/ldm/modules/attention.py

@ -479,23 +479,19 @@ class CrossAttentionPytorch(nn.Module):
return self.to_out(out)
import sys
if model_management.xformers_enabled() == False:
if model_management.xformers_enabled():
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch
else:
if "--use-split-cross-attention" in sys.argv:
print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx
else:
if "--use-pytorch-cross-attention" in sys.argv:
print("Using pytorch cross attention")
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
CrossAttention = CrossAttentionPytorch
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
else:
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
class BasicTransformerBlock(nn.Module):

62
comfy/ldm/modules/diffusionmodules/model.py

@ -299,6 +299,64 @@ class MemoryEfficientAttnBlock(nn.Module):
out = self.proj_out(out)
return x+out
class MemoryEfficientAttnBlockPytorch(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return x+out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
@ -313,6 +371,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
@ -320,6 +380,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "vanilla-pytorch":
return MemoryEfficientAttnBlockPytorch(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)

11
comfy/model_management.py

@ -41,6 +41,14 @@ else:
except:
XFORMERS_IS_AVAILBLE = False
ENABLE_PYTORCH_ATTENTION = False
if "--use-pytorch-cross-attention" in sys.argv:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILBLE = False
if "--cpu" in sys.argv:
vram_state = CPU
@ -175,6 +183,9 @@ def xformers_enabled():
return False
return XFORMERS_IS_AVAILBLE
def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()

1
main.py

@ -15,6 +15,7 @@ if __name__ == "__main__":
print("\t--port 8188\t\t\tSet the listen port.")
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
print("\t--disable-xformers\t\tdisables xformers")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")

Loading…
Cancel
Save