diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 65d06199..4cc2534f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -66,6 +66,9 @@ class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): + self.cpu_tree = True + if "cpu" in kwargs: + self.cpu_tree = kwargs.pop("cpu") t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get('w0', torch.zeros_like(x)) if seed is None: @@ -77,7 +80,10 @@ class BatchedBrownianTree: except TypeError: seed = [seed] self.batched = False - self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] + if self.cpu_tree: + self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] + else: + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] @staticmethod def sort(a, b): @@ -85,7 +91,11 @@ class BatchedBrownianTree: def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) - w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) + if self.cpu_tree: + w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) + else: + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] @@ -104,10 +114,10 @@ class BrownianTreeNoiseSampler: internal timestep. """ - def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): self.transform = transform t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) - self.tree = BatchedBrownianTree(x, t0, t1, seed) + self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) def __call__(self, sigma, sigma_next): t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) @@ -544,7 +554,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N """DPM-Solver++ (stochastic).""" sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() seed = extra_args.get("seed", None) - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() @@ -616,7 +626,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -651,3 +661,18 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl old_denoised = denoised h_last = h return x + +@torch.no_grad() +def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + +@torch.no_grad() +def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) + + diff --git a/comfy/samplers.py b/comfy/samplers.py index 3aaf8ac4..ea952559 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -483,8 +483,8 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): class KSampler: SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", - "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", - "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"] + "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model