From c6bd456c45fd24818223bd4f61a6840e281ac82f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 Apr 2024 11:38:25 -0400 Subject: [PATCH] Make zero denoise a NOP. --- comfy/samplers.py | 12 +++++++++--- comfy_extras/nodes_custom_sampler.py | 5 +++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index a89e3a6c..475b1aad 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -624,6 +624,9 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + if sigmas.shape[-1] == 0: + return latent_image + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) @@ -722,9 +725,12 @@ class KSampler: if denoise is None or denoise > 0.9999: self.sigmas = self.calculate_sigmas(steps).to(self.device) else: - new_steps = int(steps/denoise) - sigmas = self.calculate_sigmas(new_steps).to(self.device) - self.sigmas = sigmas[-(steps + 1):] + if denoise <= 0.0: + self.sigmas = torch.FloatTensor([]) + else: + new_steps = int(steps/denoise) + sigmas = self.calculate_sigmas(new_steps).to(self.device) + self.sigmas = sigmas[-(steps + 1):] def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): if sigmas is None: diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index e9dc3bd9..a99dbcee 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -24,6 +24,8 @@ class BasicScheduler: def get_sigmas(self, model, scheduler, steps, denoise): total_steps = steps if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) total_steps = int(steps/denoise) comfy.model_management.load_models_gpu([model]) @@ -160,6 +162,9 @@ class FlipSigmas: FUNCTION = "get_sigmas" def get_sigmas(self, sigmas): + if len(sigmas) == 0: + return (sigmas,) + sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001