|
|
|
@ -496,7 +496,7 @@ def unet_dtype(device=None, model_params=0):
|
|
|
|
|
return torch.float8_e4m3fn |
|
|
|
|
if args.fp8_e5m2_unet: |
|
|
|
|
return torch.float8_e5m2 |
|
|
|
|
if should_use_fp16(device=device, model_params=model_params): |
|
|
|
|
if should_use_fp16(device=device, model_params=model_params, manual_cast=True): |
|
|
|
|
return torch.float16 |
|
|
|
|
return torch.float32 |
|
|
|
|
|
|
|
|
@ -696,7 +696,7 @@ def is_device_mps(device):
|
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def should_use_fp16(device=None, model_params=0, prioritize_performance=True): |
|
|
|
|
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): |
|
|
|
|
global directml_enabled |
|
|
|
|
|
|
|
|
|
if device is not None: |
|
|
|
@ -738,7 +738,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
|
|
|
|
if x in props.name.lower(): |
|
|
|
|
fp16_works = True |
|
|
|
|
|
|
|
|
|
if fp16_works: |
|
|
|
|
if fp16_works or manual_cast: |
|
|
|
|
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) |
|
|
|
|
if (not prioritize_performance) or model_params * 4 > free_model_memory: |
|
|
|
|
return True |
|
|
|
|