From e1489ad2576651a1384e9b82fd1991ac2f8764e0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 11 May 2024 21:46:05 -0400 Subject: [PATCH] Fix issue with lowvram mode breaking model saving. --- comfy/model_management.py | 8 ++++---- comfy/model_patcher.py | 12 +++++++++--- comfy/sd.py | 2 +- comfy_extras/nodes_model_merging.py | 2 +- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 913b6844..15dd73a6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -285,7 +285,7 @@ class LoadedModel: else: return self.model_memory() - def model_load(self, lowvram_model_memory=0): + def model_load(self, lowvram_model_memory=0, force_patch_weights=False): patch_model_to = self.device self.model.model_patches_to(self.device) @@ -295,7 +295,7 @@ class LoadedModel: try: if lowvram_model_memory > 0 and load_weights: - self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) else: self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) except Exception as e: @@ -379,7 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]): if mem_free_torch > mem_free_total * 0.25: soft_empty_cache() -def load_models_gpu(models, memory_required=0): +def load_models_gpu(models, memory_required=0, force_patch_weights=False): global vram_state inference_memory = minimum_inference_memory() @@ -444,7 +444,7 @@ def load_models_gpu(models, memory_required=0): if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 64 * 1024 * 1024 - cur_loaded_model = loaded_model.model_load(lowvram_model_memory) + cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) return diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cf51c4ad..48e5be31 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -272,7 +272,7 @@ class ModelPatcher: return self.model - def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): self.patch_model(device_to, patch_weights=False) logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) @@ -296,9 +296,15 @@ class ModelPatcher: if lowvram_weight: if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self) + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + m.weight_function = LowVramPatch(weight_key, self) if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self) + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + m.bias_function = LowVramPatch(bias_key, self) m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True diff --git a/comfy/sd.py b/comfy/sd.py index 9671e4ae..8044c184 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -562,7 +562,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m load_models.append(clip.load_model()) clip_sd = clip.get_sd() - model_management.load_models_gpu(load_models) + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) for k in extra_keys: diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 2a431f65..8c5dc985 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -262,7 +262,7 @@ class CLIPSave: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) - comfy.model_management.load_models_gpu([clip.load_model()]) + comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True) clip_sd = clip.get_sd() for prefix in ["clip_l.", "clip_g.", ""]: