From 76cdc809bfe562dc1026784f26ae0b9582016d6b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 23 Sep 2023 18:47:46 -0400 Subject: [PATCH] Support more controlnet models. --- comfy/controlnet.py | 2 +- comfy/model_detection.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index af0df103..ea219c7e 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -354,7 +354,7 @@ def load_controlnet(ckpt_path, model=None): if controlnet_config is None: use_fp16 = comfy.model_management.should_use_fp16() - controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config + controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 372d5a2d..787c7857 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,5 +1,5 @@ - -from . import supported_models +import comfy.supported_models +import comfy.supported_models_base def count_blocks(state_dict_keys, prefix_string): count = 0 @@ -109,17 +109,20 @@ def detect_unet_config(state_dict, key_prefix, use_fp16): return unet_config def model_config_from_unet_config(unet_config): - for model_config in supported_models.models: + for model_config in comfy.supported_models.models: if model_config.matches(unet_config): return model_config(unet_config) print("no match", unet_config) return None -def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): +def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False): unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) - return model_config_from_unet_config(unet_config) - + model_config = model_config_from_unet_config(unet_config) + if model_config is None and use_base_if_no_match: + return comfy.supported_models_base.BASE(unet_config) + else: + return model_config def unet_config_from_diffusers_unet(state_dict, use_fp16): match = {}