Browse Source

Fix some OOM issues with split and sub quad attention.

pull/1861/head
comfyanonymous 1 year ago
parent
commit
a373367b0c
  1. 9
      comfy/ldm/modules/attention.py
  2. 3
      comfy/ldm/modules/sub_quadratic_attention.py

9
comfy/ldm/modules/attention.py

@ -222,9 +222,14 @@ def attention_split(q, k, v, heads, mask=None):
mem_free_total = model_management.get_free_memory(q.device) mem_free_total = model_management.get_free_memory(q.device)
if _ATTN_PRECISION =="fp32":
element_size = 4
else:
element_size = q.element_size()
gb = 1024 ** 3 gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3 if q.element_size() == 2 else 2.5 modifier = 3 if element_size == 2 else 2.5
mem_required = tensor_size * modifier mem_required = tensor_size * modifier
steps = 1 steps = 1

3
comfy/ldm/modules/sub_quadratic_attention.py

@ -83,7 +83,8 @@ def _summarize_chunk(
) )
max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach() max_score = max_score.detach()
torch.exp(attn_weights - max_score, out=attn_weights) attn_weights -= max_score
torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype) exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value) exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1) max_score = max_score.squeeze(-1)

Loading…
Cancel
Save