|
|
|
@ -684,17 +684,20 @@ def mps_mode():
|
|
|
|
|
global cpu_state |
|
|
|
|
return cpu_state == CPUState.MPS |
|
|
|
|
|
|
|
|
|
def is_device_cpu(device): |
|
|
|
|
def is_device_type(device, type): |
|
|
|
|
if hasattr(device, 'type'): |
|
|
|
|
if (device.type == 'cpu'): |
|
|
|
|
if (device.type == type): |
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def is_device_cpu(device): |
|
|
|
|
return is_device_type(device, 'cpu') |
|
|
|
|
|
|
|
|
|
def is_device_mps(device): |
|
|
|
|
if hasattr(device, 'type'): |
|
|
|
|
if (device.type == 'mps'): |
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
return is_device_type(device, 'mps') |
|
|
|
|
|
|
|
|
|
def is_device_cuda(device): |
|
|
|
|
return is_device_type(device, 'cuda') |
|
|
|
|
|
|
|
|
|
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): |
|
|
|
|
global directml_enabled |
|
|
|
|