|
|
|
@ -31,17 +31,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
|
|
|
|
if ids.dtype == torch.float32: |
|
|
|
|
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() |
|
|
|
|
|
|
|
|
|
keys_to_replace = { |
|
|
|
|
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", |
|
|
|
|
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", |
|
|
|
|
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight", |
|
|
|
|
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for x in keys_to_replace: |
|
|
|
|
if x in sd: |
|
|
|
|
sd[keys_to_replace[x]] = sd.pop(x) |
|
|
|
|
|
|
|
|
|
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24) |
|
|
|
|
|
|
|
|
|
for x in load_state_dict_to: |
|
|
|
@ -1073,13 +1062,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|
|
|
|
"legacy": False |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2: |
|
|
|
|
if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2: |
|
|
|
|
unet_config['use_linear_in_transformer'] = True |
|
|
|
|
|
|
|
|
|
unet_config["use_fp16"] = fp16 |
|
|
|
|
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0] |
|
|
|
|
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] |
|
|
|
|
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] |
|
|
|
|
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] |
|
|
|
|
|
|
|
|
|
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} |
|
|
|
|
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} |
|
|
|
@ -1097,10 +1086,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|
|
|
|
else: |
|
|
|
|
sd_config["conditioning_key"] = "crossattn" |
|
|
|
|
|
|
|
|
|
if unet_config["context_dim"] == 1024: |
|
|
|
|
unet_config["num_head_channels"] = 64 #SD2.x |
|
|
|
|
else: |
|
|
|
|
if unet_config["context_dim"] == 768: |
|
|
|
|
unet_config["num_heads"] = 8 #SD1.x |
|
|
|
|
else: |
|
|
|
|
unet_config["num_head_channels"] = 64 #SD2.x |
|
|
|
|
|
|
|
|
|
unclip = 'model.diffusion_model.label_emb.0.0.weight' |
|
|
|
|
if unclip in sd_keys: |
|
|
|
|