|
|
|
@ -121,9 +121,20 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
|
|
|
|
return model_config_from_unet_config(unet_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model_config_from_diffusers_unet(state_dict, use_fp16): |
|
|
|
|
def unet_config_from_diffusers_unet(state_dict, use_fp16): |
|
|
|
|
match = {} |
|
|
|
|
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] |
|
|
|
|
attention_resolutions = [] |
|
|
|
|
|
|
|
|
|
attn_res = 1 |
|
|
|
|
for i in range(5): |
|
|
|
|
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i) |
|
|
|
|
if k in state_dict: |
|
|
|
|
match["context_dim"] = state_dict[k].shape[1] |
|
|
|
|
attention_resolutions.append(attn_res) |
|
|
|
|
attn_res *= 2 |
|
|
|
|
|
|
|
|
|
match["attention_resolutions"] = attention_resolutions |
|
|
|
|
|
|
|
|
|
match["model_channels"] = state_dict["conv_in.weight"].shape[0] |
|
|
|
|
match["in_channels"] = state_dict["conv_in.weight"].shape[1] |
|
|
|
|
match["adm_in_channels"] = None |
|
|
|
@ -135,22 +146,22 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
|
|
|
|
|
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, |
|
|
|
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], |
|
|
|
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} |
|
|
|
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} |
|
|
|
|
|
|
|
|
|
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, |
|
|
|
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} |
|
|
|
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} |
|
|
|
|
|
|
|
|
|
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, |
|
|
|
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} |
|
|
|
|
|
|
|
|
|
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, |
|
|
|
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} |
|
|
|
|
|
|
|
|
|
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, |
|
|
|
@ -160,9 +171,14 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
|
|
|
|
|
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, |
|
|
|
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} |
|
|
|
|
|
|
|
|
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] |
|
|
|
|
SDXL_mini_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, |
|
|
|
|
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} |
|
|
|
|
|
|
|
|
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mini_cnet] |
|
|
|
|
|
|
|
|
|
for unet_config in supported_models: |
|
|
|
|
matches = True |
|
|
|
@ -171,5 +187,11 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
|
|
|
|
|
matches = False |
|
|
|
|
break |
|
|
|
|
if matches: |
|
|
|
|
return model_config_from_unet_config(unet_config) |
|
|
|
|
return unet_config |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
def model_config_from_diffusers_unet(state_dict, use_fp16): |
|
|
|
|
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16) |
|
|
|
|
if unet_config is not None: |
|
|
|
|
return model_config_from_unet_config(unet_config) |
|
|
|
|
return None |
|
|
|
|