Browse Source

Remove unused code and torchdiffeq dependency.

pull/1015/head
comfyanonymous 1 year ago
parent
commit
c910b4a01c
  1. 25
      comfy/k_diffusion/sampling.py
  2. 1
      requirements.txt

25
comfy/k_diffusion/sampling.py

@ -3,7 +3,6 @@ import math
from scipy import integrate from scipy import integrate
import torch import torch
from torch import nn from torch import nn
from torchdiffeq import odeint
import torchsde import torchsde
from tqdm.auto import trange, tqdm from tqdm.auto import trange, tqdm
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
return x return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}
class PIDStepSizeController: class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control.""" """A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):

1
requirements.txt

@ -1,5 +1,4 @@
torch torch
torchdiffeq
torchsde torchsde
einops einops
transformers>=4.25.1 transformers>=4.25.1

Loading…
Cancel
Save