diff --git a/comfy/model_management.py b/comfy/model_management.py index 94d59696..3588d350 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -550,12 +550,12 @@ def cast_to_device(tensor, device, dtype, copy=False): if device_supports_cast: if copy: if tensor.device == device: - return tensor.to(dtype, copy=copy) - return tensor.to(device, copy=copy).to(dtype) + return tensor.to(dtype, copy=copy, non_blocking=True) + return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True) else: - return tensor.to(device).to(dtype) + return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True) else: - return tensor.to(dtype).to(device, copy=copy) + return tensor.to(device, dtype, copy=copy, non_blocking=True) def xformers_enabled(): global directml_enabled