@ -4,6 +4,7 @@ from enum import Enum
from comfy import model_management
from . ldm . models . autoencoder import AutoencoderKL , AutoencodingEngine
from . ldm . cascade . stage_a import StageA
from . ldm . cascade . stage_c_coder import StageC_coder
import yaml
@ -51,7 +52,7 @@ def load_clip_weights(model, sd):
if ids . dtype == torch . float32 :
sd [ ' cond_stage_model.transformer.text_model.embeddings.position_ids ' ] = ids . round ( )
sd = comfy . utils . transformers_convert ( sd , " cond_stage_model.model. " , " cond_stage_model.transformer.text_model. " , 24 )
sd = comfy . utils . clip_text_ transformers_convert( sd , " cond_stage_model.model. " , " cond_stage_model.transformer. " )
return load_model_weights ( model , sd )
@ -122,10 +123,13 @@ class CLIP:
return self . tokenizer . tokenize_with_weights ( text , return_word_ids )
def encode_from_tokens ( self , tokens , return_pooled = False ) :
self . cond_stage_model . reset_clip_options ( )
if self . layer_idx is not None :
self . cond_stage_model . clip_layer ( self . layer_idx )
else :
self . cond_stage_model . reset_clip_layer ( )
self . cond_stage_model . set_clip_options ( { " layer " : self . layer_idx } )
if return_pooled == " unprojected " :
self . cond_stage_model . set_clip_options ( { " projected_pooled " : False } )
self . load_model ( )
cond , pooled = self . cond_stage_model . encode_token_weights ( tokens )
@ -137,8 +141,11 @@ class CLIP:
tokens = self . tokenize ( text )
return self . encode_from_tokens ( tokens )
def load_sd ( self , sd ) :
return self . cond_stage_model . load_sd ( sd )
def load_sd ( self , sd , full_model = False ) :
if full_model :
return self . cond_stage_model . load_state_dict ( sd , strict = False )
else :
return self . cond_stage_model . load_sd ( sd )
def get_sd ( self ) :
return self . cond_stage_model . state_dict ( )
@ -158,6 +165,7 @@ class VAE:
self . memory_used_encode = lambda shape , dtype : ( 1767 * shape [ 2 ] * shape [ 3 ] ) * model_management . dtype_size ( dtype ) #These are for AutoencoderKL and need tweaking (should be lower)
self . memory_used_decode = lambda shape , dtype : ( 2178 * shape [ 2 ] * shape [ 3 ] * 64 ) * model_management . dtype_size ( dtype )
self . downscale_ratio = 8
self . upscale_ratio = 8
self . latent_channels = 4
self . process_input = lambda image : image * 2.0 - 1.0
self . process_output = lambda image : torch . clamp ( ( image + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
@ -176,11 +184,31 @@ class VAE:
elif " vquantizer.codebook.weight " in sd : #VQGan: stage a of stable cascade
self . first_stage_model = StageA ( )
self . downscale_ratio = 4
self . upscale_ratio = 4
#TODO
#self.memory_used_encode
#self.memory_used_decode
self . process_input = lambda image : image
self . process_output = lambda image : image
elif " backbone.1.0.block.0.1.num_batches_tracked " in sd : #effnet: encoder for stage c latent of stable cascade
self . first_stage_model = StageC_coder ( )
self . downscale_ratio = 32
self . latent_channels = 16
new_sd = { }
for k in sd :
new_sd [ " encoder. {} " . format ( k ) ] = sd [ k ]
sd = new_sd
elif " blocks.11.num_batches_tracked " in sd : #previewer: decoder for stage c latent of stable cascade
self . first_stage_model = StageC_coder ( )
self . latent_channels = 16
new_sd = { }
for k in sd :
new_sd [ " previewer. {} " . format ( k ) ] = sd [ k ]
sd = new_sd
elif " encoder.backbone.1.0.block.0.1.num_batches_tracked " in sd : #combined effnet and previewer for stable cascade
self . first_stage_model = StageC_coder ( )
self . downscale_ratio = 32
self . latent_channels = 16
else :
#default SD1.x/SD2.x VAE parameters
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
@ -188,6 +216,7 @@ class VAE:
if ' encoder.down.2.downsample.conv.weight ' not in sd : #Stable diffusion x4 upscaler VAE
ddconfig [ ' ch_mult ' ] = [ 1 , 2 , 4 ]
self . downscale_ratio = 4
self . upscale_ratio = 4
self . first_stage_model = AutoencoderKL ( ddconfig = ddconfig , embed_dim = 4 )
else :
@ -213,6 +242,15 @@ class VAE:
self . patcher = comfy . model_patcher . ModelPatcher ( self . first_stage_model , load_device = self . device , offload_device = offload_device )
def vae_encode_crop_pixels ( self , pixels ) :
x = ( pixels . shape [ 1 ] / / self . downscale_ratio ) * self . downscale_ratio
y = ( pixels . shape [ 2 ] / / self . downscale_ratio ) * self . downscale_ratio
if pixels . shape [ 1 ] != x or pixels . shape [ 2 ] != y :
x_offset = ( pixels . shape [ 1 ] % self . downscale_ratio ) / / 2
y_offset = ( pixels . shape [ 2 ] % self . downscale_ratio ) / / 2
pixels = pixels [ : , x_offset : x + x_offset , y_offset : y + y_offset , : ]
return pixels
def decode_tiled_ ( self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ) :
steps = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
@ -221,9 +259,9 @@ class VAE:
decode_fn = lambda a : self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
output = self . process_output (
( comfy . utils . tiled_scale ( samples , decode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = self . down scale_ratio, output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = self . down scale_ratio, output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = self . down scale_ratio, output_device = self . output_device , pbar = pbar ) )
( comfy . utils . tiled_scale ( samples , decode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = self . up scale_ratio, output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = self . up scale_ratio, output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = self . up scale_ratio, output_device = self . output_device , pbar = pbar ) )
/ 3.0 )
return output
@ -248,7 +286,7 @@ class VAE:
batch_number = int ( free_memory / memory_used )
batch_number = max ( 1 , batch_number )
pixel_samples = torch . empty ( ( samples_in . shape [ 0 ] , 3 , round ( samples_in . shape [ 2 ] * self . down scale_ratio) , round ( samples_in . shape [ 3 ] * self . down scale_ratio) ) , device = self . output_device )
pixel_samples = torch . empty ( ( samples_in . shape [ 0 ] , 3 , round ( samples_in . shape [ 2 ] * self . up scale_ratio) , round ( samples_in . shape [ 3 ] * self . up scale_ratio) ) , device = self . output_device )
for x in range ( 0 , samples_in . shape [ 0 ] , batch_number ) :
samples = samples_in [ x : x + batch_number ] . to ( self . vae_dtype ) . to ( self . device )
pixel_samples [ x : x + batch_number ] = self . process_output ( self . first_stage_model . decode ( samples ) . to ( self . output_device ) . float ( ) )
@ -265,6 +303,7 @@ class VAE:
return output . movedim ( 1 , - 1 )
def encode ( self , pixel_samples ) :
pixel_samples = self . vae_encode_crop_pixels ( pixel_samples )
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
try :
memory_used = self . memory_used_encode ( pixel_samples . shape , self . vae_dtype )
@ -284,6 +323,7 @@ class VAE:
return samples
def encode_tiled ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
pixel_samples = self . vae_encode_crop_pixels ( pixel_samples )
model_management . load_model_gpu ( self . patcher )
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
samples = self . encode_tiled_ ( pixel_samples , tile_x = tile_x , tile_y = tile_y , overlap = overlap )
@ -324,7 +364,10 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
for i in range ( len ( clip_data ) ) :
if " transformer.resblocks.0.ln_1.weight " in clip_data [ i ] :
clip_data [ i ] = comfy . utils . transformers_convert ( clip_data [ i ] , " " , " text_model. " , 32 )
clip_data [ i ] = comfy . utils . clip_text_transformers_convert ( clip_data [ i ] , " " , " " )
else :
if " text_projection " in clip_data [ i ] :
clip_data [ i ] [ " text_projection.weight " ] = clip_data [ i ] [ " text_projection " ] . transpose ( 0 , 1 ) #old models saved with the CLIPSave node
clip_target = EmptyClass ( )
clip_target . params = { }
@ -460,9 +503,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters = comfy . utils . calculate_parameters ( sd , " model.diffusion_model. " )
load_device = model_management . get_torch_device ( )
class WeightsLoader ( torch . nn . Module ) :
pass
model_config = model_detection . model_config_from_unet ( sd , " model.diffusion_model. " )
unet_dtype = model_management . unet_dtype ( model_params = parameters , supported_dtypes = model_config . supported_inference_dtypes )
manual_cast_dtype = model_management . unet_manual_cast ( unet_dtype , load_device , model_config . supported_inference_dtypes )
@ -487,14 +527,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
vae = VAE ( sd = vae_sd )
if output_clip :
w = WeightsLoader ( )
clip_target = model_config . clip_target ( )
if clip_target is not None :
sd = model_config . process_clip_state_dict ( sd )
if any ( k . startswith ( ' cond_stage_model. ' ) for k in sd ) :
clip_ sd = model_config . process_clip_state_dict ( sd )
if len ( clip_sd ) > 0 :
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
w . cond_stage_model = clip . cond_stage_model
load_model_weights ( w , sd )
m , u = clip . load_sd ( clip_sd , full_model = True )
if len ( m ) > 0 :
print ( " clip missing: " , m )
if len ( u ) > 0 :
print ( " clip unexpected: " , u )
else :
print ( " no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded. " )