@ -23,6 +23,7 @@ import comfy.model_patcher
import comfy . lora
import comfy . t2i_adapter . adapter
import comfy . supported_models_base
import comfy . taesd . taesd
def load_model_weights ( model , sd ) :
m , u = model . load_state_dict ( sd , strict = False )
@ -154,10 +155,16 @@ class VAE:
if ' decoder.up_blocks.0.resnets.0.norm1.weight ' in sd . keys ( ) : #diffusers format
sd = diffusers_convert . convert_vae_state_dict ( sd )
self . memory_used_encode = lambda shape : ( 2078 * shape [ 2 ] * shape [ 3 ] ) * 1.7 #These are for AutoencoderKL and need tweaking
self . memory_used_decode = lambda shape : ( 2562 * shape [ 2 ] * shape [ 3 ] * 64 ) * 1.7
if config is None :
#default SD1.x/SD2.x VAE parameters
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
self . first_stage_model = AutoencoderKL ( ddconfig = ddconfig , embed_dim = 4 )
if " taesd_decoder.1.weight " in sd :
self . first_stage_model = comfy . taesd . taesd . TAESD ( )
else :
#default SD1.x/SD2.x VAE parameters
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
self . first_stage_model = AutoencoderKL ( ddconfig = ddconfig , embed_dim = 4 )
else :
self . first_stage_model = AutoencoderKL ( * * ( config [ ' params ' ] ) )
self . first_stage_model = self . first_stage_model . eval ( )
@ -206,7 +213,7 @@ class VAE:
def decode ( self , samples_in ) :
self . first_stage_model = self . first_stage_model . to ( self . device )
try :
memory_used = ( 2562 * samples_in . shape [ 2 ] * samples_in . shape [ 3 ] * 64 ) * 1.7
memory_used = self . memory_used_decode ( samples_in . shape )
model_management . free_memory ( memory_used , self . device )
free_memory = model_management . get_free_memory ( self . device )
batch_number = int ( free_memory / memory_used )
@ -234,7 +241,7 @@ class VAE:
self . first_stage_model = self . first_stage_model . to ( self . device )
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
try :
memory_used = ( 2078 * pixel_samples . shape [ 2 ] * pixel_samples . shape [ 3 ] ) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
memory_used = self . memory_used_encode ( pixel_samples . shape ) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management . free_memory ( memory_used , self . device )
free_memory = model_management . get_free_memory ( self . device )
batch_number = int ( free_memory / memory_used )