Browse Source

Fix lowvram issue with saving checkpoints.

The previous fix didn't cover the case where the model was loaded in
lowvram mode right before.
pull/3449/merge
comfyanonymous 6 months ago
parent
commit
fa6dd7e5bb
  1. 21
      comfy/model_management.py
  2. 6
      comfy/model_patcher.py

21
comfy/model_management.py

@ -309,6 +309,11 @@ class LoadedModel:
self.weights_loaded = True self.weights_loaded = True
return self.real_model return self.real_model
def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter > 0:
return True
return False
def model_unload(self, unpatch_weights=True): def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
@ -391,10 +396,22 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
models_already_loaded = [] models_already_loaded = []
for x in models: for x in models:
loaded_model = LoadedModel(x) loaded_model = LoadedModel(x)
loaded = None
if loaded_model in current_loaded_models: try:
models_already_loaded.append(loaded_model) loaded_model_index = current_loaded_models.index(loaded_model)
except:
loaded_model_index = None
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else: else:
models_already_loaded.append(loaded)
if loaded is None:
if hasattr(x, "model"): if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}") logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)

6
comfy/model_patcher.py

@ -58,6 +58,7 @@ class ModelPatcher:
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False self.model_lowvram = False
self.lowvram_patch_counter = 0
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
def model_size(self): def model_size(self):
@ -284,6 +285,7 @@ class ModelPatcher:
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0 mem_counter = 0
patch_counter = 0
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
lowvram_weight = False lowvram_weight = False
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
@ -300,11 +302,13 @@ class ModelPatcher:
self.patch_weight_to_device(weight_key) self.patch_weight_to_device(weight_key)
else: else:
m.weight_function = LowVramPatch(weight_key, self) m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(bias_key) self.patch_weight_to_device(bias_key)
else: else:
m.bias_function = LowVramPatch(bias_key, self) m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
@ -317,6 +321,7 @@ class ModelPatcher:
logging.debug("lowvram: loaded module regularly {}".format(m)) logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
@ -468,6 +473,7 @@ class ModelPatcher:
m.bias_function = None m.bias_function = None
self.model_lowvram = False self.model_lowvram = False
self.lowvram_patch_counter = 0
keys = list(self.backup.keys()) keys = list(self.backup.keys())

Loading…
Cancel
Save