diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 7af01682..f9b28189 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -527,6 +527,9 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, @torch.no_grad() def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" + if len(sigmas) <= 1: + return x + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() seed = extra_args.get("seed", None) noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler @@ -595,6 +598,8 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No @torch.no_grad() def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): """DPM-Solver++(2M) SDE.""" + if len(sigmas) <= 1: + return x if solver_type not in {'heun', 'midpoint'}: raise ValueError('solver_type must be \'heun\' or \'midpoint\'') @@ -642,6 +647,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" + if len(sigmas) <= 1: + return x + seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler @@ -690,18 +698,27 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + if len(sigmas) <= 1: + return x + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + if len(sigmas) <= 1: + return x + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): + if len(sigmas) <= 1: + return x + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) diff --git a/comfy/samplers.py b/comfy/samplers.py index b12b0fd1..29962a91 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -539,6 +539,9 @@ class KSAMPLER(Sampler): def ksampler(sampler_name, extra_options={}, inpaint_options={}): if sampler_name == "dpm_fast": def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable): + if len(sigmas) <= 1: + return noise + sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] @@ -547,6 +550,9 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): sampler_function = dpm_fast_function elif sampler_name == "dpm_adaptive": def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options): + if len(sigmas) <= 1: + return noise + sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2]