import torch from torch import einsum import torch.nn.functional as F import math from einops import rearrange, repeat import os from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION import comfy.samplers # from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output def attention_basic_with_sim(q, k, v, heads, mask=None): b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 h = heads q, k, v = map( lambda t: t.unsqueeze(3) .reshape(b, -1, heads, dim_head) .permute(0, 2, 1, 3) .reshape(b * heads, -1, dim_head) .contiguous(), (q, k, v), ) # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale del q, k if mask is not None: mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) .permute(0, 2, 1, 3) .reshape(b, -1, heads * dim_head) ) return (out, sim) def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): # reshape and GAP the attention map _, hw1, hw2 = attn.shape b, _, lh, lw = x0.shape attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] # Reshape mask = ( mask.reshape(b, *mid_shape) .unsqueeze(1) .type(attn.dtype) ) # Upsample mask = F.interpolate(mask, (lh, lw)) blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) blurred = blurred * mask + x0 * (1 - mask) return blurred def gaussian_blur_2d(img, kernel_size, sigma): ksize_half = (kernel_size - 1) * 0.5 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) x_kernel = pdf / pdf.sum() x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] img = F.pad(img, padding, mode="reflect") img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return img class SelfAttentionGuidance: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing" def patch(self, model, scale, blur_sigma): m = model.clone() attn_scores = None # TODO: make this work properly with chunked batches # currently, we can only save the attn from one UNet call def attn_and_record(q, k, v, extra_options): nonlocal attn_scores # if uncond, save the attention scores heads = extra_options["n_heads"] cond_or_uncond = extra_options["cond_or_uncond"] b = q.shape[0] // len(cond_or_uncond) if 1 in cond_or_uncond: uncond_index = cond_or_uncond.index(1) # do the entire attention operation, but save the attention scores to attn_scores (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] n_slices = heads * b attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] return out else: return optimized_attention(q, k, v, heads=heads) def post_cfg_function(args): nonlocal attn_scores uncond_attn = attn_scores sag_scale = scale sag_sigma = blur_sigma sag_threshold = 1.0 model = args["model"] uncond_pred = args["uncond_denoised"] uncond = args["uncond"] cfg_result = args["denoised"] sigma = args["sigma"] model_options = args["model_options"] x = args["input"] # create the adversarially blurred image degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded_noised = degraded + x - uncond_pred # call into the UNet (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) return cfg_result + (degraded - sag) * sag_scale m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) # from diffusers: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) return (m, ) NODE_CLASS_MAPPINGS = { "SelfAttentionGuidance": SelfAttentionGuidance, } NODE_DISPLAY_NAME_MAPPINGS = { "SelfAttentionGuidance": "Self-Attention Guidance", }