From 5272fd4b0389c6e702493d193a1f824f9fa4c7b8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 Apr 2024 14:57:44 -0400 Subject: [PATCH] Add DualCFGGuider used in IP2P models for example. --- comfy_extras/nodes_custom_sampler.py | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 5da437a6..fbd8cd25 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -428,6 +428,41 @@ class CFGGuider: guider.set_cfg(cfg) return (guider,) +class Guider_DualCFG(comfy.samplers.CFGGuider): + def set_cfg(self, cfg1, cfg2): + self.cfg1 = cfg1 + self.cfg2 = cfg2 + + def set_conds(self, positive, middle, negative): + self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative}) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + out = comfy.samplers.calc_cond_batch(self.inner_model, [self.conds.get("negative", None), self.conds.get("middle", None), self.conds.get("positive", None)], x, timestep, model_options) + return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options) + (out[2] - out[1]) * self.cfg1 + +class DualCFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "cond1": ("CONDITIONING", ), + "cond2": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative): + guider = Guider_DualCFG(model) + guider.set_conds(cond1, cond2, negative) + guider.set_cfg(cfg_conds, cfg_cond2_negative) + return (guider,) class DisableNoise: @classmethod @@ -518,6 +553,7 @@ NODE_CLASS_MAPPINGS = { "FlipSigmas": FlipSigmas, "CFGGuider": CFGGuider, + "DualCFGGuider": DualCFGGuider, "BasicGuider": BasicGuider, "RandomNoise": RandomNoise, "DisableNoise": DisableNoise,