Browse Source

Properly fix attention masks in CLIP with batches.

pull/2820/head
comfyanonymous 9 months ago
parent
commit
3b9969c1c5
  1. 2
      comfy/clip_model.py
  2. 9
      comfy/ldm/modules/attention.py

2
comfy/clip_model.py

@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], attention_mask.shape[-1], attention_mask.shape[-1])
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)

9
comfy/ldm/modules/attention.py

@ -114,7 +114,8 @@ def attention_basic(q, k, v, heads, mask=None):
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
sim += mask
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(sim.shape)
sim.add_(mask)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
@ -165,6 +166,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
if query_chunk_size is None:
query_chunk_size = 512
if mask is not None:
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
hidden_states = efficient_dot_product_attention(
query,
key,
@ -223,6 +227,9 @@ def attention_split(q, k, v, heads, mask=None):
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
if mask is not None:
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False

Loading…
Cancel
Save