Browse Source

Add CheckpointSave node to save checkpoints.

The created checkpoints contain workflow metadata that can be loaded by
dragging them on top of the UI or loading them with the "Load" button.

Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI
is using for inference on your hardware. To force fp32 use: --force-fp32

Anything that patches the model weights like merging or loras will be
saved.

The output directory is currently set to: output/checkpoints but that might
change in the future.
pull/809/head
comfyanonymous 1 year ago
parent
commit
9b93b920be
  1. 4
      comfy/diffusers_convert.py
  2. 12
      comfy/model_base.py
  3. 32
      comfy/sd.py
  4. 29
      comfy/supported_models.py
  5. 12
      comfy/supported_models_base.py
  6. 14
      comfy/utils.py
  7. 44
      comfy_extras/nodes_model_merging.py
  8. 3
      nodes.py
  9. 1
      notebooks/comfyui_colab.ipynb
  10. 2
      web/scripts/app.js
  11. 5
      web/scripts/pnginfo.js
  12. 2
      web/scripts/ui.js

4
comfy/diffusers_convert.py

@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys()))
code2idx = {"q": 0, "k": 1, "v": 2}
def convert_text_enc_state_dict_v20(text_enc_dict):
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {}
capture_qkv_weight = {}
capture_qkv_bias = {}
for k, v in text_enc_dict.items():
if not k.startswith(prefix):
continue
if (
k.endswith(".self_attn.q_proj.weight")
or k.endswith(".self_attn.k_proj.weight")

12
comfy/model_base.py

@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
from . import utils
class BaseModel(torch.nn.Module):
def __init__(self, model_config, v_prediction=False):
@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, v_prediction=True):

32
comfy/sd.py

@ -545,11 +545,11 @@ class CLIP:
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
try:
self.patcher.patch_model()
self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.patcher.unpatch_model()
self.unpatch_model()
except Exception as e:
self.patcher.unpatch_model()
self.unpatch_model()
raise e
cond_out = cond
@ -564,6 +564,15 @@ class CLIP:
def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd)
def get_sd(self):
return self.cond_stage_model.state_dict()
def patch_model(self):
self.patcher.patch_model()
def unpatch_model(self):
self.patcher.unpatch_model()
class VAE:
def __init__(self, ckpt_path=None, device=None, config=None):
if config is None:
@ -665,6 +674,10 @@ class VAE:
self.first_stage_model = self.first_stage_model.cpu()
return samples
def get_sd(self):
return self.first_stage_model.state_dict()
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
@ -1135,3 +1148,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("left over keys:", left_over)
return (ModelPatcher(model), clip, vae, clipvision)
def save_checkpoint(output_path, model, clip, vae, metadata=None):
try:
model.patch_model()
clip.patch_model()
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
utils.save_torch_file(sd, output_path, metadata=metadata)
model.unpatch_model()
clip.unpatch_model()
except Exception as e:
model.unpatch_model()
clip.unpatch_model()
raise e

29
comfy/supported_models.py

@ -9,6 +9,8 @@ from . import sdxl_clip
from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
class SD15(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix[""] = "cond_stage_model.model."
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE):
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
for k in state_dict:
if k.startswith("clip_l"):
state_dict_g[k] = state_dict[k]
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
replace_prefix["clip_l"] = "conditioner.embedders.0"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)

12
comfy/supported_models_base.py

@ -64,3 +64,15 @@ class BASE:
def process_clip_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "first_stage_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)

14
comfy/utils.py

@ -2,10 +2,10 @@ import torch
import math
import struct
import comfy.checkpoint_pickle
import safetensors.torch
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
if safe_load:
@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False):
sd = pl_sd
return sd
def save_torch_file(sd, ckpt, metadata=None):
if metadata is not None:
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
else:
safetensors.torch.save_file(sd, ckpt)
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",
@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:
state_dict[k] = state_dict[k].to(dtype)
return state_dict
def safetensors_header(safetensors_path, max_size=100*1024*1024):
with open(safetensors_path, "rb") as f:
header = f.read(8)

44
comfy_extras/nodes_model_merging.py

@ -1,4 +1,8 @@
import comfy.sd
import comfy.utils
import folder_paths
import json
import os
class ModelMergeSimple:
@classmethod
@ -49,7 +53,43 @@ class ModelMergeBlocks:
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
return (m, )
class CheckpointSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks
"ModelMergeBlocks": ModelMergeBlocks,
"CheckpointSave": CheckpointSave,
}

3
nodes.py

@ -286,8 +286,7 @@ class SaveLatent:
output["latent_tensor"] = samples["samples"]
output["latent_format_version_0"] = torch.tensor([])
safetensors.torch.save_file(output, file, metadata=metadata)
comfy.utils.save_torch_file(output, file, metadata=metadata)
return {}

1
notebooks/comfyui_colab.ipynb

@ -144,6 +144,7 @@
"\n",
"\n",
"# ESRGAN upscale model\n",
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
"\n",

2
web/scripts/app.js

@ -1468,7 +1468,7 @@ export class ComfyApp {
this.loadGraphData(JSON.parse(reader.result));
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent")) {
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));

5
web/scripts/pnginfo.js

@ -55,11 +55,12 @@ export function getLatentMetadata(file) {
const dataView = new DataView(safetensorsData.buffer);
let header_size = dataView.getUint32(0, true);
let offset = 8;
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size)));
r(header.__metadata__);
};
reader.readAsArrayBuffer(file);
var slice = file.slice(0, 1024 * 1024 * 4);
reader.readAsArrayBuffer(slice);
});
}

2
web/scripts/ui.js

@ -545,7 +545,7 @@ export class ComfyUI {
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png,.latent",
accept: ".json,image/png,.latent,.safetensors",
style: {display: "none"},
parent: document.body,
onchange: () => {

Loading…
Cancel
Save