From 57eea0efbb07a48d4810b477b29d44ba5425a742 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 Nov 2023 23:45:36 -0500 Subject: [PATCH] heunpp2 sampler. --- comfy/k_diffusion/sampling.py | 58 +++++++++++++++++++++++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index dd6f7bbe..761c2e0e 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -750,3 +750,61 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n if sigmas[i + 1] > 0: x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) return x + + + +@torch.no_grad() +def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + s_end = sigmas[-1] + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == s_end: + # Euler method + x = x + d * dt + elif sigmas[i + 2] == s_end: + + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + + w = 2 * sigmas[0] + w2 = sigmas[i+1]/w + w1 = 1 - w2 + + d_prime = d * w1 + d_2 * w2 + + + x = x + d_prime * dt + + else: + # Heun++ + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + dt_2 = sigmas[i + 2] - sigmas[i + 1] + + x_3 = x_2 + d_2 * dt_2 + denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args) + d_3 = to_d(x_3, sigmas[i + 2], denoised_3) + + w = 3 * sigmas[0] + w2 = sigmas[i + 1] / w + w3 = sigmas[i + 2] / w + w1 = 1 - w2 - w3 + + d_prime = w1 * d + w2 * d_2 + w3 * d_3 + x = x + d_prime * dt + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index 65c44791..d8037d8e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -518,7 +518,7 @@ class UNIPCBH2(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) -KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", +KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]