@ -5,6 +5,7 @@ import collections
from comfy import model_management
import math
import logging
import comfy . sampler_helpers
def get_area_and_mult ( conds , x_in , timestep_in ) :
area = ( x_in . shape [ 2 ] , x_in . shape [ 3 ] , 0 , 0 )
@ -127,30 +128,23 @@ def cond_cat(c_list):
return out
def calc_cond_uncond_batch ( model , cond , uncond , x_in , timestep , model_options ) :
out_cond = torch . zeros_like ( x_in )
out_count = torch . ones_like ( x_in ) * 1e-37
out_uncond = torch . zeros_like ( x_in )
out_uncond_count = torch . ones_like ( x_in ) * 1e-37
def calc_cond_batch ( model , conds , x_in , timestep , model_options ) :
out_conds = [ ]
out_counts = [ ]
to_run = [ ]
COND = 0
UNCOND = 1
for i in range ( len ( conds ) ) :
out_conds . append ( torch . zeros_like ( x_in ) )
out_counts . append ( torch . ones_like ( x_in ) * 1e-37 )
to_run = [ ]
cond = conds [ i ]
if cond is not None :
for x in cond :
p = get_area_and_mult ( x , x_in , timestep )
if p is None :
continue
to_run + = [ ( p , COND ) ]
if uncond is not None :
for x in uncond :
p = get_area_and_mult ( x , x_in , timestep )
if p is None :
continue
to_run + = [ ( p , UNCOND ) ]
to_run + = [ ( p , i ) ]
while len ( to_run ) > 0 :
first = to_run [ 0 ]
@ -222,32 +216,22 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
output = model_options [ ' model_function_wrapper ' ] ( model . apply_model , { " input " : input_x , " timestep " : timestep_ , " c " : c , " cond_or_uncond " : cond_or_uncond } ) . chunk ( batch_chunks )
else :
output = model . apply_model ( input_x , timestep_ , * * c ) . chunk ( batch_chunks )
del input_x
for o in range ( batch_chunks ) :
if cond_or_uncond [ o ] == COND :
out_cond [ : , : , area [ o ] [ 2 ] : area [ o ] [ 0 ] + area [ o ] [ 2 ] , area [ o ] [ 3 ] : area [ o ] [ 1 ] + area [ o ] [ 3 ] ] + = output [ o ] * mult [ o ]
out_count [ : , : , area [ o ] [ 2 ] : area [ o ] [ 0 ] + area [ o ] [ 2 ] , area [ o ] [ 3 ] : area [ o ] [ 1 ] + area [ o ] [ 3 ] ] + = mult [ o ]
else :
out_uncond [ : , : , area [ o ] [ 2 ] : area [ o ] [ 0 ] + area [ o ] [ 2 ] , area [ o ] [ 3 ] : area [ o ] [ 1 ] + area [ o ] [ 3 ] ] + = output [ o ] * mult [ o ]
out_uncond_count [ : , : , area [ o ] [ 2 ] : area [ o ] [ 0 ] + area [ o ] [ 2 ] , area [ o ] [ 3 ] : area [ o ] [ 1 ] + area [ o ] [ 3 ] ] + = mult [ o ]
del mult
cond_index = cond_or_uncond [ o ]
out_conds [ cond_index ] [ : , : , area [ o ] [ 2 ] : area [ o ] [ 0 ] + area [ o ] [ 2 ] , area [ o ] [ 3 ] : area [ o ] [ 1 ] + area [ o ] [ 3 ] ] + = output [ o ] * mult [ o ]
out_counts [ cond_index ] [ : , : , area [ o ] [ 2 ] : area [ o ] [ 0 ] + area [ o ] [ 2 ] , area [ o ] [ 3 ] : area [ o ] [ 1 ] + area [ o ] [ 3 ] ] + = mult [ o ]
out_cond / = out_count
del out_count
out_uncond / = out_uncond_count
del out_uncond_count
return out_cond , out_uncond
for i in range ( len ( out_conds ) ) :
out_conds [ i ] / = out_counts [ i ]
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function ( model , x , timestep , uncond , cond , cond_scale , model_options = { } , seed = None ) :
if math . isclose ( cond_scale , 1.0 ) and model_options . get ( " disable_cfg1_optimization " , False ) == False :
uncond_ = None
else :
uncond_ = uncond
return out_conds
def calc_cond_uncond_batch ( model , cond , uncond , x_in , timestep , model_options ) : #TODO: remove
logging . warning ( " WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead. " )
return tuple ( calc_cond_batch ( model , [ cond , uncond ] , x_in , timestep , model_options ) )
cond_pred , uncond_pred = calc_cond_uncond_batch ( model , cond , uncond_ , x , timestep , model_options )
def cfg_function ( model , cond_pred , uncond_pred , cond_scale , x , timestep , model_options = { } , cond = None , uncond = None ) :
if " sampler_cfg_function " in model_options :
args = { " cond " : x - cond_pred , " uncond " : x - uncond_pred , " cond_scale " : cond_scale , " timestep " : timestep , " input " : x , " sigma " : timestep ,
" cond_denoised " : cond_pred , " uncond_denoised " : uncond_pred , " model " : model , " model_options " : model_options }
@ -262,34 +246,36 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return cfg_result
class CFGNoisePredictor ( torch . nn . Module ) :
def __init__ ( self , model ) :
super ( ) . __init__ ( )
self . inner_model = model
def apply_model ( self , x , timestep , cond , uncond , cond_scale , model_options = { } , seed = None ) :
out = sampling_function ( self . inner_model , x , timestep , uncond , cond , cond_scale , model_options = model_options , seed = seed )
return out
def forward ( self , * args , * * kwargs ) :
return self . apply_model ( * args , * * kwargs )
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function ( model , x , timestep , uncond , cond , cond_scale , model_options = { } , seed = None ) :
if math . isclose ( cond_scale , 1.0 ) and model_options . get ( " disable_cfg1_optimization " , False ) == False :
uncond_ = None
else :
uncond_ = uncond
conds = [ cond , uncond_ ]
out = calc_cond_batch ( model , conds , x , timestep , model_options )
return cfg_function ( model , out [ 0 ] , out [ 1 ] , cond_scale , x , timestep , model_options = model_options , cond = cond , uncond = uncond_ )
class KSamplerX0Inpaint ( torch . nn . Module ) :
class KSamplerX0Inpaint :
def __init__ ( self , model , sigmas ) :
super ( ) . __init__ ( )
self . inner_model = model
self . sigmas = sigmas
def forward ( self , x , sigma , uncond , cond , cond_scale , denoise_mask , model_options = { } , seed = None ) :
def __call__ ( self , x , sigma , denoise_mask , model_options = { } , seed = None ) :
if denoise_mask is not None :
if " denoise_mask_function " in model_options :
denoise_mask = model_options [ " denoise_mask_function " ] ( sigma , denoise_mask , extra_options = { " model " : self . inner_model , " sigmas " : self . sigmas } )
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self . inner_model . inner_model . model_sampling . noise_scaling ( sigma . reshape ( [ sigma . shape [ 0 ] ] + [ 1 ] * ( len ( self . noise . shape ) - 1 ) ) , self . noise , self . latent_image ) * latent_mask
out = self . inner_model ( x , sigma , cond = cond , uncond = uncond , cond_scale = cond_scale , model_options = model_options , seed = seed )
out = self . inner_model ( x , sigma , model_options = model_options , seed = seed )
if denoise_mask is not None :
out = out * denoise_mask + self . latent_image * latent_mask
return out
def simple_scheduler ( model , steps ) :
s = model . model _sampling
def simple_scheduler ( model_sampling , steps ) :
s = model_sampling
sigs = [ ]
ss = len ( s . sigmas ) / steps
for x in range ( steps ) :
@ -297,8 +283,8 @@ def simple_scheduler(model, steps):
sigs + = [ 0.0 ]
return torch . FloatTensor ( sigs )
def ddim_scheduler ( model , steps ) :
s = model . model _sampling
def ddim_scheduler ( model_sampling , steps ) :
s = model_sampling
sigs = [ ]
ss = max ( len ( s . sigmas ) / / steps , 1 )
x = 1
@ -309,8 +295,8 @@ def ddim_scheduler(model, steps):
sigs + = [ 0.0 ]
return torch . FloatTensor ( sigs )
def normal_scheduler ( model , steps , sgm = False , floor = False ) :
s = model . model _sampling
def normal_scheduler ( model_sampling , steps , sgm = False , floor = False ) :
s = model_sampling
start = s . timestep ( s . sigma_max )
end = s . timestep ( s . sigma_min )
@ -546,6 +532,7 @@ class KSAMPLER(Sampler):
k_callback = lambda x : callback ( x [ " i " ] , x [ " denoised " ] , x [ " x " ] , total_steps )
samples = self . sampler_function ( model_k , noise , sigmas , extra_args = extra_args , callback = k_callback , disable = disable_pbar , * * self . extra_options )
samples = model_wrap . inner_model . model_sampling . inverse_noise_scaling ( sigmas [ - 1 ] , samples )
return samples
@ -559,72 +546,133 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return k_diffusion_sampling . sample_dpm_fast ( model , noise , sigma_min , sigmas [ 0 ] , total_steps , extra_args = extra_args , callback = callback , disable = disable )
sampler_function = dpm_fast_function
elif sampler_name == " dpm_adaptive " :
def dpm_adaptive_function ( model , noise , sigmas , extra_args , callback , disable ) :
def dpm_adaptive_function ( model , noise , sigmas , extra_args , callback , disable , * * extra_options ) :
sigma_min = sigmas [ - 1 ]
if sigma_min == 0 :
sigma_min = sigmas [ - 2 ]
return k_diffusion_sampling . sample_dpm_adaptive ( model , noise , sigma_min , sigmas [ 0 ] , extra_args = extra_args , callback = callback , disable = disable )
return k_diffusion_sampling . sample_dpm_adaptive ( model , noise , sigma_min , sigmas [ 0 ] , extra_args = extra_args , callback = callback , disable = disable , * * extra_options )
sampler_function = dpm_adaptive_function
else :
sampler_function = getattr ( k_diffusion_sampling , " sample_ {} " . format ( sampler_name ) )
return KSAMPLER ( sampler_function , extra_options , inpaint_options )
def wrap_model ( model ) :
model_denoise = CFGNoisePredictor ( model )
return model_denoise
def sample ( model , noise , positive , negative , cfg , device , sampler , sigmas , model_options = { } , latent_image = None , denoise_mask = None , callback = None , disable_pbar = False , seed = None ) :
positive = positive [ : ]
negative = negative [ : ]
def process_conds ( model , noise , conds , device , latent_image = None , denoise_mask = None , seed = None ) :
for k in conds :
conds [ k ] = conds [ k ] [ : ]
resolve_areas_and_cond_masks ( conds [ k ] , noise . shape [ 2 ] , noise . shape [ 3 ] , device )
for k in conds :
calculate_start_end_timesteps ( model , conds [ k ] )
if hasattr ( model , ' extra_conds ' ) :
for k in conds :
conds [ k ] = encode_model_conds ( model . extra_conds , conds [ k ] , noise , device , k , latent_image = latent_image , denoise_mask = denoise_mask , seed = seed )
#make sure each cond area has an opposite one with the same area
for k in conds :
for c in conds [ k ] :
for kk in conds :
if k != kk :
create_cond_with_same_area_if_none ( conds [ kk ] , c )
for k in conds :
pre_run_control ( model , conds [ k ] )
if " positive " in conds :
positive = conds [ " positive " ]
for k in conds :
if k != " positive " :
apply_empty_x_to_equal_area ( list ( filter ( lambda c : c . get ( ' control_apply_to_uncond ' , False ) == True , positive ) ) , conds [ k ] , ' control ' , lambda cond_cnets , x : cond_cnets [ x ] )
apply_empty_x_to_equal_area ( positive , conds [ k ] , ' gligen ' , lambda cond_cnets , x : cond_cnets [ x ] )
return conds
class CFGGuider :
def __init__ ( self , model_patcher ) :
self . model_patcher = model_patcher
self . model_options = model_patcher . model_options
self . original_conds = { }
self . cfg = 1.0
def set_conds ( self , positive , negative ) :
self . inner_set_conds ( { " positive " : positive , " negative " : negative } )
def set_cfg ( self , cfg ) :
self . cfg = cfg
resolve_areas_and_cond_masks ( positive , noise . shape [ 2 ] , noise . shape [ 3 ] , device )
resolve_areas_and_cond_masks ( negative , noise . shape [ 2 ] , noise . shape [ 3 ] , device )
def inner_set_conds ( self , conds ) :
for k in conds :
self . original_conds [ k ] = comfy . sampler_helpers . convert_cond ( conds [ k ] )
model_wrap = wrap_model ( model )
def __call__ ( self , * args , * * kwargs ) :
return self . predict_noise ( * args , * * kwargs )
calculate_start_end_timesteps ( model , negative )
calculate_start_end_timesteps ( model , positive )
def predict_noise ( self , x , timestep , model_options = { } , seed = None ) :
return sampling_function ( self . inner_model , x , timestep , self . conds . get ( " negative " , None ) , self . conds . get ( " positive " , None ) , self . cfg , model_options = model_options , seed = seed )
def inner_sample ( self , noise , latent_image , device , sampler , sigmas , denoise_mask , callback , disable_pbar , seed ) :
if latent_image is not None and torch . count_nonzero ( latent_image ) > 0 : #Don't shift the empty latent image.
latent_image = model . process_latent_in ( latent_image )
latent_image = self . inner_ model. process_latent_in ( latent_image )
if hasattr ( model , ' extra_conds ' ) :
positive = encode_model_conds ( model . extra_conds , positive , noise , device , " positive " , latent_image = latent_image , denoise_mask = denoise_mask , seed = seed )
negative = encode_model_conds ( model . extra_conds , negative , noise , device , " negative " , latent_image = latent_image , denoise_mask = denoise_mask , seed = seed )
self . conds = process_conds ( self . inner_model , noise , self . conds , device , latent_image , denoise_mask , seed )
#make sure each cond area has an opposite one with the same area
for c in positive :
create_cond_with_same_area_if_none ( negative , c )
for c in negative :
create_cond_with_same_area_if_none ( positive , c )
extra_args = { " model_options " : self . model_options , " seed " : seed }
pre_run_control ( model , negative + positive )
samples = sampler . sample ( self , sigmas , extra_args , callback , noise , latent_image , denoise_mask , disable_pbar )
return self . inner_model . process_latent_out ( samples . to ( torch . float32 ) )
def sample ( self , noise , latent_image , sampler , sigmas , denoise_mask = None , callback = None , disable_pbar = False , seed = None ) :
if sigmas . shape [ - 1 ] == 0 :
return latent_image
apply_empty_x_to_equal_area ( list ( filter ( lambda c : c . get ( ' control_apply_to_uncond ' , False ) == True , positive ) ) , negative , ' control ' , lambda cond_cnets , x : cond_cnets [ x ] )
apply_empty_x_to_equal_area ( positive , negative , ' gligen ' , lambda cond_cnets , x : cond_cnets [ x ] )
self . conds = { }
for k in self . original_conds :
self . conds [ k ] = list ( map ( lambda a : a . copy ( ) , self . original_conds [ k ] ) )
self . inner_model , self . conds , self . loaded_models = comfy . sampler_helpers . prepare_sampling ( self . model_patcher , noise . shape , self . conds )
device = self . model_patcher . load_device
if denoise_mask is not None :
denoise_mask = comfy . sampler_helpers . prepare_mask ( denoise_mask , noise . shape , device )
extra_args = { " cond " : positive , " uncond " : negative , " cond_scale " : cfg , " model_options " : model_options , " seed " : seed }
noise = noise . to ( device )
latent_image = latent_image . to ( device )
sigmas = sigmas . to ( device )
output = self . inner_sample ( noise , latent_image , device , sampler , sigmas , denoise_mask , callback , disable_pbar , seed )
comfy . sampler_helpers . cleanup_models ( self . conds , self . loaded_models )
del self . inner_model
del self . conds
del self . loaded_models
return output
def sample ( model , noise , positive , negative , cfg , device , sampler , sigmas , model_options = { } , latent_image = None , denoise_mask = None , callback = None , disable_pbar = False , seed = None ) :
cfg_guider = CFGGuider ( model )
cfg_guider . set_conds ( positive , negative )
cfg_guider . set_cfg ( cfg )
return cfg_guider . sample ( noise , latent_image , sampler , sigmas , denoise_mask , callback , disable_pbar , seed )
samples = sampler . sample ( model_wrap , sigmas , extra_args , callback , noise , latent_image , denoise_mask , disable_pbar )
return model . process_latent_out ( samples . to ( torch . float32 ) )
SCHEDULER_NAMES = [ " normal " , " karras " , " exponential " , " sgm_uniform " , " simple " , " ddim_uniform " ]
SAMPLER_NAMES = KSAMPLER_NAMES + [ " ddim " , " uni_pc " , " uni_pc_bh2 " ]
def calculate_sigmas_scheduler ( model , scheduler_name , steps ) :
def calculate_sigmas ( model_sampling , scheduler_name , steps ) :
if scheduler_name == " karras " :
sigmas = k_diffusion_sampling . get_sigmas_karras ( n = steps , sigma_min = float ( model . model_sampling . sigma_min ) , sigma_max = float ( model . model_sampling . sigma_max ) )
sigmas = k_diffusion_sampling . get_sigmas_karras ( n = steps , sigma_min = float ( model_sampling . sigma_min ) , sigma_max = float ( model_sampling . sigma_max ) )
elif scheduler_name == " exponential " :
sigmas = k_diffusion_sampling . get_sigmas_exponential ( n = steps , sigma_min = float ( model . model_sampling . sigma_min ) , sigma_max = float ( model . model_sampling . sigma_max ) )
sigmas = k_diffusion_sampling . get_sigmas_exponential ( n = steps , sigma_min = float ( model_sampling . sigma_min ) , sigma_max = float ( model_sampling . sigma_max ) )
elif scheduler_name == " normal " :
sigmas = normal_scheduler ( model , steps )
sigmas = normal_scheduler ( model_sampling , steps )
elif scheduler_name == " simple " :
sigmas = simple_scheduler ( model , steps )
sigmas = simple_scheduler ( model_sampling , steps )
elif scheduler_name == " ddim_uniform " :
sigmas = ddim_scheduler ( model , steps )
sigmas = ddim_scheduler ( model_sampling , steps )
elif scheduler_name == " sgm_uniform " :
sigmas = normal_scheduler ( model , steps , sgm = True )
sigmas = normal_scheduler ( model_sampling , steps , sgm = True )
else :
logging . error ( " error invalid scheduler {} " . format ( scheduler_name ) )
return sigmas
@ -666,7 +714,7 @@ class KSampler:
steps + = 1
discard_penultimate_sigma = True
sigmas = calculate_sigmas_scheduler ( self . model , self . scheduler , steps )
sigmas = calculate_sigmas ( self . model . get_model_object ( " model_sampling " ) , self . scheduler , steps )
if discard_penultimate_sigma :
sigmas = torch . cat ( [ sigmas [ : - 2 ] , sigmas [ - 1 : ] ] )
@ -676,6 +724,9 @@ class KSampler:
self . steps = steps
if denoise is None or denoise > 0.9999 :
self . sigmas = self . calculate_sigmas ( steps ) . to ( self . device )
else :
if denoise < = 0.0 :
self . sigmas = torch . FloatTensor ( [ ] )
else :
new_steps = int ( steps / denoise )
sigmas = self . calculate_sigmas ( new_steps ) . to ( self . device )