|
|
|
@ -5,6 +5,7 @@ import collections
|
|
|
|
|
from comfy import model_management |
|
|
|
|
import math |
|
|
|
|
import logging |
|
|
|
|
import comfy.sampler_helpers |
|
|
|
|
|
|
|
|
|
def get_area_and_mult(conds, x_in, timestep_in): |
|
|
|
|
area = (x_in.shape[2], x_in.shape[3], 0, 0) |
|
|
|
@ -230,58 +231,45 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|
|
|
|
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") |
|
|
|
|
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) |
|
|
|
|
|
|
|
|
|
#The main sampling function shared by all the samplers |
|
|
|
|
#Returns denoised |
|
|
|
|
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): |
|
|
|
|
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: |
|
|
|
|
uncond_ = None |
|
|
|
|
else: |
|
|
|
|
uncond_ = uncond |
|
|
|
|
|
|
|
|
|
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}): |
|
|
|
|
if "sampler_cfg_function" in model_options: |
|
|
|
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, |
|
|
|
|
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} |
|
|
|
|
cfg_result = x - model_options["sampler_cfg_function"](args) |
|
|
|
|
else: |
|
|
|
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale |
|
|
|
|
|
|
|
|
|
conds = [cond, uncond_] |
|
|
|
|
for fn in model_options.get("sampler_post_cfg_function", []): |
|
|
|
|
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, |
|
|
|
|
"sigma": timestep, "model_options": model_options, "input": x} |
|
|
|
|
cfg_result = fn(args) |
|
|
|
|
|
|
|
|
|
out = calc_cond_batch(model, conds, x, timestep, model_options) |
|
|
|
|
cond_pred = out[0] |
|
|
|
|
uncond_pred = out[1] |
|
|
|
|
return cfg_result |
|
|
|
|
|
|
|
|
|
if "sampler_cfg_function" in model_options: |
|
|
|
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, |
|
|
|
|
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} |
|
|
|
|
cfg_result = x - model_options["sampler_cfg_function"](args) |
|
|
|
|
else: |
|
|
|
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale |
|
|
|
|
|
|
|
|
|
for fn in model_options.get("sampler_post_cfg_function", []): |
|
|
|
|
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, |
|
|
|
|
"sigma": timestep, "model_options": model_options, "input": x} |
|
|
|
|
cfg_result = fn(args) |
|
|
|
|
#The main sampling function shared by all the samplers |
|
|
|
|
#Returns denoised |
|
|
|
|
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): |
|
|
|
|
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: |
|
|
|
|
uncond_ = None |
|
|
|
|
else: |
|
|
|
|
uncond_ = uncond |
|
|
|
|
|
|
|
|
|
return cfg_result |
|
|
|
|
conds = [cond, uncond_] |
|
|
|
|
out = calc_cond_batch(model, conds, x, timestep, model_options) |
|
|
|
|
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options) |
|
|
|
|
|
|
|
|
|
class CFGNoisePredictor(torch.nn.Module): |
|
|
|
|
def __init__(self, model, cond_scale=1.0): |
|
|
|
|
super().__init__() |
|
|
|
|
self.inner_model = model |
|
|
|
|
self.cond_scale = cond_scale |
|
|
|
|
def apply_model(self, x, timestep, conds, model_options={}, seed=None): |
|
|
|
|
out = sampling_function(self.inner_model, x, timestep, conds.get("negative", None), conds.get("positive", None), self.cond_scale, model_options=model_options, seed=seed) |
|
|
|
|
return out |
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
return self.apply_model(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
class KSamplerX0Inpaint(torch.nn.Module): |
|
|
|
|
class KSamplerX0Inpaint: |
|
|
|
|
def __init__(self, model, sigmas): |
|
|
|
|
super().__init__() |
|
|
|
|
self.inner_model = model |
|
|
|
|
self.sigmas = sigmas |
|
|
|
|
def forward(self, x, sigma, conds, denoise_mask, model_options={}, seed=None): |
|
|
|
|
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): |
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
if "denoise_mask_function" in model_options: |
|
|
|
|
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) |
|
|
|
|
latent_mask = 1. - denoise_mask |
|
|
|
|
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask |
|
|
|
|
out = self.inner_model(x, sigma, conds=conds, model_options=model_options, seed=seed) |
|
|
|
|
out = self.inner_model(x, sigma, model_options=model_options, seed=seed) |
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
out = out * denoise_mask + self.latent_image * latent_mask |
|
|
|
|
return out |
|
|
|
@ -601,22 +589,66 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
|
|
|
|
|
|
|
|
|
return conds |
|
|
|
|
|
|
|
|
|
class CFGGuider: |
|
|
|
|
def __init__(self, model_patcher): |
|
|
|
|
self.model_patcher = model_patcher |
|
|
|
|
self.model_options = model_patcher.model_options |
|
|
|
|
self.original_conds = {} |
|
|
|
|
self.cfg = 1.0 |
|
|
|
|
|
|
|
|
|
def set_conds(self, conds): |
|
|
|
|
for k in conds: |
|
|
|
|
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) |
|
|
|
|
|
|
|
|
|
def set_cfg(self, cfg): |
|
|
|
|
self.cfg = cfg |
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
|
return self.predict_noise(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
def predict_noise(self, x, timestep, model_options={}, seed=None): |
|
|
|
|
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) |
|
|
|
|
|
|
|
|
|
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): |
|
|
|
|
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. |
|
|
|
|
latent_image = self.inner_model.process_latent_in(latent_image) |
|
|
|
|
|
|
|
|
|
def sample_advanced(model, noise, conds, guider_class, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): |
|
|
|
|
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. |
|
|
|
|
latent_image = model.process_latent_in(latent_image) |
|
|
|
|
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) |
|
|
|
|
|
|
|
|
|
extra_args = {"model_options": self.model_options, "seed":seed} |
|
|
|
|
|
|
|
|
|
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) |
|
|
|
|
return self.inner_model.process_latent_out(samples.to(torch.float32)) |
|
|
|
|
|
|
|
|
|
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): |
|
|
|
|
self.conds = {} |
|
|
|
|
for k in self.original_conds: |
|
|
|
|
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) |
|
|
|
|
|
|
|
|
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) |
|
|
|
|
device = self.model_patcher.load_device |
|
|
|
|
|
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) |
|
|
|
|
|
|
|
|
|
conds = process_conds(model, noise, conds, device, latent_image, denoise_mask, seed) |
|
|
|
|
model_wrap = guider_class(model) |
|
|
|
|
noise = noise.to(device) |
|
|
|
|
latent_image = latent_image.to(device) |
|
|
|
|
sigmas = sigmas.to(device) |
|
|
|
|
|
|
|
|
|
extra_args = {"conds": conds, "model_options": model_options, "seed":seed} |
|
|
|
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) |
|
|
|
|
|
|
|
|
|
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) |
|
|
|
|
return model.process_latent_out(samples.to(torch.float32)) |
|
|
|
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) |
|
|
|
|
del self.inner_model |
|
|
|
|
del self.conds |
|
|
|
|
del self.loaded_models |
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): |
|
|
|
|
return sample_advanced(model, noise, {"positive": positive, "negative": negative}, lambda a: CFGNoisePredictor(a, cfg), device, sampler, sigmas, model_options, latent_image, denoise_mask, callback, disable_pbar, seed) |
|
|
|
|
cfg_guider = CFGGuider(model) |
|
|
|
|
cfg_guider.set_conds({"positive": positive, "negative": negative}) |
|
|
|
|
cfg_guider.set_cfg(cfg) |
|
|
|
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] |
|
|
|
@ -676,7 +708,7 @@ class KSampler:
|
|
|
|
|
steps += 1 |
|
|
|
|
discard_penultimate_sigma = True |
|
|
|
|
|
|
|
|
|
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) |
|
|
|
|
sigmas = calculate_sigmas_scheduler(self.model.model, self.scheduler, steps) |
|
|
|
|
|
|
|
|
|
if discard_penultimate_sigma: |
|
|
|
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) |
|
|
|
|