From ec6f16adb607fa8d14b26670106e1a09d8401e20 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 May 2024 18:02:27 -0400 Subject: [PATCH] Fix SAG. --- comfy/ldm/modules/attention.py | 6 ++++-- comfy_extras/nodes_sag.py | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2515bac5..1d5cf0da 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -420,6 +420,7 @@ class BasicTransformerBlock(nn.Module): inner_dim = dim self.is_res = inner_dim == dim + self.attn_precision = attn_precision if self.ff_in: self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) @@ -427,7 +428,7 @@ class BasicTransformerBlock(nn.Module): self.disable_self_attn = disable_self_attn self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn + context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) if disable_temporal_crossattention: @@ -441,7 +442,7 @@ class BasicTransformerBlock(nn.Module): context_dim_attn2 = context_dim self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, - heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) @@ -471,6 +472,7 @@ class BasicTransformerBlock(nn.Module): extra_options["n_heads"] = self.n_heads extra_options["dim_head"] = self.d_head + extra_options["attn_precision"] = self.attn_precision if self.ff_in: x_skip = x diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 69084e91..8d786db5 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -5,12 +5,12 @@ import math from einops import rearrange, repeat import os -from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION +from comfy.ldm.modules.attention import optimized_attention import comfy.samplers # from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output -def attention_basic_with_sim(q, k, v, heads, mask=None): +def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): ) # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale @@ -121,13 +121,13 @@ class SelfAttentionGuidance: if 1 in cond_or_uncond: uncond_index = cond_or_uncond.index(1) # do the entire attention operation, but save the attention scores to attn_scores - (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] n_slices = heads * b attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] return out else: - return optimized_attention(q, k, v, heads=heads) + return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) def post_cfg_function(args): nonlocal attn_scores