From 0a03009808a5ad13fa3a44edbabcae68576c3982 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 Apr 2024 18:38:39 -0400 Subject: [PATCH] Fix issue with controlnet models getting loaded multiple times. --- comfy/controlnet.py | 10 +++++++--- comfy/model_management.py | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index b6941d8c..8cf4a61a 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -138,11 +138,13 @@ class ControlBase: return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model self.load_device = load_device - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + if control_model is not None: + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.global_average_pooling = global_average_pooling self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype @@ -183,7 +185,9 @@ class ControlNet(ControlBase): return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped self.copy_to(c) return c diff --git a/comfy/model_management.py b/comfy/model_management.py index 26216432..310ec253 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -385,6 +385,8 @@ def load_models_gpu(models, memory_required=0): inference_memory = minimum_inference_memory() extra_mem = max(inference_memory, memory_required) + models = set(models) + models_to_load = [] models_already_loaded = [] for x in models: