|
|
@ -1,5 +1,5 @@ |
|
|
|
|
|
|
|
import comfy.supported_models |
|
|
|
from . import supported_models |
|
|
|
import comfy.supported_models_base |
|
|
|
|
|
|
|
|
|
|
|
def count_blocks(state_dict_keys, prefix_string): |
|
|
|
def count_blocks(state_dict_keys, prefix_string): |
|
|
|
count = 0 |
|
|
|
count = 0 |
|
|
@ -109,17 +109,20 @@ def detect_unet_config(state_dict, key_prefix, use_fp16): |
|
|
|
return unet_config |
|
|
|
return unet_config |
|
|
|
|
|
|
|
|
|
|
|
def model_config_from_unet_config(unet_config): |
|
|
|
def model_config_from_unet_config(unet_config): |
|
|
|
for model_config in supported_models.models: |
|
|
|
for model_config in comfy.supported_models.models: |
|
|
|
if model_config.matches(unet_config): |
|
|
|
if model_config.matches(unet_config): |
|
|
|
return model_config(unet_config) |
|
|
|
return model_config(unet_config) |
|
|
|
|
|
|
|
|
|
|
|
print("no match", unet_config) |
|
|
|
print("no match", unet_config) |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): |
|
|
|
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False): |
|
|
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) |
|
|
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) |
|
|
|
return model_config_from_unet_config(unet_config) |
|
|
|
model_config = model_config_from_unet_config(unet_config) |
|
|
|
|
|
|
|
if model_config is None and use_base_if_no_match: |
|
|
|
|
|
|
|
return comfy.supported_models_base.BASE(unet_config) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
return model_config |
|
|
|
|
|
|
|
|
|
|
|
def unet_config_from_diffusers_unet(state_dict, use_fp16): |
|
|
|
def unet_config_from_diffusers_unet(state_dict, use_fp16): |
|
|
|
match = {} |
|
|
|
match = {} |
|
|
|