@ -418,6 +418,8 @@ def load_gligen(ckpt_path):
return comfy . model_patcher . ModelPatcher ( model , load_device = model_management . get_torch_device ( ) , offload_device = model_management . unet_offload_device ( ) )
def load_checkpoint ( config_path = None , ckpt_path = None , output_vae = True , output_clip = True , embedding_directory = None , state_dict = None , config = None ) :
logging . warning ( " Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one. " )
model , clip , vae , _ = load_checkpoint_guess_config ( ckpt_path , output_vae = output_vae , output_clip = output_clip , output_clipvision = False , embedding_directory = embedding_directory , output_model = True )
#TODO: this function is a mess and should be removed eventually
if config is None :
with open ( config_path , ' r ' ) as stream :
@ -425,81 +427,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config_params = config [ ' model ' ] [ ' params ' ]
clip_config = model_config_params [ ' cond_stage_config ' ]
scale_factor = model_config_params [ ' scale_factor ' ]
vae_config = model_config_params [ ' first_stage_config ' ]
fp16 = False
if " unet_config " in model_config_params :
if " params " in model_config_params [ " unet_config " ] :
unet_config = model_config_params [ " unet_config " ] [ " params " ]
if " use_fp16 " in unet_config :
fp16 = unet_config . pop ( " use_fp16 " )
if fp16 :
unet_config [ " dtype " ] = torch . float16
noise_aug_config = None
if " noise_aug_config " in model_config_params :
noise_aug_config = model_config_params [ " noise_aug_config " ]
model_type = model_base . ModelType . EPS
if " parameterization " in model_config_params :
if model_config_params [ " parameterization " ] == " v " :
model_type = model_base . ModelType . V_PREDICTION
clip = None
vae = None
class WeightsLoader ( torch . nn . Module ) :
pass
m = model . clone ( )
class ModelSamplingAdvanced ( comfy . model_sampling . ModelSamplingDiscrete , comfy . model_sampling . V_PREDICTION ) :
pass
m . add_object_patch ( " model_sampling " , ModelSamplingAdvanced ( model . model . model_config ) )
model = m
if state_dict is None :
state_dict = comfy . utils . load_torch_file ( ckpt_path )
class EmptyClass :
pass
model_config = comfy . supported_models_base . BASE ( { } )
from . import latent_formats
model_config . latent_format = latent_formats . SD15 ( scale_factor = scale_factor )
model_config . unet_config = model_detection . convert_config ( unet_config )
if config [ ' model ' ] [ " target " ] . endswith ( " ImageEmbeddingConditionedLatentDiffusion " ) :
model = model_base . SD21UNCLIP ( model_config , noise_aug_config [ " params " ] , model_type = model_type )
else :
model = model_base . BaseModel ( model_config , model_type = model_type )
if config [ ' model ' ] [ " target " ] . endswith ( " LatentInpaintDiffusion " ) :
model . set_inpaint ( )
if fp16 :
model = model . half ( )
offload_device = model_management . unet_offload_device ( )
model = model . to ( offload_device )
model . load_model_weights ( state_dict , " model.diffusion_model. " )
if output_vae :
vae_sd = comfy . utils . state_dict_prefix_replace ( state_dict , { " first_stage_model. " : " " } , filter_keys = True )
vae = VAE ( sd = vae_sd , config = vae_config )
if output_clip :
w = WeightsLoader ( )
clip_target = EmptyClass ( )
clip_target . params = clip_config . get ( " params " , { } )
if clip_config [ " target " ] . endswith ( " FrozenOpenCLIPEmbedder " ) :
clip_target . clip = sd2_clip . SD2ClipModel
clip_target . tokenizer = sd2_clip . SD2Tokenizer
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
w . cond_stage_model = clip . cond_stage_model . clip_h
elif clip_config [ " target " ] . endswith ( " FrozenCLIPEmbedder " ) :
clip_target . clip = sd1_clip . SD1ClipModel
clip_target . tokenizer = sd1_clip . SD1Tokenizer
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
w . cond_stage_model = clip . cond_stage_model . clip_l
load_clip_weights ( w , state_dict )
layer_idx = clip_config . get ( " params " , { } ) . get ( " layer_idx " , None )
if layer_idx is not None :
clip . clip_layer ( layer_idx )
return ( comfy . model_patcher . ModelPatcher ( model , load_device = model_management . get_torch_device ( ) , offload_device = offload_device ) , clip , vae )
return ( model , clip , vae )
def load_checkpoint_guess_config ( ckpt_path , output_vae = True , output_clip = True , output_clipvision = False , embedding_directory = None , output_model = True ) :
sd = comfy . utils . load_torch_file ( ckpt_path )