|
|
|
@ -129,7 +129,6 @@ def load_model_gpu(model):
|
|
|
|
|
global current_loaded_model |
|
|
|
|
global vram_state |
|
|
|
|
global model_accelerated |
|
|
|
|
global xpu_available |
|
|
|
|
|
|
|
|
|
if model is current_loaded_model: |
|
|
|
|
return |
|
|
|
@ -148,17 +147,14 @@ def load_model_gpu(model):
|
|
|
|
|
pass |
|
|
|
|
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: |
|
|
|
|
model_accelerated = False |
|
|
|
|
if xpu_available: |
|
|
|
|
real_model.to("xpu") |
|
|
|
|
else: |
|
|
|
|
real_model.cuda() |
|
|
|
|
real_model.to(get_torch_device()) |
|
|
|
|
else: |
|
|
|
|
if vram_state == VRAMState.NO_VRAM: |
|
|
|
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) |
|
|
|
|
elif vram_state == VRAMState.LOW_VRAM: |
|
|
|
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) |
|
|
|
|
|
|
|
|
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda") |
|
|
|
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) |
|
|
|
|
model_accelerated = True |
|
|
|
|
return current_loaded_model |
|
|
|
|
|
|
|
|
@ -184,12 +180,8 @@ def load_controlnet_gpu(models):
|
|
|
|
|
|
|
|
|
|
def load_if_low_vram(model): |
|
|
|
|
global vram_state |
|
|
|
|
global xpu_available |
|
|
|
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: |
|
|
|
|
if xpu_available: |
|
|
|
|
return model.to("xpu") |
|
|
|
|
else: |
|
|
|
|
return model.cuda() |
|
|
|
|
return model.to(get_torch_device()) |
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
def unload_if_low_vram(model): |
|
|
|
|