From 059773a6df310d44026be12140310688a16e3735 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 Apr 2024 12:50:22 -0400 Subject: [PATCH] Add some nodes to multiply the attention in UNet and Clip models. --- comfy_extras/nodes_attention_multiply.py | 120 +++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 121 insertions(+) create mode 100644 comfy_extras/nodes_attention_multiply.py diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py new file mode 100644 index 00000000..4747eb39 --- /dev/null +++ b/comfy_extras/nodes_attention_multiply.py @@ -0,0 +1,120 @@ + +def attention_multiply(attn, model, q, k, v, out): + m = model.clone() + sd = model.model_state_dict() + + for key in sd: + if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, q) + if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, k) + if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, v) + if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, out) + + return m + + +class UNetSelfAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, q, k, v, out): + m = attention_multiply("attn1", model, q, k, v, out) + return (m, ) + +class UNetCrossAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, q, k, v, out): + m = attention_multiply("attn2", model, q, k, v, out) + return (m, ) + +class CLIPAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip": ("CLIP",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, clip, q, k, v, out): + m = clip.clone() + sd = m.patcher.model_state_dict() + + for key in sd: + if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"): + m.add_patches({key: (None,)}, 0.0, q) + if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"): + m.add_patches({key: (None,)}, 0.0, k) + if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"): + m.add_patches({key: (None,)}, 0.0, v) + if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"): + m.add_patches({key: (None,)}, 0.0, out) + return (m, ) + +class UNetTemporalAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal): + m = model.clone() + sd = model.model_state_dict() + + for k in sd: + if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")): + if '.time_stack.' in k: + m.add_patches({k: (None,)}, 0.0, self_temporal) + else: + m.add_patches({k: (None,)}, 0.0, self_structural) + elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")): + if '.time_stack.' in k: + m.add_patches({k: (None,)}, 0.0, cross_temporal) + else: + m.add_patches({k: (None,)}, 0.0, cross_structural) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply, + "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply, + "CLIPAttentionMultiply": CLIPAttentionMultiply, + "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply, +} diff --git a/nodes.py b/nodes.py index a1cfd636..1651a71c 100644 --- a/nodes.py +++ b/nodes.py @@ -1944,6 +1944,7 @@ def init_custom_nodes(): "nodes_model_merging_model_specific.py", "nodes_pag.py", "nodes_align_your_steps.py", + "nodes_attention_multiply.py", ] import_failed = []