Browse Source

Fix SAG.

pull/2569/head
comfyanonymous 6 months ago
parent
commit
ec6f16adb6
  1. 6
      comfy/ldm/modules/attention.py
  2. 10
      comfy_extras/nodes_sag.py

6
comfy/ldm/modules/attention.py

@ -420,6 +420,7 @@ class BasicTransformerBlock(nn.Module):
inner_dim = dim inner_dim = dim
self.is_res = inner_dim == dim self.is_res = inner_dim == dim
self.attn_precision = attn_precision
if self.ff_in: if self.ff_in:
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
@ -427,7 +428,7 @@ class BasicTransformerBlock(nn.Module):
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention: if disable_temporal_crossattention:
@ -441,7 +442,7 @@ class BasicTransformerBlock(nn.Module):
context_dim_attn2 = context_dim context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
@ -471,6 +472,7 @@ class BasicTransformerBlock(nn.Module):
extra_options["n_heads"] = self.n_heads extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head extra_options["dim_head"] = self.d_head
extra_options["attn_precision"] = self.attn_precision
if self.ff_in: if self.ff_in:
x_skip = x x_skip = x

10
comfy_extras/nodes_sag.py

@ -5,12 +5,12 @@ import math
from einops import rearrange, repeat from einops import rearrange, repeat
import os import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION from comfy.ldm.modules.attention import optimized_attention
import comfy.samplers import comfy.samplers
# from comfy/ldm/modules/attention.py # from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output # but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None): def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
) )
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else: else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -121,13 +121,13 @@ class SelfAttentionGuidance:
if 1 in cond_or_uncond: if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1) uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores # do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads) (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out return out
else: else:
return optimized_attention(q, k, v, heads=heads) return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
def post_cfg_function(args): def post_cfg_function(args):
nonlocal attn_scores nonlocal attn_scores

Loading…
Cancel
Save