|
|
|
@ -638,32 +638,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|
|
|
|
h_last = h |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_3m(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): |
|
|
|
|
"""DPM-Solver++(3M) without SDE-specific parts.""" |
|
|
|
|
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler |
|
|
|
|
extra_args = {} if extra_args is None else extra_args |
|
|
|
|
s_in = x.new_ones([x.shape[0]]) |
|
|
|
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
|
|
|
if callback is not None: |
|
|
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
|
|
|
|
|
|
|
|
|
# Update x using the DPM-Solver++(3M) update rule |
|
|
|
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log() |
|
|
|
|
h = s - t |
|
|
|
|
h_eta = h * (eta + 1) |
|
|
|
|
|
|
|
|
|
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised |
|
|
|
|
|
|
|
|
|
if eta: |
|
|
|
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise |
|
|
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): |
|
|
|
|
"""DPM-Solver++(3M) SDE.""" |
|
|
|
|