Browse Source

Refactor sampler code for more advanced sampler nodes part 2.

pull/3216/head
comfyanonymous 8 months ago
parent
commit
0542088ef8
  1. 102
      comfy/sample.py
  2. 76
      comfy/sampler_helpers.py
  3. 108
      comfy/samplers.py

102
comfy/sample.py

@ -1,10 +1,9 @@
import torch
import comfy.model_management
import comfy.samplers
import comfy.conds
import comfy.utils
import math
import numpy as np
import logging
def prepare_noise(latent_image, seed, noise_inds=None):
"""
@ -25,106 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises = torch.cat(noises, axis=0)
return noises
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for i in range(len(conds)):
cnets += get_models_from_cond(conds[i], "control")
gligen += get_models_from_cond(conds[i], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
return model, positive, negative, noise_mask, []
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, conds, noise_mask):
device = model.load_device
for i in range(len(conds)):
conds[i] = convert_cond(conds[i])
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, noise_mask, models
logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, conds_copy, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask)
positive_copy, negative_copy = conds_copy
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, conds, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)
samples = comfy.samplers.sample(real_model, noise, conds[0], conds[1], cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
control_cleanup = []
for i in range(len(conds)):
control_cleanup += get_models_from_cond(conds[i], "control")
cleanup_additional_models(set(control_cleanup))
return samples

76
comfy/sampler_helpers.py

@ -0,0 +1,76 @@
import torch
import comfy.model_management
import comfy.conds
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, conds):
device = model.load_device
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, models
def cleanup_models(conds, models):
cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))

108
comfy/samplers.py

@ -5,6 +5,7 @@ import collections
from comfy import model_management
import math
import logging
import comfy.sampler_helpers
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
@ -230,21 +231,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
cond_pred = out[0]
uncond_pred = out[1]
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}):
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
@ -259,29 +246,30 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return cfg_result
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model, cond_scale=1.0):
super().__init__()
self.inner_model = model
self.cond_scale = cond_scale
def apply_model(self, x, timestep, conds, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, conds.get("negative", None), conds.get("positive", None), self.cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options)
class KSamplerX0Inpaint(torch.nn.Module):
class KSamplerX0Inpaint:
def __init__(self, model, sigmas):
super().__init__()
self.inner_model = model
self.sigmas = sigmas
def forward(self, x, sigma, conds, denoise_mask, model_options={}, seed=None):
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, conds=conds, model_options=model_options, seed=seed)
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
return out
@ -601,22 +589,66 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
return conds
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
def set_conds(self, conds):
for k in conds:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def set_cfg(self, cfg):
self.cfg = cfg
def sample_advanced(model, noise, conds, guider_class, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": self.model_options, "seed":seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
conds = process_conds(model, noise, conds, device, latent_image, denoise_mask, seed)
model_wrap = guider_class(model)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
extra_args = {"conds": conds, "model_options": model_options, "seed":seed}
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
return sample_advanced(model, noise, {"positive": positive, "negative": negative}, lambda a: CFGNoisePredictor(a, cfg), device, sampler, sigmas, model_options, latent_image, denoise_mask, callback, disable_pbar, seed)
cfg_guider = CFGGuider(model)
cfg_guider.set_conds({"positive": positive, "negative": negative})
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
@ -676,7 +708,7 @@ class KSampler:
steps += 1
discard_penultimate_sigma = True
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps)
sigmas = calculate_sigmas_scheduler(self.model.model, self.scheduler, steps)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])

Loading…
Cancel
Save