You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
3.5 KiB
98 lines
3.5 KiB
9 months ago
|
# code adapted from https://github.com/exx8/differential-diffusion
|
||
|
|
||
|
import torch
|
||
|
import inspect
|
||
|
|
||
|
class DifferentialDiffusion():
|
||
|
@classmethod
|
||
|
def INPUT_TYPES(s):
|
||
|
return {"required": {"model": ("MODEL", ),
|
||
|
}}
|
||
|
RETURN_TYPES = ("MODEL",)
|
||
|
FUNCTION = "apply"
|
||
|
CATEGORY = "_for_testing"
|
||
|
INIT = False
|
||
|
|
||
|
@classmethod
|
||
|
def IS_CHANGED(s, *args, **kwargs):
|
||
|
DifferentialDiffusion.INIT = s.INIT = True
|
||
|
return ""
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
DifferentialDiffusion.INIT = False
|
||
|
self.sigmas: torch.Tensor = None
|
||
|
self.thresholds: torch.Tensor = None
|
||
|
self.mask_i = None
|
||
|
self.valid_sigmas = False
|
||
|
self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"]
|
||
|
|
||
|
def apply(self, model):
|
||
|
model = model.clone()
|
||
|
model.model_options["denoise_mask_function"] = self.forward
|
||
|
return (model,)
|
||
|
|
||
|
def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
|
||
|
self.__init__()
|
||
|
self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "")
|
||
|
self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler
|
||
|
if self.sigmas is None:
|
||
|
self.sigmas = sigma[:1].repeat(2)
|
||
|
self.sigmas[-1].zero_()
|
||
|
self.sigmas_min = self.sigmas.min()
|
||
|
self.sigmas_max = self.sigmas.max()
|
||
|
self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device)
|
||
|
self.thresholds_min_len = self.thresholds.shape[0] - 1
|
||
|
if self.valid_sigmas:
|
||
|
thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1)
|
||
|
mask = denoise_mask.unsqueeze(0)
|
||
|
mask = (mask >= thresholds).to(denoise_mask.dtype)
|
||
|
self.mask_i = iter(mask)
|
||
|
|
||
|
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
|
||
|
if self.sigmas is None or DifferentialDiffusion.INIT:
|
||
|
self.init_sigmas(sigma, denoise_mask)
|
||
|
if self.valid_sigmas:
|
||
|
try:
|
||
|
return next(self.mask_i)
|
||
|
except StopIteration:
|
||
|
self.valid_sigmas = False
|
||
|
if self.thresholds_min_len > 1:
|
||
|
nearest_idx = (self.sigmas - sigma[0]).abs().argmin()
|
||
|
if not self.thresholds_min_len > nearest_idx:
|
||
|
nearest_idx = -2
|
||
|
threshold = self.thresholds[nearest_idx]
|
||
|
else:
|
||
|
threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min)
|
||
|
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||
|
|
||
|
def get_sigmas_and_sampler(frame, target):
|
||
|
found = frame.f_locals[target]
|
||
|
if isinstance(found, torch.Tensor) and found[-1] < 0.1:
|
||
|
return found, frame.f_code.co_name
|
||
|
return False
|
||
|
|
||
|
def find_outer_instance(target: str, target_type=None, callback=None):
|
||
|
frame = inspect.currentframe()
|
||
|
i = 0
|
||
|
while frame and i < 100:
|
||
|
if target in frame.f_locals:
|
||
|
if callback is not None:
|
||
|
res = callback(frame, target)
|
||
|
if res:
|
||
|
return res
|
||
|
else:
|
||
|
found = frame.f_locals[target]
|
||
|
if isinstance(found, target_type):
|
||
|
return found
|
||
|
frame = frame.f_back
|
||
|
i += 1
|
||
|
return None
|
||
|
|
||
|
|
||
|
NODE_CLASS_MAPPINGS = {
|
||
|
"DifferentialDiffusion": DifferentialDiffusion,
|
||
|
}
|
||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
|
"DifferentialDiffusion": "Differential Diffusion",
|
||
|
}
|