|
|
|
@ -274,6 +274,7 @@ class LoadedModel:
|
|
|
|
|
self.model = model |
|
|
|
|
self.device = model.load_device |
|
|
|
|
self.weights_loaded = False |
|
|
|
|
self.real_model = None |
|
|
|
|
|
|
|
|
|
def model_memory(self): |
|
|
|
|
return self.model.model_size() |
|
|
|
@ -312,6 +313,7 @@ class LoadedModel:
|
|
|
|
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) |
|
|
|
|
self.model.model_patches_to(self.model.offload_device) |
|
|
|
|
self.weights_loaded = self.weights_loaded and not unpatch_weights |
|
|
|
|
self.real_model = None |
|
|
|
|
|
|
|
|
|
def __eq__(self, other): |
|
|
|
|
return self.model is other.model |
|
|
|
@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
|
|
|
|
to_unload = [i] + to_unload |
|
|
|
|
|
|
|
|
|
if len(to_unload) == 0: |
|
|
|
|
return None |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
same_weights = 0 |
|
|
|
|
for i in to_unload: |
|
|
|
@ -408,8 +410,8 @@ def load_models_gpu(models, memory_required=0):
|
|
|
|
|
|
|
|
|
|
total_memory_required = {} |
|
|
|
|
for loaded_model in models_to_load: |
|
|
|
|
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different |
|
|
|
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) |
|
|
|
|
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different |
|
|
|
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) |
|
|
|
|
|
|
|
|
|
for device in total_memory_required: |
|
|
|
|
if device != torch.device("cpu"): |
|
|
|
@ -448,11 +450,15 @@ def load_models_gpu(models, memory_required=0):
|
|
|
|
|
def load_model_gpu(model): |
|
|
|
|
return load_models_gpu([model]) |
|
|
|
|
|
|
|
|
|
def cleanup_models(): |
|
|
|
|
def cleanup_models(keep_clone_weights_loaded=False): |
|
|
|
|
to_delete = [] |
|
|
|
|
for i in range(len(current_loaded_models)): |
|
|
|
|
if sys.getrefcount(current_loaded_models[i].model) <= 2: |
|
|
|
|
to_delete = [i] + to_delete |
|
|
|
|
if not keep_clone_weights_loaded: |
|
|
|
|
to_delete = [i] + to_delete |
|
|
|
|
#TODO: find a less fragile way to do this. |
|
|
|
|
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model |
|
|
|
|
to_delete = [i] + to_delete |
|
|
|
|
|
|
|
|
|
for i in to_delete: |
|
|
|
|
x = current_loaded_models.pop(i) |
|
|
|
|