Browse Source

Use relative imports for k_diffusion.

pull/4/head
comfyanonymous 2 years ago
parent
commit
bbdcf0b737
  1. 16
      comfy/samplers.py

16
comfy/samplers.py

@ -1,5 +1,5 @@
import k_diffusion.sampling from .k_diffusion import sampling as k_diffusion_sampling
import k_diffusion.external from .k_diffusion import external as k_diffusion_external
import torch import torch
import contextlib import contextlib
import model_management import model_management
@ -185,9 +185,9 @@ class KSampler:
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
self.model = model self.model = model
if self.model.parameterization == "v": if self.model.parameterization == "v":
self.model_wrap = k_diffusion.external.CompVisVDenoiser(self.model, quantize=True) self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True)
else: else:
self.model_wrap = k_diffusion.external.CompVisDenoiser(self.model, quantize=True) self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True)
self.model_k = CFGDenoiserComplex(self.model_wrap) self.model_k = CFGDenoiserComplex(self.model_wrap)
self.device = device self.device = device
if scheduler not in self.SCHEDULERS: if scheduler not in self.SCHEDULERS:
@ -209,7 +209,7 @@ class KSampler:
discard_penultimate_sigma = True discard_penultimate_sigma = True
if self.scheduler == "karras": if self.scheduler == "karras":
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device)
elif self.scheduler == "normal": elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps).to(self.device) sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
elif self.scheduler == "simple": elif self.scheduler == "simple":
@ -269,9 +269,9 @@ class KSampler:
with precision_scope(self.device): with precision_scope(self.device):
if self.sampler == "sample_dpm_fast": if self.sampler == "sample_dpm_fast":
samples = k_diffusion.sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
elif self.sampler == "sample_dpm_adaptive": elif self.sampler == "sample_dpm_adaptive":
samples = k_diffusion.sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
else: else:
samples = getattr(k_diffusion.sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
return samples.to(torch.float32) return samples.to(torch.float32)

Loading…
Cancel
Save