@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
return unet_config
def model_config_from_unet_config ( unet_config ) :
def model_config_from_unet_config ( unet_config , state_dict = None ) :
for model_config in comfy . supported_models . models :
if model_config . matches ( unet_config ) :
if model_config . matches ( unet_config , state_dict ) :
return model_config ( unet_config )
logging . error ( " no match {} " . format ( unet_config ) )
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
def model_config_from_unet ( state_dict , unet_key_prefix , use_base_if_no_match = False ) :
unet_config = detect_unet_config ( state_dict , unet_key_prefix )
model_config = model_config_from_unet_config ( unet_config )
model_config = model_config_from_unet_config ( unet_config , state_dict )
if model_config is None and use_base_if_no_match :
return comfy . supported_models_base . BASE ( unet_config )
else :
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 0 , 0 , 0 , 2 , 2 , 2 , 10 , 10 , 10 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
SDXL_diffusers_ip2p = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 8 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 2 , 2 , 10 , 10 ] , ' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : 10 ,
' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 0 , 0 , 0 , 2 , 2 , 2 , 10 , 10 , 10 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
SSD_1B = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 2 , 2 , 4 , 4 ] , ' transformer_depth_output ' : [ 0 , 0 , 0 , 1 , 1 , 2 , 10 , 4 , 4 ] ,
@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
' context_dim ' : 1024 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 1 , 1 , 1 , 1 , 1 , 1 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False , ' disable_self_attentions ' : [ True , False , False ] }
supported_models = [ SDXL , SDXL_refiner , SD21 , SD15 , SD21_uncliph , SD21_unclipl , SDXL_mid_cnet , SDXL_small_cnet , SDXL_diffusers_inpaint , SSD_1B , Segmind_Vega , KOALA_700M , KOALA_1B , SD09_XS ]
supported_models = [ SDXL , SDXL_refiner , SD21 , SD15 , SD21_uncliph , SD21_unclipl , SDXL_mid_cnet , SDXL_small_cnet , SDXL_diffusers_inpaint , SSD_1B , Segmind_Vega , KOALA_700M , KOALA_1B , SD09_XS , SDXL_diffusers_ip2p ]
for unet_config in supported_models :
matches = True