From 4201181b35402e0a992b861f8d2f0e0b267f52fa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 Apr 2024 04:25:45 -0400 Subject: [PATCH] Add ModelMergeSD1, ModelMergeSD2 and ModelMergeSDXL. --- .../nodes_model_merging_model_specific.py | 60 +++++++++++++++++++ nodes.py | 1 + 2 files changed, 61 insertions(+) create mode 100644 comfy_extras/nodes_model_merging_model_specific.py diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py new file mode 100644 index 00000000..f2d008d8 --- /dev/null +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -0,0 +1,60 @@ +import comfy_extras.nodes_model_merging + +class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["time_embed."] = argument + arg_dict["label_emb."] = argument + + for i in range(12): + arg_dict["input_blocks.{}.".format(i)] = argument + + for i in range(3): + arg_dict["middle_block.{}.".format(i)] = argument + + for i in range(12): + arg_dict["output_blocks.{}.".format(i)] = argument + + arg_dict["out."] = argument + + return {"required": arg_dict} + + +class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["time_embed."] = argument + arg_dict["label_emb."] = argument + + for i in range(9): + arg_dict["input_blocks.{}".format(i)] = argument + + for i in range(3): + arg_dict["middle_block.{}".format(i)] = argument + + for i in range(9): + arg_dict["output_blocks.{}".format(i)] = argument + + arg_dict["out."] = argument + + return {"required": arg_dict} + + +NODE_CLASS_MAPPINGS = { + "ModelMergeSD1": ModelMergeSD1, + "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks + "ModelMergeSDXL": ModelMergeSDXL, +} diff --git a/nodes.py b/nodes.py index a1baa98a..78e0cf11 100644 --- a/nodes.py +++ b/nodes.py @@ -1941,6 +1941,7 @@ def init_custom_nodes(): "nodes_stable_cascade.py", "nodes_differential_diffusion.py", "nodes_ip2p.py", + "nodes_model_merging_model_specific.py", ] import_failed = []