|
|
|
@ -4,6 +4,8 @@ from enum import Enum
|
|
|
|
|
from comfy.cli_args import args |
|
|
|
|
import comfy.utils |
|
|
|
|
import torch |
|
|
|
|
if args.use_npu == True: |
|
|
|
|
import torch_npu |
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
|
class VRAMState(Enum): |
|
|
|
@ -71,6 +73,12 @@ def is_intel_xpu():
|
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def is_ascend_npu(): |
|
|
|
|
if args.use_npu and torch.npu.is_available(): |
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_torch_device(): |
|
|
|
|
global directml_enabled |
|
|
|
|
global cpu_state |
|
|
|
@ -84,6 +92,8 @@ def get_torch_device():
|
|
|
|
|
else: |
|
|
|
|
if is_intel_xpu(): |
|
|
|
|
return torch.device("xpu", torch.xpu.current_device()) |
|
|
|
|
elif is_ascend_npu(): |
|
|
|
|
return torch.device("npu", torch.npu.current_device()) |
|
|
|
|
else: |
|
|
|
|
return torch.device(torch.cuda.current_device()) |
|
|
|
|
|
|
|
|
@ -104,6 +114,11 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current'] |
|
|
|
|
mem_total = torch.xpu.get_device_properties(dev).total_memory |
|
|
|
|
mem_total_torch = mem_reserved |
|
|
|
|
elif is_ascend_npu(): |
|
|
|
|
stats = torch.npu.memory_stats(dev) |
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current'] |
|
|
|
|
mem_total = torch.npu.get_device_properties(dev).total_memory |
|
|
|
|
mem_total_torch = mem_reserved |
|
|
|
|
else: |
|
|
|
|
stats = torch.cuda.memory_stats(dev) |
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current'] |
|
|
|
@ -179,6 +194,10 @@ try:
|
|
|
|
|
if is_intel_xpu(): |
|
|
|
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: |
|
|
|
|
ENABLE_PYTORCH_ATTENTION = True |
|
|
|
|
if is_ascend_npu(): |
|
|
|
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: |
|
|
|
|
ENABLE_PYTORCH_ATTENTION = True |
|
|
|
|
|
|
|
|
|
except: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
@ -249,6 +268,8 @@ def get_torch_device_name(device):
|
|
|
|
|
return "{}".format(device.type) |
|
|
|
|
elif is_intel_xpu(): |
|
|
|
|
return "{} {}".format(device, torch.xpu.get_device_name(device)) |
|
|
|
|
elif is_ascend_npu(): |
|
|
|
|
return "{} {}".format(device, torch.npu.get_device_name(device)) |
|
|
|
|
else: |
|
|
|
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) |
|
|
|
|
|
|
|
|
@ -306,6 +327,11 @@ class LoadedModel:
|
|
|
|
|
if is_intel_xpu() and not args.disable_ipex_optimize: |
|
|
|
|
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) |
|
|
|
|
|
|
|
|
|
if is_ascend_npu() and not args.disable_torchair_optimize: # torchair optimize |
|
|
|
|
import torchair as tng |
|
|
|
|
npu_backend = tng.get_npu_backend() |
|
|
|
|
self.real_model = torch.compile(self.real_model.eval(), backend=npu_backend, dynamic=False) |
|
|
|
|
|
|
|
|
|
self.weights_loaded = True |
|
|
|
|
return self.real_model |
|
|
|
|
|
|
|
|
@ -649,6 +675,8 @@ def xformers_enabled():
|
|
|
|
|
return False |
|
|
|
|
if directml_enabled: |
|
|
|
|
return False |
|
|
|
|
if is_ascend_npu(): |
|
|
|
|
return False |
|
|
|
|
return XFORMERS_IS_AVAILABLE |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -690,6 +718,13 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current'] |
|
|
|
|
mem_free_torch = mem_reserved - mem_active |
|
|
|
|
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated |
|
|
|
|
elif is_ascend_npu(): |
|
|
|
|
stats = torch.npu.memory_stats(dev) |
|
|
|
|
mem_active = stats['active_bytes.all.current'] |
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current'] |
|
|
|
|
mem_free_npu, _ = torch.npu.mem_get_info(dev) |
|
|
|
|
mem_free_torch = mem_reserved - mem_active |
|
|
|
|
mem_free_total = mem_free_npu + mem_free_torch |
|
|
|
|
else: |
|
|
|
|
stats = torch.cuda.memory_stats(dev) |
|
|
|
|
mem_active = stats['active_bytes.all.current'] |
|
|
|
@ -755,6 +790,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|
|
|
|
if is_intel_xpu(): |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
if is_ascend_npu(): |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
if torch.version.hip: |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
@ -833,6 +871,8 @@ def soft_empty_cache(force=False):
|
|
|
|
|
torch.mps.empty_cache() |
|
|
|
|
elif is_intel_xpu(): |
|
|
|
|
torch.xpu.empty_cache() |
|
|
|
|
elif is_ascend_npu(): |
|
|
|
|
torch.npu.empty_cache() |
|
|
|
|
elif torch.cuda.is_available(): |
|
|
|
|
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda |
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|