From 12c1080ebc9095b3878d7b3aa994a6b75b308e0a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 3 Mar 2024 15:11:13 -0500 Subject: [PATCH] Simplify differential diffusion code. --- comfy/model_patcher.py | 3 + comfy/samplers.py | 7 +- comfy_extras/nodes_differential_diffusion.py | 87 ++++---------------- 3 files changed, 23 insertions(+), 74 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 604e3477..4a5d42b0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -67,6 +67,9 @@ class ModelPatcher: def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function + def set_model_denoise_mask_function(self, denoise_mask_function): + self.model_options["denoise_mask_function"] = denoise_mask_function + def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: diff --git a/comfy/samplers.py b/comfy/samplers.py index b7ef6b96..6863be4e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -272,13 +272,14 @@ class CFGNoisePredictor(torch.nn.Module): return self.apply_model(*args, **kwargs) class KSamplerX0Inpaint(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, sigmas): super().__init__() self.inner_model = model + self.sigmas = sigmas def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: if "denoise_mask_function" in model_options: - denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask) + denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) @@ -528,7 +529,7 @@ class KSAMPLER(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): extra_args["denoise_mask"] = denoise_mask - model_k = KSamplerX0Inpaint(model_wrap) + model_k = KSamplerX0Inpaint(model_wrap, sigmas) model_k.latent_image = latent_image if self.inpaint_options.get("random", False): #TODO: Should this be the default? generator = torch.manual_seed(extra_args.get("seed", 41) + 1) diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 48c95602..7e858a71 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -1,7 +1,6 @@ # code adapted from https://github.com/exx8/differential-diffusion import torch -import inspect class DifferentialDiffusion(): @classmethod @@ -13,82 +12,28 @@ class DifferentialDiffusion(): CATEGORY = "_for_testing" INIT = False - @classmethod - def IS_CHANGED(s, *args, **kwargs): - DifferentialDiffusion.INIT = s.INIT = True - return "" - - def __init__(self) -> None: - DifferentialDiffusion.INIT = False - self.sigmas: torch.Tensor = None - self.thresholds: torch.Tensor = None - self.mask_i = None - self.valid_sigmas = False - self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"] - def apply(self, model): model = model.clone() - model.model_options["denoise_mask_function"] = self.forward + model.set_model_denoise_mask_function(self.forward) return (model,) - - def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor): - self.__init__() - self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "") - self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler - if self.sigmas is None: - self.sigmas = sigma[:1].repeat(2) - self.sigmas[-1].zero_() - self.sigmas_min = self.sigmas.min() - self.sigmas_max = self.sigmas.max() - self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device) - self.thresholds_min_len = self.thresholds.shape[0] - 1 - if self.valid_sigmas: - thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1) - mask = denoise_mask.unsqueeze(0) - mask = (mask >= thresholds).to(denoise_mask.dtype) - self.mask_i = iter(mask) - - def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor): - if self.sigmas is None or DifferentialDiffusion.INIT: - self.init_sigmas(sigma, denoise_mask) - if self.valid_sigmas: - try: - return next(self.mask_i) - except StopIteration: - self.valid_sigmas = False - if self.thresholds_min_len > 1: - nearest_idx = (self.sigmas - sigma[0]).abs().argmin() - if not self.thresholds_min_len > nearest_idx: - nearest_idx = -2 - threshold = self.thresholds[nearest_idx] - else: - threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min) - return (denoise_mask >= threshold).to(denoise_mask.dtype) -def get_sigmas_and_sampler(frame, target): - found = frame.f_locals[target] - if isinstance(found, torch.Tensor) and found[-1] < 0.1: - return found, frame.f_code.co_name - return False + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + model = extra_options["model"] + step_sigmas = extra_options["sigmas"] + sigma_to = model.inner_model.model_sampling.sigma_min + if step_sigmas[-1] > sigma_to: + sigma_to = step_sigmas[-1] + sigma_from = step_sigmas[0] + + ts_from = model.inner_model.model_sampling.timestep(sigma_from) + ts_to = model.inner_model.model_sampling.timestep(sigma_to) + current_ts = model.inner_model.model_sampling.timestep(sigma) + + threshold = (current_ts - ts_to) / (ts_from - ts_to) + + return (denoise_mask >= threshold).to(denoise_mask.dtype) -def find_outer_instance(target: str, target_type=None, callback=None): - frame = inspect.currentframe() - i = 0 - while frame and i < 100: - if target in frame.f_locals: - if callback is not None: - res = callback(frame, target) - if res: - return res - else: - found = frame.f_locals[target] - if isinstance(found, target_type): - return found - frame = frame.f_back - i += 1 - return None - NODE_CLASS_MAPPINGS = { "DifferentialDiffusion": DifferentialDiffusion, }