diff --git a/comfy/sd.py b/comfy/sd.py index 16dc0b73..ceb080b3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -418,6 +418,8 @@ def load_gligen(ckpt_path): return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): + logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.") + model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True) #TODO: this function is a mess and should be removed eventually if config is None: with open(config_path, 'r') as stream: @@ -425,81 +427,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model_config_params = config['model']['params'] clip_config = model_config_params['cond_stage_config'] scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - - fp16 = False - if "unet_config" in model_config_params: - if "params" in model_config_params["unet_config"]: - unet_config = model_config_params["unet_config"]["params"] - if "use_fp16" in unet_config: - fp16 = unet_config.pop("use_fp16") - if fp16: - unet_config["dtype"] = torch.float16 - - noise_aug_config = None - if "noise_aug_config" in model_config_params: - noise_aug_config = model_config_params["noise_aug_config"] - - model_type = model_base.ModelType.EPS if "parameterization" in model_config_params: if model_config_params["parameterization"] == "v": - model_type = model_base.ModelType.V_PREDICTION - - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass + m = model.clone() + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.model_sampling.V_PREDICTION): + pass + m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config)) + model = m - if state_dict is None: - state_dict = comfy.utils.load_torch_file(ckpt_path) - - class EmptyClass: - pass - - model_config = comfy.supported_models_base.BASE({}) - - from . import latent_formats - model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) - model_config.unet_config = model_detection.convert_config(unet_config) - - if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): - model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) - else: - model = model_base.BaseModel(model_config, model_type=model_type) - - if config['model']["target"].endswith("LatentInpaintDiffusion"): - model.set_inpaint() - - if fp16: - model = model.half() - - offload_device = model_management.unet_offload_device() - model = model.to(offload_device) - model.load_model_weights(state_dict, "model.diffusion_model.") - - if output_vae: - vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True) - vae = VAE(sd=vae_sd, config=vae_config) - - if output_clip: - w = WeightsLoader() - clip_target = EmptyClass() - clip_target.params = clip_config.get("params", {}) - if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): - clip_target.clip = sd2_clip.SD2ClipModel - clip_target.tokenizer = sd2_clip.SD2Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model.clip_h - elif clip_config["target"].endswith("FrozenCLIPEmbedder"): - clip_target.clip = sd1_clip.SD1ClipModel - clip_target.tokenizer = sd1_clip.SD1Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model.clip_l - load_clip_weights(w, state_dict) + layer_idx = clip_config.get("params", {}).get("layer_idx", None) + if layer_idx is not None: + clip.clip_layer(layer_idx) - return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) + return (model, clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): sd = comfy.utils.load_torch_file(ckpt_path)