|
|
|
@ -527,6 +527,9 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): |
|
|
|
|
"""DPM-Solver++ (stochastic).""" |
|
|
|
|
if len(sigmas) <= 1: |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
seed = extra_args.get("seed", None) |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler |
|
|
|
@ -595,6 +598,8 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): |
|
|
|
|
"""DPM-Solver++(2M) SDE.""" |
|
|
|
|
if len(sigmas) <= 1: |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
if solver_type not in {'heun', 'midpoint'}: |
|
|
|
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'') |
|
|
|
@ -642,6 +647,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|
|
|
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): |
|
|
|
|
"""DPM-Solver++(3M) SDE.""" |
|
|
|
|
|
|
|
|
|
if len(sigmas) <= 1: |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
seed = extra_args.get("seed", None) |
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler |
|
|
|
@ -690,18 +698,27 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): |
|
|
|
|
if len(sigmas) <= 1: |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler |
|
|
|
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): |
|
|
|
|
if len(sigmas) <= 1: |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler |
|
|
|
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): |
|
|
|
|
if len(sigmas) <= 1: |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler |
|
|
|
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) |
|
|
|
|