Browse Source

Refactor sampling code for more advanced sampler nodes.

pull/3216/head
comfyanonymous 8 months ago
parent
commit
57753c964a
  1. 38
      comfy/sample.py
  2. 72
      comfy/samplers.py

38
comfy/sample.py

@ -52,9 +52,16 @@ def convert_cond(cond):
out.append(temp)
return out
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for i in range(len(conds)):
cnets += get_models_from_cond(conds[i], "control")
gligen += get_models_from_cond(conds[i], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
@ -62,7 +69,6 @@ def get_additional_models(positive, negative, dtype):
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
@ -73,24 +79,25 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
def prepare_sampling(model, noise_shape, conds, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)
for i in range(len(conds)):
conds[i] = convert_cond(conds[i])
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, positive, negative, noise_mask, models
return real_model, conds, noise_mask, models
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
real_model, conds_copy, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask)
positive_copy, negative_copy = conds_copy
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
@ -105,14 +112,19 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
real_model, conds, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = comfy.samplers.sample(real_model, noise, conds[0], conds[1], cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
control_cleanup = []
for i in range(len(conds)):
control_cleanup += get_models_from_cond(conds[i], "control")
cleanup_additional_models(set(control_cleanup))
return samples

72
comfy/samplers.py

@ -260,11 +260,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return cfg_result
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, cond_scale=1.0):
super().__init__()
self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
self.cond_scale = cond_scale
def apply_model(self, x, timestep, conds, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, conds.get("negative", None), conds.get("positive", None), self.cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
@ -274,13 +275,13 @@ class KSamplerX0Inpaint(torch.nn.Module):
super().__init__()
self.inner_model = model
self.sigmas = sigmas
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
def forward(self, x, sigma, conds, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
out = self.inner_model(x, sigma, conds=conds, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
return out
@ -568,44 +569,55 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
return model_denoise
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:]
negative = negative[:]
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
for k in conds:
calculate_start_end_timesteps(model, conds[k])
model_wrap = wrap_model(model)
if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
#make sure each cond area has an opposite one with the same area
for k in conds:
for c in conds[k]:
for kk in conds:
if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)
for k in conds:
pre_run_control(model, conds[k])
if "positive" in conds:
positive = conds["positive"]
for k in conds:
if k != "positive":
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
return conds
def sample_advanced(model, noise, conds, guider_class, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
conds = process_conds(model, noise, conds, device, latent_image, denoise_mask, seed)
model_wrap = guider_class(model)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
extra_args = {"conds": conds, "model_options": model_options, "seed":seed}
pre_run_control(model, negative + positive)
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
return sample_advanced(model, noise, {"positive": positive, "negative": negative}, lambda a: CFGNoisePredictor(a, cfg), device, sampler, sigmas, model_options, latent_image, denoise_mask, callback, disable_pbar, seed)
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]

Loading…
Cancel
Save