Browse Source

Support properly saving CosXL checkpoints.

pull/3227/head
comfyanonymous 7 months ago
parent
commit
30abc324c2
  1. 5
      comfy/sd.py
  2. 11
      comfy_extras/nodes_model_merging.py

5
comfy/sd.py

@ -600,7 +600,7 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None clip_sd = None
load_models = [model] load_models = [model]
if clip is not None: if clip is not None:
@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
model_management.load_models_gpu(load_models) model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]
comfy.utils.save_torch_file(sd, output_path, metadata=metadata) comfy.utils.save_torch_file(sd, output_path, metadata=metadata)

11
comfy_extras/nodes_model_merging.py

@ -2,7 +2,9 @@ import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_base import comfy.model_base
import comfy.model_management import comfy.model_management
import comfy.model_sampling
import torch
import folder_paths import folder_paths
import json import json
import os import os
@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting" # "v2-inpainting"
extra_keys = {}
model_sampling = model.get_model_object("model_sampling")
if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM):
if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION):
extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float()
extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float()
if model.model.model_type == comfy.model_base.ModelType.EPS: if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon" metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave: class CheckpointSave:
def __init__(self): def __init__(self):

Loading…
Cancel
Save