|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
import psutil |
|
|
|
|
from enum import Enum |
|
|
|
|
from comfy.cli_args import args |
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
class VRAMState(Enum): |
|
|
|
|
CPU = 0 |
|
|
|
@ -33,28 +34,67 @@ if args.directml is not None:
|
|
|
|
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
import torch |
|
|
|
|
if directml_enabled: |
|
|
|
|
pass #TODO |
|
|
|
|
else: |
|
|
|
|
try: |
|
|
|
|
import intel_extension_for_pytorch as ipex |
|
|
|
|
if torch.xpu.is_available(): |
|
|
|
|
xpu_available = True |
|
|
|
|
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) |
|
|
|
|
except: |
|
|
|
|
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) |
|
|
|
|
total_ram = psutil.virtual_memory().total / (1024 * 1024) |
|
|
|
|
if not args.normalvram and not args.cpu: |
|
|
|
|
if lowvram_available and total_vram <= 4096: |
|
|
|
|
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") |
|
|
|
|
set_vram_to = VRAMState.LOW_VRAM |
|
|
|
|
elif total_vram > total_ram * 1.1 and total_vram > 14336: |
|
|
|
|
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") |
|
|
|
|
vram_state = VRAMState.HIGH_VRAM |
|
|
|
|
import intel_extension_for_pytorch as ipex |
|
|
|
|
if torch.xpu.is_available(): |
|
|
|
|
xpu_available = True |
|
|
|
|
except: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
def get_torch_device(): |
|
|
|
|
global xpu_available |
|
|
|
|
global directml_enabled |
|
|
|
|
if directml_enabled: |
|
|
|
|
global directml_device |
|
|
|
|
return directml_device |
|
|
|
|
if vram_state == VRAMState.MPS: |
|
|
|
|
return torch.device("mps") |
|
|
|
|
if vram_state == VRAMState.CPU: |
|
|
|
|
return torch.device("cpu") |
|
|
|
|
else: |
|
|
|
|
if xpu_available: |
|
|
|
|
return torch.device("xpu") |
|
|
|
|
else: |
|
|
|
|
return torch.device(torch.cuda.current_device()) |
|
|
|
|
|
|
|
|
|
def get_total_memory(dev=None, torch_total_too=False): |
|
|
|
|
global xpu_available |
|
|
|
|
global directml_enabled |
|
|
|
|
if dev is None: |
|
|
|
|
dev = get_torch_device() |
|
|
|
|
|
|
|
|
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): |
|
|
|
|
mem_total = psutil.virtual_memory().total |
|
|
|
|
mem_total_torch = mem_total |
|
|
|
|
else: |
|
|
|
|
if directml_enabled: |
|
|
|
|
mem_total = 1024 * 1024 * 1024 #TODO |
|
|
|
|
mem_total_torch = mem_total |
|
|
|
|
elif xpu_available: |
|
|
|
|
mem_total = torch.xpu.get_device_properties(dev).total_memory |
|
|
|
|
mem_total_torch = mem_total |
|
|
|
|
else: |
|
|
|
|
stats = torch.cuda.memory_stats(dev) |
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current'] |
|
|
|
|
_, mem_total_cuda = torch.cuda.mem_get_info(dev) |
|
|
|
|
mem_total_torch = mem_reserved |
|
|
|
|
mem_total = mem_total_cuda |
|
|
|
|
|
|
|
|
|
if torch_total_too: |
|
|
|
|
return (mem_total, mem_total_torch) |
|
|
|
|
else: |
|
|
|
|
return mem_total |
|
|
|
|
|
|
|
|
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) |
|
|
|
|
total_ram = psutil.virtual_memory().total / (1024 * 1024) |
|
|
|
|
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) |
|
|
|
|
if not args.normalvram and not args.cpu: |
|
|
|
|
if lowvram_available and total_vram <= 4096: |
|
|
|
|
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") |
|
|
|
|
set_vram_to = VRAMState.LOW_VRAM |
|
|
|
|
elif total_vram > total_ram * 1.1 and total_vram > 14336: |
|
|
|
|
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") |
|
|
|
|
vram_state = VRAMState.HIGH_VRAM |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError |
|
|
|
|
except: |
|
|
|
@ -128,29 +168,17 @@ if args.cpu:
|
|
|
|
|
|
|
|
|
|
print(f"Set vram state to: {vram_state.name}") |
|
|
|
|
|
|
|
|
|
def get_torch_device(): |
|
|
|
|
global xpu_available |
|
|
|
|
global directml_enabled |
|
|
|
|
if directml_enabled: |
|
|
|
|
global directml_device |
|
|
|
|
return directml_device |
|
|
|
|
if vram_state == VRAMState.MPS: |
|
|
|
|
return torch.device("mps") |
|
|
|
|
if vram_state == VRAMState.CPU: |
|
|
|
|
return torch.device("cpu") |
|
|
|
|
else: |
|
|
|
|
if xpu_available: |
|
|
|
|
return torch.device("xpu") |
|
|
|
|
else: |
|
|
|
|
return torch.cuda.current_device() |
|
|
|
|
|
|
|
|
|
def get_torch_device_name(device): |
|
|
|
|
if hasattr(device, 'type'): |
|
|
|
|
return "{}".format(device.type) |
|
|
|
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) |
|
|
|
|
if device.type == "cuda": |
|
|
|
|
return "{} {}".format(device, torch.cuda.get_device_name(device)) |
|
|
|
|
else: |
|
|
|
|
return "{}".format(device.type) |
|
|
|
|
else: |
|
|
|
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
print("Using device:", get_torch_device_name(get_torch_device())) |
|
|
|
|
print("Device:", get_torch_device_name(get_torch_device())) |
|
|
|
|
except: |
|
|
|
|
print("Could not pick default device.") |
|
|
|
|
|
|
|
|
@ -308,33 +336,6 @@ def pytorch_attention_flash_attention():
|
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def get_total_memory(dev=None, torch_total_too=False): |
|
|
|
|
global xpu_available |
|
|
|
|
global directml_enabled |
|
|
|
|
if dev is None: |
|
|
|
|
dev = get_torch_device() |
|
|
|
|
|
|
|
|
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): |
|
|
|
|
mem_total = psutil.virtual_memory().total |
|
|
|
|
else: |
|
|
|
|
if directml_enabled: |
|
|
|
|
mem_total = 1024 * 1024 * 1024 #TODO |
|
|
|
|
mem_total_torch = mem_total |
|
|
|
|
elif xpu_available: |
|
|
|
|
mem_total = torch.xpu.get_device_properties(dev).total_memory |
|
|
|
|
mem_total_torch = mem_total |
|
|
|
|
else: |
|
|
|
|
stats = torch.cuda.memory_stats(dev) |
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current'] |
|
|
|
|
_, mem_total_cuda = torch.cuda.mem_get_info(dev) |
|
|
|
|
mem_total_torch = mem_reserved |
|
|
|
|
mem_total = mem_total_cuda |
|
|
|
|
|
|
|
|
|
if torch_total_too: |
|
|
|
|
return (mem_total, mem_total_torch) |
|
|
|
|
else: |
|
|
|
|
return mem_total |
|
|
|
|
|
|
|
|
|
def get_free_memory(dev=None, torch_free_too=False): |
|
|
|
|
global xpu_available |
|
|
|
|
global directml_enabled |
|
|
|
|