|
|
|
@ -127,6 +127,32 @@ if args.cpu:
|
|
|
|
|
|
|
|
|
|
print(f"Set vram state to: {vram_state.name}") |
|
|
|
|
|
|
|
|
|
def get_torch_device(): |
|
|
|
|
global xpu_available |
|
|
|
|
global directml_enabled |
|
|
|
|
if directml_enabled: |
|
|
|
|
global directml_device |
|
|
|
|
return directml_device |
|
|
|
|
if vram_state == VRAMState.MPS: |
|
|
|
|
return torch.device("mps") |
|
|
|
|
if vram_state == VRAMState.CPU: |
|
|
|
|
return torch.device("cpu") |
|
|
|
|
else: |
|
|
|
|
if xpu_available: |
|
|
|
|
return torch.device("xpu") |
|
|
|
|
else: |
|
|
|
|
return torch.cuda.current_device() |
|
|
|
|
|
|
|
|
|
def get_torch_device_name(device): |
|
|
|
|
if hasattr(device, 'type'): |
|
|
|
|
return "{}".format(device.type) |
|
|
|
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
print("Using device:", get_torch_device_name(get_torch_device())) |
|
|
|
|
except: |
|
|
|
|
print("Could not pick default device.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_loaded_model = None |
|
|
|
|
current_gpu_controlnets = [] |
|
|
|
@ -233,22 +259,6 @@ def unload_if_low_vram(model):
|
|
|
|
|
return model.cpu() |
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
def get_torch_device(): |
|
|
|
|
global xpu_available |
|
|
|
|
global directml_enabled |
|
|
|
|
if directml_enabled: |
|
|
|
|
global directml_device |
|
|
|
|
return directml_device |
|
|
|
|
if vram_state == VRAMState.MPS: |
|
|
|
|
return torch.device("mps") |
|
|
|
|
if vram_state == VRAMState.CPU: |
|
|
|
|
return torch.device("cpu") |
|
|
|
|
else: |
|
|
|
|
if xpu_available: |
|
|
|
|
return torch.device("xpu") |
|
|
|
|
else: |
|
|
|
|
return torch.cuda.current_device() |
|
|
|
|
|
|
|
|
|
def get_autocast_device(dev): |
|
|
|
|
if hasattr(dev, 'type'): |
|
|
|
|
return dev.type |
|
|
|
|