|
|
|
@ -83,7 +83,7 @@ def get_torch_device():
|
|
|
|
|
return torch.device("cpu") |
|
|
|
|
else: |
|
|
|
|
if is_intel_xpu(): |
|
|
|
|
return torch.device("xpu") |
|
|
|
|
return torch.device("xpu", torch.xpu.current_device()) |
|
|
|
|
else: |
|
|
|
|
return torch.device(torch.cuda.current_device()) |
|
|
|
|
|
|
|
|
@ -304,7 +304,7 @@ class LoadedModel:
|
|
|
|
|
raise e |
|
|
|
|
|
|
|
|
|
if is_intel_xpu() and not args.disable_ipex_optimize: |
|
|
|
|
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) |
|
|
|
|
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) |
|
|
|
|
|
|
|
|
|
self.weights_loaded = True |
|
|
|
|
return self.real_model |
|
|
|
@ -552,8 +552,6 @@ def text_encoder_device():
|
|
|
|
|
if args.gpu_only: |
|
|
|
|
return get_torch_device() |
|
|
|
|
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: |
|
|
|
|
if is_intel_xpu(): |
|
|
|
|
return torch.device("cpu") |
|
|
|
|
if should_use_fp16(prioritize_performance=False): |
|
|
|
|
return get_torch_device() |
|
|
|
|
else: |
|
|
|
|