diff --git a/comfy/sd.py b/comfy/sd.py index 61d1916d..0eba58fc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -527,8 +527,10 @@ def load_controlnet(ckpt_path, model=None): elif key in controlnet_data: pass else: - print("error checkpoint does not contain controlnet data", ckpt_path) - return None + net = load_t2i_adapter(controlnet_data) + if net is None: + print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) + return net context_dim = controlnet_data[key].shape[1] @@ -682,15 +684,16 @@ class T2IAdapter: out += self.previous_controlnet.get_control_models() return out -def load_t2i_adapter(ckpt_path, model=None): - t2i_data = load_torch_file(ckpt_path) +def load_t2i_adapter(t2i_data): keys = t2i_data.keys() if "body.0.in_conv.weight" in keys: cin = t2i_data['body.0.in_conv.weight'].shape[1] model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) - else: + elif 'conv_in.weight' in keys: cin = t2i_data['conv_in.weight'].shape[1] model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False) + else: + return None model_ad.load_state_dict(t2i_data) return T2IAdapter(model_ad, cin // 64) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index bfa787d3..b79b7851 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -2,7 +2,6 @@ import os from comfy_extras.chainner_models import model_loading from comfy.sd import load_torch_file import model_management -from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions import torch import comfy.utils import folder_paths diff --git a/models/controlnet/put_controlnets_here b/models/controlnet/put_controlnets_and_t2i_here similarity index 100% rename from models/controlnet/put_controlnets_here rename to models/controlnet/put_controlnets_and_t2i_here diff --git a/models/t2i_adapter/put_t2i_adapter_models_here b/models/t2i_adapter/put_t2i_adapter_models_here deleted file mode 100644 index e69de29b..00000000 diff --git a/nodes.py b/nodes.py index 93268240..0beacee1 100644 --- a/nodes.py +++ b/nodes.py @@ -24,26 +24,6 @@ import model_management import importlib import folder_paths -supported_ckpt_extensions = ['.ckpt', '.pth'] -supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth'] -try: - import safetensors.torch - supported_ckpt_extensions += ['.safetensors'] - supported_pt_extensions += ['.safetensors'] -except: - print("Could not import safetensors, safetensors support disabled.") - -def recursive_search(directory): - result = [] - for root, subdir, file in os.walk(directory, followlinks=True): - for filepath in file: - #we os.path,join directory with a blank string to generate a path separator at the end. - result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - return result - -def filter_files_extensions(files, extensions): - return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) - def before_node_execution(): model_management.throw_exception_if_processing_interrupted() @@ -348,23 +328,6 @@ class ControlNetApply: c.append(n) return (c, ) -class T2IAdapterLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter") - @classmethod - def INPUT_TYPES(s): - return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}} - - RETURN_TYPES = ("CONTROL_NET",) - FUNCTION = "load_t2i_adapter" - - CATEGORY = "loaders" - - def load_t2i_adapter(self, t2i_adapter_name): - t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name) - t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path) - return (t2i_adapter,) - class CLIPLoader: @classmethod def INPUT_TYPES(s): @@ -963,7 +926,6 @@ NODE_CLASS_MAPPINGS = { "ControlNetApply": ControlNetApply, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, - "T2IAdapterLoader": T2IAdapterLoader, "StyleModelLoader": StyleModelLoader, "CLIPVisionLoader": CLIPVisionLoader, "VAEDecodeTiled": VAEDecodeTiled, diff --git a/web/scripts/app.js b/web/scripts/app.js index 86e485eb..5e5f4f56 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -614,6 +614,12 @@ class ComfyApp { if (!graphData) { graphData = defaultGraph; } + + // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now + for (let n of graphData.nodes) { + if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; + } + this.graph.configure(graphData); for (const node of this.graph._nodes) {