diff --git a/comfy/model_management.py b/comfy/model_management.py index b8fd8796..1301f746 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -102,6 +102,12 @@ def load_model_gpu(model): def load_controlnet_gpu(models): global current_gpu_controlnets + global vram_state + + if vram_state == LOW_VRAM or vram_state == NO_VRAM: + #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after + return + for m in current_gpu_controlnets: if m not in models: m.cpu() @@ -111,6 +117,19 @@ def load_controlnet_gpu(models): current_gpu_controlnets.append(m.cuda()) +def load_if_low_vram(model): + global vram_state + if vram_state == LOW_VRAM or vram_state == NO_VRAM: + return model.cuda() + return model + +def unload_if_low_vram(model): + global vram_state + if vram_state == LOW_VRAM or vram_state == NO_VRAM: + return model.cpu() + return model + + def get_free_memory(): dev = torch.cuda.current_device() stats = torch.cuda.memory_stats(dev) diff --git a/comfy/sd.py b/comfy/sd.py index 9f46595e..bf67f128 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -349,7 +349,9 @@ class ControlNet: precision_scope = contextlib.nullcontext with precision_scope(self.device): + self.control_model = model_management.load_if_low_vram(self.control_model) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + self.control_model = model_management.unload_if_low_vram(self.control_model) out = [] autocast_enabled = torch.is_autocast_enabled() for x in control: