diff --git a/comfy/sd.py b/comfy/sd.py index 85821120..3919f4bf 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -600,7 +600,7 @@ def load_unet(unet_path): raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model -def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): +def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): clip_sd = None load_models = [model] if clip is not None: @@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) + for k in extra_keys: + sd[k] = extra_keys[k] + comfy.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index a25b73ca..2a431f65 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -2,7 +2,9 @@ import comfy.sd import comfy.utils import comfy.model_base import comfy.model_management +import comfy.model_sampling +import torch import folder_paths import json import os @@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "v2-inpainting" + extra_keys = {} + model_sampling = model.get_model_object("model_sampling") + if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM): + if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION): + extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float() + extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float() + if model.model.model_type == comfy.model_base.ModelType.EPS: metadata["modelspec.predict_key"] = "epsilon" elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: @@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) class CheckpointSave: def __init__(self):