From 6135a21ee813bd7bcb11fdfd9f363b469d5dabe1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Feb 2023 18:08:01 -0500 Subject: [PATCH] Add a way to control controlnet strength. --- comfy/sd.py | 7 ++++++- nodes.py | 10 +++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d37e5316..61a01dea 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -331,6 +331,7 @@ class ControlNet: self.control_model = control_model self.cond_hint_original = None self.cond_hint = None + self.strength = 1.0 def get_control(self, x_noisy, t, cond_txt): if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: @@ -340,10 +341,13 @@ class ControlNet: self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device) print("set cond_hint", self.cond_hint.shape) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + for x in control: + x *= self.strength return control - def set_cond_hint(self, cond_hint): + def set_cond_hint(self, cond_hint, strength=1.0): self.cond_hint_original = cond_hint + self.strength = strength return self def cleanup(self): @@ -354,6 +358,7 @@ class ControlNet: def copy(self): c = ControlNet(self.control_model) c.cond_hint_original = self.cond_hint_original + c.strength = self.strength return c def load_controlnet(ckpt_path): diff --git a/nodes.py b/nodes.py index 9aec9235..be3952ac 100644 --- a/nodes.py +++ b/nodes.py @@ -234,19 +234,23 @@ class ControlNetLoader: class ControlNetApply: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), "control_net": ("CONTROL_NET", ), "image": ("IMAGE", )}} + return {"required": {"conditioning": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) + }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_controlnet" CATEGORY = "conditioning" - def apply_controlnet(self, conditioning, control_net, image): + def apply_controlnet(self, conditioning, control_net, image, strength): c = [] control_hint = image.movedim(-1,1) print(control_hint.shape) for t in conditioning: n = [t[0], t[1].copy()] - n[1]['control'] = control_net.copy().set_cond_hint(control_hint) + n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength) c.append(n) return (c, )