|
|
|
@ -605,3 +605,46 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|
|
|
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d |
|
|
|
|
old_denoised = denoised |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): |
|
|
|
|
"""DPM-Solver++(2M) SDE.""" |
|
|
|
|
|
|
|
|
|
if solver_type not in {'heun', 'midpoint'}: |
|
|
|
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'') |
|
|
|
|
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler |
|
|
|
|
extra_args = {} if extra_args is None else extra_args |
|
|
|
|
s_in = x.new_ones([x.shape[0]]) |
|
|
|
|
|
|
|
|
|
old_denoised = None |
|
|
|
|
h_last = None |
|
|
|
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
|
|
|
if callback is not None: |
|
|
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
|
|
|
|
if sigmas[i + 1] == 0: |
|
|
|
|
# Denoising step |
|
|
|
|
x = denoised |
|
|
|
|
else: |
|
|
|
|
# DPM-Solver++(2M) SDE |
|
|
|
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log() |
|
|
|
|
h = s - t |
|
|
|
|
eta_h = eta * h |
|
|
|
|
|
|
|
|
|
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised |
|
|
|
|
|
|
|
|
|
if old_denoised is not None: |
|
|
|
|
r = h_last / h |
|
|
|
|
if solver_type == 'heun': |
|
|
|
|
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) |
|
|
|
|
elif solver_type == 'midpoint': |
|
|
|
|
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) |
|
|
|
|
|
|
|
|
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise |
|
|
|
|
|
|
|
|
|
old_denoised = denoised |
|
|
|
|
h_last = h |
|
|
|
|
return x |
|
|
|
|