Browse Source

Support more controlnet models.

pull/1590/head
comfyanonymous 1 year ago
parent
commit
76cdc809bf
  1. 2
      comfy/controlnet.py
  2. 15
      comfy/model_detection.py

2
comfy/controlnet.py

@ -354,7 +354,7 @@ def load_controlnet(ckpt_path, model=None):
if controlnet_config is None: if controlnet_config is None:
use_fp16 = comfy.model_management.should_use_fp16() use_fp16 = comfy.model_management.should_use_fp16()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

15
comfy/model_detection.py

@ -1,5 +1,5 @@
import comfy.supported_models
from . import supported_models import comfy.supported_models_base
def count_blocks(state_dict_keys, prefix_string): def count_blocks(state_dict_keys, prefix_string):
count = 0 count = 0
@ -109,17 +109,20 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
return unet_config return unet_config
def model_config_from_unet_config(unet_config): def model_config_from_unet_config(unet_config):
for model_config in supported_models.models: for model_config in comfy.supported_models.models:
if model_config.matches(unet_config): if model_config.matches(unet_config):
return model_config(unet_config) return model_config(unet_config)
print("no match", unet_config) print("no match", unet_config)
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
return model_config_from_unet_config(unet_config) model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
else:
return model_config
def unet_config_from_diffusers_unet(state_dict, use_fp16): def unet_config_from_diffusers_unet(state_dict, use_fp16):
match = {} match = {}

Loading…
Cancel
Save