From a373367b0cf37e9c67b30d21c207417dedfffd4f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 25 Oct 2023 20:17:28 -0400 Subject: [PATCH] Fix some OOM issues with split and sub quad attention. --- comfy/ldm/modules/attention.py | 9 +++++++-- comfy/ldm/modules/sub_quadratic_attention.py | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index dcf46748..4f10bbc3 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -222,9 +222,14 @@ def attention_split(q, k, v, heads, mask=None): mem_free_total = model_management.get_free_memory(q.device) + if _ATTN_PRECISION =="fp32": + element_size = 4 + else: + element_size = q.element_size() + gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size + modifier = 3 if element_size == 2 else 2.5 mem_required = tensor_size * modifier steps = 1 diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 4d42059b..8e8e8054 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -83,7 +83,8 @@ def _summarize_chunk( ) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() - torch.exp(attn_weights - max_score, out=attn_weights) + attn_weights -= max_score + torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1)