|
|
@ -617,7 +617,8 @@ def supports_dtype(device, dtype): #TODO |
|
|
|
def device_supports_non_blocking(device): |
|
|
|
def device_supports_non_blocking(device): |
|
|
|
if is_device_mps(device): |
|
|
|
if is_device_mps(device): |
|
|
|
return False #pytorch bug? mps doesn't support non blocking |
|
|
|
return False #pytorch bug? mps doesn't support non blocking |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
# return True #TODO: figure out why this causes issues |
|
|
|
|
|
|
|
|
|
|
|
def cast_to_device(tensor, device, dtype, copy=False): |
|
|
|
def cast_to_device(tensor, device, dtype, copy=False): |
|
|
|
device_supports_cast = False |
|
|
|
device_supports_cast = False |
|
|
|