Browse Source

Merge remote-tracking branch 'origin' into frontendrefactor

pull/49/head
pythongosssss 2 years ago
parent
commit
65c432e6ee
  1. 56
      comfy/ldm/modules/attention.py
  2. 2
      execution.py
  3. 4
      nodes.py

56
comfy/ldm/modules/attention.py

@ -442,14 +442,64 @@ class MemoryEfficientCrossAttention(nn.Module):
) )
return self.to_out(out) return self.to_out(out)
class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
import sys import sys
if XFORMERS_IS_AVAILBLE == False: if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv:
if "--use-split-cross-attention" in sys.argv: if "--use-split-cross-attention" in sys.argv:
print("Using split optimization for cross attention") print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx CrossAttention = CrossAttentionDoggettx
else: else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") if "--use-pytorch-cross-attention" in sys.argv:
CrossAttention = CrossAttentionBirchSan print("Using pytorch cross attention")
torch.backends.cuda.enable_math_sdp(False)
CrossAttention = CrossAttentionPytorch
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
else: else:
print("Using xformers cross attention") print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention CrossAttention = MemoryEfficientCrossAttention

2
execution.py

@ -135,6 +135,8 @@ class PromptExecutor:
self.server = server self.server = server
def execute(self, prompt, extra_data={}): def execute(self, prompt, extra_data={}):
nodes.interrupt_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"] self.server.client_id = extra_data["client_id"]
else: else:

4
nodes.py

@ -45,8 +45,8 @@ def filter_files_extensions(files, extensions):
def before_node_execution(): def before_node_execution():
model_management.throw_exception_if_processing_interrupted() model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(): def interrupt_processing(value=True):
model_management.interrupt_current_processing() model_management.interrupt_current_processing(value)
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod

Loading…
Cancel
Save