From fa6dd7e5bbee031defa640534b0924313757676f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 12 May 2024 06:13:45 -0400 Subject: [PATCH] Fix lowvram issue with saving checkpoints. The previous fix didn't cover the case where the model was loaded in lowvram mode right before. --- comfy/model_management.py | 23 ++++++++++++++++++++--- comfy/model_patcher.py | 6 ++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5a66a383..3d01e8a2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -309,6 +309,11 @@ class LoadedModel: self.weights_loaded = True return self.real_model + def should_reload_model(self, force_patch_weights=False): + if force_patch_weights and self.model.lowvram_patch_counter > 0: + return True + return False + def model_unload(self, unpatch_weights=True): self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) @@ -391,10 +396,22 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False): models_already_loaded = [] for x in models: loaded_model = LoadedModel(x) + loaded = None - if loaded_model in current_loaded_models: - models_already_loaded.append(loaded_model) - else: + try: + loaded_model_index = current_loaded_models.index(loaded_model) + except: + loaded_model_index = None + + if loaded_model_index is not None: + loaded = current_loaded_models[loaded_model_index] + if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic + current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) + loaded = None + else: + models_already_loaded.append(loaded) + + if loaded is None: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 48e5be31..c38b2f79 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -58,6 +58,7 @@ class ModelPatcher: self.weight_inplace_update = weight_inplace_update self.model_lowvram = False + self.lowvram_patch_counter = 0 self.patches_uuid = uuid.uuid4() def model_size(self): @@ -284,6 +285,7 @@ class ModelPatcher: return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) mem_counter = 0 + patch_counter = 0 for n, m in self.model.named_modules(): lowvram_weight = False if hasattr(m, "comfy_cast_weights"): @@ -300,11 +302,13 @@ class ModelPatcher: self.patch_weight_to_device(weight_key) else: m.weight_function = LowVramPatch(weight_key, self) + patch_counter += 1 if bias_key in self.patches: if force_patch_weights: self.patch_weight_to_device(bias_key) else: m.bias_function = LowVramPatch(bias_key, self) + patch_counter += 1 m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True @@ -317,6 +321,7 @@ class ModelPatcher: logging.debug("lowvram: loaded module regularly {}".format(m)) self.model_lowvram = True + self.lowvram_patch_counter = patch_counter return self.model def calculate_weight(self, patches, weight, key): @@ -468,6 +473,7 @@ class ModelPatcher: m.bias_function = None self.model_lowvram = False + self.lowvram_patch_counter = 0 keys = list(self.backup.keys())