Browse Source

Fix attention masks properly for multiple batches.

pull/2820/head
comfyanonymous 9 months ago
parent
commit
6bcf57ff10
  1. 6
      comfy/ldm/modules/attention.py

6
comfy/ldm/modules/attention.py

@ -118,7 +118,7 @@ def attention_basic(q, k, v, heads, mask=None):
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
sim.add_(mask)
# attention, what we cannot get enough of
@ -175,7 +175,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
hidden_states = efficient_dot_product_attention(
query,
@ -240,7 +240,7 @@ def attention_split(q, k, v, heads, mask=None):
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False

Loading…
Cancel
Save