Browse Source

added optional layer filters

Comma separated, like this:
3,4,7,8

And if nothing is entered then it applies globally! :)
pull/3493/head
Extraltodeus 6 months ago committed by GitHub
parent
commit
8b41699da3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 36
      comfy_extras/nodes_attention_multiply.py

36
comfy_extras/nodes_attention_multiply.py

@ -1,9 +1,11 @@
def attention_multiply(attn, model, q, k, v, out):
def attention_multiply(attn, model, q, k, v, out, block, layer):
m = model.clone()
sd = model.model_state_dict()
for key in sd:
if f"{block}_blocks.{layer}" not in key and block != "all":
continue
if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)):
m.add_patches({key: (None,)}, 0.0, q)
if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)):
@ -24,14 +26,26 @@ class UNetSelfAttentionMultiply:
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"block_id_input": ("STRING", {"multiline": False}, {"default": ""}),
"block_id_middle": ("STRING", {"multiline": False}, {"default": ""}),
"block_id_output": ("STRING", {"multiline": False}, {"default": ""}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing/attention_experiments"
def patch(self, model, q, k, v, out):
m = attention_multiply("attn1", model, q, k, v, out)
def patch(self, model, q, k, v, out, block_id_input, block_id_middle, block_id_output):
block_layers = {"input": block_id_input, "middle": block_id_middle, "output": block_id_output}
m = model.clone()
if all(value == "" for value in block_layers.values()):
m = attention_multiply("attn1", m, q, k, v, out, "all", "all")
else:
for block in block_layers:
for block_id in block_layers[block].split(","):
if block_id != "":
block_id = int(block_id)
m = attention_multiply("attn1", m, q, k, v, out, block, block_id)
return (m, )
class UNetCrossAttentionMultiply:
@ -42,14 +56,26 @@ class UNetCrossAttentionMultiply:
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"block_id_input": ("STRING", {"multiline": False}, {"default": ""}),
"block_id_middle": ("STRING", {"multiline": False}, {"default": ""}),
"block_id_output": ("STRING", {"multiline": False}, {"default": ""}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing/attention_experiments"
def patch(self, model, q, k, v, out):
m = attention_multiply("attn2", model, q, k, v, out)
def patch(self, model, q, k, v, out, block_id_input, block_id_middle, block_id_output):
block_layers = {"input": block_id_input, "middle": block_id_middle, "output": block_id_output}
m = model.clone()
if all(value == "" for value in block_layers.values()):
m = attention_multiply("attn2", m, q, k, v, out, "all", "all")
else:
for block in block_layers:
for block_id in block_layers[block].split(","):
if block_id != "":
block_id = int(block_id)
m = attention_multiply("attn2", m, q, k, v, out, block, block_id)
return (m, )
class CLIPAttentionMultiply:

Loading…
Cancel
Save