|
|
@ -1,5 +1,5 @@ |
|
|
|
import k_diffusion.sampling |
|
|
|
from .k_diffusion import sampling as k_diffusion_sampling |
|
|
|
import k_diffusion.external |
|
|
|
from .k_diffusion import external as k_diffusion_external |
|
|
|
import torch |
|
|
|
import torch |
|
|
|
import contextlib |
|
|
|
import contextlib |
|
|
|
import model_management |
|
|
|
import model_management |
|
|
@ -185,9 +185,9 @@ class KSampler: |
|
|
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): |
|
|
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): |
|
|
|
self.model = model |
|
|
|
self.model = model |
|
|
|
if self.model.parameterization == "v": |
|
|
|
if self.model.parameterization == "v": |
|
|
|
self.model_wrap = k_diffusion.external.CompVisVDenoiser(self.model, quantize=True) |
|
|
|
self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True) |
|
|
|
else: |
|
|
|
else: |
|
|
|
self.model_wrap = k_diffusion.external.CompVisDenoiser(self.model, quantize=True) |
|
|
|
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True) |
|
|
|
self.model_k = CFGDenoiserComplex(self.model_wrap) |
|
|
|
self.model_k = CFGDenoiserComplex(self.model_wrap) |
|
|
|
self.device = device |
|
|
|
self.device = device |
|
|
|
if scheduler not in self.SCHEDULERS: |
|
|
|
if scheduler not in self.SCHEDULERS: |
|
|
@ -209,7 +209,7 @@ class KSampler: |
|
|
|
discard_penultimate_sigma = True |
|
|
|
discard_penultimate_sigma = True |
|
|
|
|
|
|
|
|
|
|
|
if self.scheduler == "karras": |
|
|
|
if self.scheduler == "karras": |
|
|
|
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device) |
|
|
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device) |
|
|
|
elif self.scheduler == "normal": |
|
|
|
elif self.scheduler == "normal": |
|
|
|
sigmas = self.model_wrap.get_sigmas(steps).to(self.device) |
|
|
|
sigmas = self.model_wrap.get_sigmas(steps).to(self.device) |
|
|
|
elif self.scheduler == "simple": |
|
|
|
elif self.scheduler == "simple": |
|
|
@ -269,9 +269,9 @@ class KSampler: |
|
|
|
|
|
|
|
|
|
|
|
with precision_scope(self.device): |
|
|
|
with precision_scope(self.device): |
|
|
|
if self.sampler == "sample_dpm_fast": |
|
|
|
if self.sampler == "sample_dpm_fast": |
|
|
|
samples = k_diffusion.sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) |
|
|
|
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) |
|
|
|
elif self.sampler == "sample_dpm_adaptive": |
|
|
|
elif self.sampler == "sample_dpm_adaptive": |
|
|
|
samples = k_diffusion.sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) |
|
|
|
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) |
|
|
|
else: |
|
|
|
else: |
|
|
|
samples = getattr(k_diffusion.sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) |
|
|
|
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) |
|
|
|
return samples.to(torch.float32) |
|
|
|
return samples.to(torch.float32) |
|
|
|