|
|
@ -3,7 +3,6 @@ import math |
|
|
|
from scipy import integrate |
|
|
|
from scipy import integrate |
|
|
|
import torch |
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
from torch import nn |
|
|
|
from torchdiffeq import odeint |
|
|
|
|
|
|
|
import torchsde |
|
|
|
import torchsde |
|
|
|
from tqdm.auto import trange, tqdm |
|
|
|
from tqdm.auto import trange, tqdm |
|
|
|
|
|
|
|
|
|
|
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o |
|
|
|
return x |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
|
|
|
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): |
|
|
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args |
|
|
|
|
|
|
|
s_in = x.new_ones([x.shape[0]]) |
|
|
|
|
|
|
|
v = torch.randint_like(x, 2) * 2 - 1 |
|
|
|
|
|
|
|
fevals = 0 |
|
|
|
|
|
|
|
def ode_fn(sigma, x): |
|
|
|
|
|
|
|
nonlocal fevals |
|
|
|
|
|
|
|
with torch.enable_grad(): |
|
|
|
|
|
|
|
x = x[0].detach().requires_grad_() |
|
|
|
|
|
|
|
denoised = model(x, sigma * s_in, **extra_args) |
|
|
|
|
|
|
|
d = to_d(x, sigma, denoised) |
|
|
|
|
|
|
|
fevals += 1 |
|
|
|
|
|
|
|
grad = torch.autograd.grad((d * v).sum(), x)[0] |
|
|
|
|
|
|
|
d_ll = (v * grad).flatten(1).sum(1) |
|
|
|
|
|
|
|
return d.detach(), d_ll |
|
|
|
|
|
|
|
x_min = x, x.new_zeros([x.shape[0]]) |
|
|
|
|
|
|
|
t = x.new_tensor([sigma_min, sigma_max]) |
|
|
|
|
|
|
|
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') |
|
|
|
|
|
|
|
latent, delta_ll = sol[0][-1], sol[1][-1] |
|
|
|
|
|
|
|
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) |
|
|
|
|
|
|
|
return ll_prior + delta_ll, {'fevals': fevals} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PIDStepSizeController: |
|
|
|
class PIDStepSizeController: |
|
|
|
"""A PID controller for ODE adaptive step size control.""" |
|
|
|
"""A PID controller for ODE adaptive step size control.""" |
|
|
|
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): |
|
|
|
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): |
|
|
|