@ -546,10 +546,8 @@ def text_encoder_dtype(device=None):
if is_device_cpu(device):
return torch.float16
if should_use_fp16(device, prioritize_performance=False):
else:
return torch.float32
def intermediate_device():
if args.gpu_only: