|
|
|
@ -1,9 +1,11 @@
|
|
|
|
|
|
|
|
|
|
def attention_multiply(attn, model, q, k, v, out): |
|
|
|
|
def attention_multiply(attn, model, q, k, v, out, block, layer): |
|
|
|
|
m = model.clone() |
|
|
|
|
sd = model.model_state_dict() |
|
|
|
|
|
|
|
|
|
for key in sd: |
|
|
|
|
if f"{block}_blocks.{layer}" not in key and block != "all": |
|
|
|
|
continue |
|
|
|
|
if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)): |
|
|
|
|
m.add_patches({key: (None,)}, 0.0, q) |
|
|
|
|
if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)): |
|
|
|
@ -24,14 +26,26 @@ class UNetSelfAttentionMultiply:
|
|
|
|
|
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), |
|
|
|
|
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), |
|
|
|
|
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), |
|
|
|
|
"block_id_input": ("STRING", {"multiline": False}, {"default": ""}), |
|
|
|
|
"block_id_middle": ("STRING", {"multiline": False}, {"default": ""}), |
|
|
|
|
"block_id_output": ("STRING", {"multiline": False}, {"default": ""}), |
|
|
|
|
}} |
|
|
|
|
RETURN_TYPES = ("MODEL",) |
|
|
|
|
FUNCTION = "patch" |
|
|
|
|
|
|
|
|
|
CATEGORY = "_for_testing/attention_experiments" |
|
|
|
|
|
|
|
|
|
def patch(self, model, q, k, v, out): |
|
|
|
|
m = attention_multiply("attn1", model, q, k, v, out) |
|
|
|
|
def patch(self, model, q, k, v, out, block_id_input, block_id_middle, block_id_output): |
|
|
|
|
block_layers = {"input": block_id_input, "middle": block_id_middle, "output": block_id_output} |
|
|
|
|
m = model.clone() |
|
|
|
|
if all(value == "" for value in block_layers.values()): |
|
|
|
|
m = attention_multiply("attn1", m, q, k, v, out, "all", "all") |
|
|
|
|
else: |
|
|
|
|
for block in block_layers: |
|
|
|
|
for block_id in block_layers[block].split(","): |
|
|
|
|
if block_id != "": |
|
|
|
|
block_id = int(block_id) |
|
|
|
|
m = attention_multiply("attn1", m, q, k, v, out, block, block_id) |
|
|
|
|
return (m, ) |
|
|
|
|
|
|
|
|
|
class UNetCrossAttentionMultiply: |
|
|
|
@ -42,14 +56,26 @@ class UNetCrossAttentionMultiply:
|
|
|
|
|
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), |
|
|
|
|
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), |
|
|
|
|
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), |
|
|
|
|
"block_id_input": ("STRING", {"multiline": False}, {"default": ""}), |
|
|
|
|
"block_id_middle": ("STRING", {"multiline": False}, {"default": ""}), |
|
|
|
|
"block_id_output": ("STRING", {"multiline": False}, {"default": ""}), |
|
|
|
|
}} |
|
|
|
|
RETURN_TYPES = ("MODEL",) |
|
|
|
|
FUNCTION = "patch" |
|
|
|
|
|
|
|
|
|
CATEGORY = "_for_testing/attention_experiments" |
|
|
|
|
|
|
|
|
|
def patch(self, model, q, k, v, out): |
|
|
|
|
m = attention_multiply("attn2", model, q, k, v, out) |
|
|
|
|
def patch(self, model, q, k, v, out, block_id_input, block_id_middle, block_id_output): |
|
|
|
|
block_layers = {"input": block_id_input, "middle": block_id_middle, "output": block_id_output} |
|
|
|
|
m = model.clone() |
|
|
|
|
if all(value == "" for value in block_layers.values()): |
|
|
|
|
m = attention_multiply("attn2", m, q, k, v, out, "all", "all") |
|
|
|
|
else: |
|
|
|
|
for block in block_layers: |
|
|
|
|
for block_id in block_layers[block].split(","): |
|
|
|
|
if block_id != "": |
|
|
|
|
block_id = int(block_id) |
|
|
|
|
m = attention_multiply("attn2", m, q, k, v, out, block, block_id) |
|
|
|
|
return (m, ) |
|
|
|
|
|
|
|
|
|
class CLIPAttentionMultiply: |
|
|
|
|