|
|
|
@ -2,7 +2,9 @@ import comfy.sd
|
|
|
|
|
import comfy.utils |
|
|
|
|
import comfy.model_base |
|
|
|
|
import comfy.model_management |
|
|
|
|
import comfy.model_sampling |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import folder_paths |
|
|
|
|
import json |
|
|
|
|
import os |
|
|
|
@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|
|
|
|
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", |
|
|
|
|
# "v2-inpainting" |
|
|
|
|
|
|
|
|
|
extra_keys = {} |
|
|
|
|
model_sampling = model.get_model_object("model_sampling") |
|
|
|
|
if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM): |
|
|
|
|
if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION): |
|
|
|
|
extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float() |
|
|
|
|
extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float() |
|
|
|
|
|
|
|
|
|
if model.model.model_type == comfy.model_base.ModelType.EPS: |
|
|
|
|
metadata["modelspec.predict_key"] = "epsilon" |
|
|
|
|
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: |
|
|
|
@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|
|
|
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors" |
|
|
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) |
|
|
|
|
|
|
|
|
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) |
|
|
|
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) |
|
|
|
|
|
|
|
|
|
class CheckpointSave: |
|
|
|
|
def __init__(self): |
|
|
|
|