From 571ea8cdcc2d1bf4fa7f398dad68415dacfff02f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Dec 2023 17:03:32 -0500 Subject: [PATCH] Fix SAG not working with cfg 1.0 --- comfy/model_patcher.py | 8 ++++++-- comfy/samplers.py | 2 +- comfy_extras/nodes_sag.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e0acdc96..6acb2d64 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -55,14 +55,18 @@ class ModelPatcher: def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) - def set_model_sampler_cfg_function(self, sampler_cfg_function): + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True - def set_model_sampler_post_cfg_function(self, post_cfg_function): + def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function diff --git a/comfy/samplers.py b/comfy/samplers.py index 18bd75ef..47f34778 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -244,7 +244,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - if math.isclose(cond_scale, 1.0): + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: uncond_ = None else: uncond_ = uncond diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index fea673d6..450ac3ee 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -151,7 +151,7 @@ class SelfAttentionGuidance: (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) return cfg_result + (degraded - sag) * sag_scale - m.set_model_sampler_post_cfg_function(post_cfg_function) + m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) # from diffusers: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch