Browse Source

Fix some performance issues with weight loading and unloading.

Lower peak memory usage when changing model.

Fix case where model weights would be unloaded and reloaded.
pull/3168/head
comfyanonymous 8 months ago
parent
commit
5d8898c056
  1. 16
      comfy/model_management.py
  2. 1
      execution.py

16
comfy/model_management.py

@ -274,6 +274,7 @@ class LoadedModel:
self.model = model
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
def model_memory(self):
return self.model.model_size()
@ -312,6 +313,7 @@ class LoadedModel:
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def __eq__(self, other):
return self.model is other.model
@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return None
return True
same_weights = 0
for i in to_unload:
@ -408,8 +410,8 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
@ -448,11 +450,15 @@ def load_models_gpu(models, memory_required=0):
def load_model_gpu(model):
return load_models_gpu([model])
def cleanup_models():
def cleanup_models(keep_clone_weights_loaded=False):
to_delete = []
for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)

1
execution.py

@ -368,6 +368,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x)
del d
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)

Loading…
Cancel
Save