|
|
|
@ -273,6 +273,7 @@ class LoadedModel:
|
|
|
|
|
def __init__(self, model): |
|
|
|
|
self.model = model |
|
|
|
|
self.device = model.load_device |
|
|
|
|
self.weights_loaded = False |
|
|
|
|
|
|
|
|
|
def model_memory(self): |
|
|
|
|
return self.model.model_size() |
|
|
|
@ -289,11 +290,13 @@ class LoadedModel:
|
|
|
|
|
self.model.model_patches_to(self.device) |
|
|
|
|
self.model.model_patches_to(self.model.model_dtype()) |
|
|
|
|
|
|
|
|
|
load_weights = not self.weights_loaded |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
if lowvram_model_memory > 0: |
|
|
|
|
if lowvram_model_memory > 0 and load_weights: |
|
|
|
|
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) |
|
|
|
|
else: |
|
|
|
|
self.real_model = self.model.patch_model(device_to=patch_model_to) |
|
|
|
|
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) |
|
|
|
|
except Exception as e: |
|
|
|
|
self.model.unpatch_model(self.model.offload_device) |
|
|
|
|
self.model_unload() |
|
|
|
@ -302,11 +305,13 @@ class LoadedModel:
|
|
|
|
|
if is_intel_xpu() and not args.disable_ipex_optimize: |
|
|
|
|
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) |
|
|
|
|
|
|
|
|
|
self.weights_loaded = True |
|
|
|
|
return self.real_model |
|
|
|
|
|
|
|
|
|
def model_unload(self): |
|
|
|
|
self.model.unpatch_model(self.model.offload_device) |
|
|
|
|
def model_unload(self, unpatch_weights=True): |
|
|
|
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) |
|
|
|
|
self.model.model_patches_to(self.model.offload_device) |
|
|
|
|
self.weights_loaded = self.weights_loaded and not unpatch_weights |
|
|
|
|
|
|
|
|
|
def __eq__(self, other): |
|
|
|
|
return self.model is other.model |
|
|
|
@ -314,15 +319,35 @@ class LoadedModel:
|
|
|
|
|
def minimum_inference_memory(): |
|
|
|
|
return (1024 * 1024 * 1024) |
|
|
|
|
|
|
|
|
|
def unload_model_clones(model): |
|
|
|
|
def unload_model_clones(loaded_model, unload_weights_only=True): |
|
|
|
|
model = loaded_model.model |
|
|
|
|
|
|
|
|
|
to_unload = [] |
|
|
|
|
for i in range(len(current_loaded_models)): |
|
|
|
|
if model.is_clone(current_loaded_models[i].model): |
|
|
|
|
to_unload = [i] + to_unload |
|
|
|
|
|
|
|
|
|
if len(to_unload) == 0: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
same_weights = 0 |
|
|
|
|
for i in to_unload: |
|
|
|
|
logging.debug("unload clone {}".format(i)) |
|
|
|
|
current_loaded_models.pop(i).model_unload() |
|
|
|
|
if model.clone_has_same_weights(current_loaded_models[i].model): |
|
|
|
|
same_weights += 1 |
|
|
|
|
|
|
|
|
|
if same_weights == len(to_unload): |
|
|
|
|
unload_weight = False |
|
|
|
|
else: |
|
|
|
|
unload_weight = True |
|
|
|
|
|
|
|
|
|
if unload_weights_only and unload_weight == False: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
for i in to_unload: |
|
|
|
|
logging.debug("unload clone {} {}".format(i, unload_weight)) |
|
|
|
|
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) |
|
|
|
|
|
|
|
|
|
loaded_model.weights_loaded = not unload_weight |
|
|
|
|
|
|
|
|
|
def free_memory(memory_required, device, keep_loaded=[]): |
|
|
|
|
unloaded_model = False |
|
|
|
@ -377,13 +402,16 @@ def load_models_gpu(models, memory_required=0):
|
|
|
|
|
|
|
|
|
|
total_memory_required = {} |
|
|
|
|
for loaded_model in models_to_load: |
|
|
|
|
unload_model_clones(loaded_model.model) |
|
|
|
|
unload_model_clones(loaded_model, unload_weights_only=True) #unload clones where the weights are different |
|
|
|
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) |
|
|
|
|
|
|
|
|
|
for device in total_memory_required: |
|
|
|
|
if device != torch.device("cpu"): |
|
|
|
|
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) |
|
|
|
|
|
|
|
|
|
for loaded_model in models_to_load: |
|
|
|
|
unload_model_clones(loaded_model, unload_weights_only=False) #unload the rest of the clones where the weights can stay loaded |
|
|
|
|
|
|
|
|
|
for loaded_model in models_to_load: |
|
|
|
|
model = loaded_model.model |
|
|
|
|
torch_dev = model.load_device |
|
|
|
|