From 8b41699da39cbcbb9048c8aa977c9ceaef9b7bf0 Mon Sep 17 00:00:00 2001 From: Extraltodeus Date: Thu, 16 May 2024 00:55:06 +0200 Subject: [PATCH] added optional layer filters Comma separated, like this: 3,4,7,8 And if nothing is entered then it applies globally! :) --- comfy_extras/nodes_attention_multiply.py | 36 ++++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index 4747eb39..628ec713 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -1,9 +1,11 @@ -def attention_multiply(attn, model, q, k, v, out): +def attention_multiply(attn, model, q, k, v, out, block, layer): m = model.clone() sd = model.model_state_dict() for key in sd: + if f"{block}_blocks.{layer}" not in key and block != "all": + continue if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)): m.add_patches({key: (None,)}, 0.0, q) if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)): @@ -24,14 +26,26 @@ class UNetSelfAttentionMultiply: "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "block_id_input": ("STRING", {"multiline": False}, {"default": ""}), + "block_id_middle": ("STRING", {"multiline": False}, {"default": ""}), + "block_id_output": ("STRING", {"multiline": False}, {"default": ""}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing/attention_experiments" - def patch(self, model, q, k, v, out): - m = attention_multiply("attn1", model, q, k, v, out) + def patch(self, model, q, k, v, out, block_id_input, block_id_middle, block_id_output): + block_layers = {"input": block_id_input, "middle": block_id_middle, "output": block_id_output} + m = model.clone() + if all(value == "" for value in block_layers.values()): + m = attention_multiply("attn1", m, q, k, v, out, "all", "all") + else: + for block in block_layers: + for block_id in block_layers[block].split(","): + if block_id != "": + block_id = int(block_id) + m = attention_multiply("attn1", m, q, k, v, out, block, block_id) return (m, ) class UNetCrossAttentionMultiply: @@ -42,14 +56,26 @@ class UNetCrossAttentionMultiply: "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "block_id_input": ("STRING", {"multiline": False}, {"default": ""}), + "block_id_middle": ("STRING", {"multiline": False}, {"default": ""}), + "block_id_output": ("STRING", {"multiline": False}, {"default": ""}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing/attention_experiments" - def patch(self, model, q, k, v, out): - m = attention_multiply("attn2", model, q, k, v, out) + def patch(self, model, q, k, v, out, block_id_input, block_id_middle, block_id_output): + block_layers = {"input": block_id_input, "middle": block_id_middle, "output": block_id_output} + m = model.clone() + if all(value == "" for value in block_layers.values()): + m = attention_multiply("attn2", m, q, k, v, out, "all", "all") + else: + for block in block_layers: + for block_id in block_layers[block].split(","): + if block_id != "": + block_id = int(block_id) + m = attention_multiply("attn2", m, q, k, v, out, block, block_id) return (m, ) class CLIPAttentionMultiply: