|
|
|
@ -1049,7 +1049,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|
|
|
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) |
|
|
|
|
|
|
|
|
|
offload_device = model_management.unet_offload_device() |
|
|
|
|
model = model_config.get_model(sd) |
|
|
|
|
model = model_config.get_model(sd, "model.diffusion_model.") |
|
|
|
|
model = model.to(offload_device) |
|
|
|
|
model.load_model_weights(sd, "model.diffusion_model.") |
|
|
|
|
|
|
|
|
@ -1073,6 +1073,73 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|
|
|
|
|
|
|
|
|
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_unet(unet_path): #load unet in diffusers format |
|
|
|
|
sd = utils.load_torch_file(unet_path) |
|
|
|
|
parameters = calculate_parameters(sd, "") |
|
|
|
|
fp16 = model_management.should_use_fp16(model_params=parameters) |
|
|
|
|
|
|
|
|
|
match = {} |
|
|
|
|
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] |
|
|
|
|
match["model_channels"] = sd["conv_in.weight"].shape[0] |
|
|
|
|
match["in_channels"] = sd["conv_in.weight"].shape[1] |
|
|
|
|
match["adm_in_channels"] = None |
|
|
|
|
if "class_embedding.linear_1.weight" in sd: |
|
|
|
|
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1] |
|
|
|
|
|
|
|
|
|
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, |
|
|
|
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], |
|
|
|
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} |
|
|
|
|
|
|
|
|
|
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384, |
|
|
|
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} |
|
|
|
|
|
|
|
|
|
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, |
|
|
|
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} |
|
|
|
|
|
|
|
|
|
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, |
|
|
|
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} |
|
|
|
|
|
|
|
|
|
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, |
|
|
|
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} |
|
|
|
|
|
|
|
|
|
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, |
|
|
|
|
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, |
|
|
|
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], |
|
|
|
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} |
|
|
|
|
|
|
|
|
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] |
|
|
|
|
print("match", match) |
|
|
|
|
for unet_config in supported_models: |
|
|
|
|
matches = True |
|
|
|
|
for k in match: |
|
|
|
|
if match[k] != unet_config[k]: |
|
|
|
|
matches = False |
|
|
|
|
break |
|
|
|
|
if matches: |
|
|
|
|
diffusers_keys = utils.unet_to_diffusers(unet_config) |
|
|
|
|
new_sd = {} |
|
|
|
|
for k in diffusers_keys: |
|
|
|
|
if k in sd: |
|
|
|
|
new_sd[diffusers_keys[k]] = sd.pop(k) |
|
|
|
|
else: |
|
|
|
|
print(diffusers_keys[k], k) |
|
|
|
|
offload_device = model_management.unet_offload_device() |
|
|
|
|
model_config = model_detection.model_config_from_unet_config(unet_config) |
|
|
|
|
model = model_config.get_model(new_sd, "") |
|
|
|
|
model = model.to(offload_device) |
|
|
|
|
model.load_model_weights(new_sd, "") |
|
|
|
|
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) |
|
|
|
|
|
|
|
|
|
def save_checkpoint(output_path, model, clip, vae, metadata=None): |
|
|
|
|
try: |
|
|
|
|
model.patch_model() |
|
|
|
|