|
|
|
@ -1,7 +1,6 @@
|
|
|
|
|
# code adapted from https://github.com/exx8/differential-diffusion |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import inspect |
|
|
|
|
|
|
|
|
|
class DifferentialDiffusion(): |
|
|
|
|
@classmethod |
|
|
|
@ -13,80 +12,26 @@ class DifferentialDiffusion():
|
|
|
|
|
CATEGORY = "_for_testing" |
|
|
|
|
INIT = False |
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
|
def IS_CHANGED(s, *args, **kwargs): |
|
|
|
|
DifferentialDiffusion.INIT = s.INIT = True |
|
|
|
|
return "" |
|
|
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
|
DifferentialDiffusion.INIT = False |
|
|
|
|
self.sigmas: torch.Tensor = None |
|
|
|
|
self.thresholds: torch.Tensor = None |
|
|
|
|
self.mask_i = None |
|
|
|
|
self.valid_sigmas = False |
|
|
|
|
self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"] |
|
|
|
|
|
|
|
|
|
def apply(self, model): |
|
|
|
|
model = model.clone() |
|
|
|
|
model.model_options["denoise_mask_function"] = self.forward |
|
|
|
|
model.set_model_denoise_mask_function(self.forward) |
|
|
|
|
return (model,) |
|
|
|
|
|
|
|
|
|
def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor): |
|
|
|
|
self.__init__() |
|
|
|
|
self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "") |
|
|
|
|
self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler |
|
|
|
|
if self.sigmas is None: |
|
|
|
|
self.sigmas = sigma[:1].repeat(2) |
|
|
|
|
self.sigmas[-1].zero_() |
|
|
|
|
self.sigmas_min = self.sigmas.min() |
|
|
|
|
self.sigmas_max = self.sigmas.max() |
|
|
|
|
self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device) |
|
|
|
|
self.thresholds_min_len = self.thresholds.shape[0] - 1 |
|
|
|
|
if self.valid_sigmas: |
|
|
|
|
thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1) |
|
|
|
|
mask = denoise_mask.unsqueeze(0) |
|
|
|
|
mask = (mask >= thresholds).to(denoise_mask.dtype) |
|
|
|
|
self.mask_i = iter(mask) |
|
|
|
|
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): |
|
|
|
|
model = extra_options["model"] |
|
|
|
|
step_sigmas = extra_options["sigmas"] |
|
|
|
|
sigma_to = model.inner_model.model_sampling.sigma_min |
|
|
|
|
if step_sigmas[-1] > sigma_to: |
|
|
|
|
sigma_to = step_sigmas[-1] |
|
|
|
|
sigma_from = step_sigmas[0] |
|
|
|
|
|
|
|
|
|
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor): |
|
|
|
|
if self.sigmas is None or DifferentialDiffusion.INIT: |
|
|
|
|
self.init_sigmas(sigma, denoise_mask) |
|
|
|
|
if self.valid_sigmas: |
|
|
|
|
try: |
|
|
|
|
return next(self.mask_i) |
|
|
|
|
except StopIteration: |
|
|
|
|
self.valid_sigmas = False |
|
|
|
|
if self.thresholds_min_len > 1: |
|
|
|
|
nearest_idx = (self.sigmas - sigma[0]).abs().argmin() |
|
|
|
|
if not self.thresholds_min_len > nearest_idx: |
|
|
|
|
nearest_idx = -2 |
|
|
|
|
threshold = self.thresholds[nearest_idx] |
|
|
|
|
else: |
|
|
|
|
threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min) |
|
|
|
|
return (denoise_mask >= threshold).to(denoise_mask.dtype) |
|
|
|
|
ts_from = model.inner_model.model_sampling.timestep(sigma_from) |
|
|
|
|
ts_to = model.inner_model.model_sampling.timestep(sigma_to) |
|
|
|
|
current_ts = model.inner_model.model_sampling.timestep(sigma) |
|
|
|
|
|
|
|
|
|
def get_sigmas_and_sampler(frame, target): |
|
|
|
|
found = frame.f_locals[target] |
|
|
|
|
if isinstance(found, torch.Tensor) and found[-1] < 0.1: |
|
|
|
|
return found, frame.f_code.co_name |
|
|
|
|
return False |
|
|
|
|
threshold = (current_ts - ts_to) / (ts_from - ts_to) |
|
|
|
|
|
|
|
|
|
def find_outer_instance(target: str, target_type=None, callback=None): |
|
|
|
|
frame = inspect.currentframe() |
|
|
|
|
i = 0 |
|
|
|
|
while frame and i < 100: |
|
|
|
|
if target in frame.f_locals: |
|
|
|
|
if callback is not None: |
|
|
|
|
res = callback(frame, target) |
|
|
|
|
if res: |
|
|
|
|
return res |
|
|
|
|
else: |
|
|
|
|
found = frame.f_locals[target] |
|
|
|
|
if isinstance(found, target_type): |
|
|
|
|
return found |
|
|
|
|
frame = frame.f_back |
|
|
|
|
i += 1 |
|
|
|
|
return None |
|
|
|
|
return (denoise_mask >= threshold).to(denoise_mask.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
|
|