diff --git a/README.md b/README.md index 2636ce14..e6af3c12 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,19 @@ This is the command to install pytorch nightly instead which might have performa ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121``` +### ASCEND + +Ascend users should install ```cann>=7.0.0 torch+cpu>=2.1.0``` and ```torch_npu```,below are the installation reference documents. + +https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha001/softwareinst/instg/instg_0001.html + +https://gitee.com/ascend/pytorch + +This is the command to lanuch ComfyUI using Ascend backend. + +```python main.py --use-npu``` + + #### Troubleshooting If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 569c7938..543e49b5 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -81,6 +81,10 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE" parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") +parser.add_argument("--disable-torchair-optimize", action="store_true", help="Disables torchair graph modee optimize when loading models with Ascend NPUs.") + +parser.add_argument("--use-npu", action="store_true", help="use Huawei Ascend NPUs backend.") + class LatentPreviewMethod(enum.Enum): NoPreviews = "none" Auto = "auto" diff --git a/comfy/model_management.py b/comfy/model_management.py index 913b6844..0b751181 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,6 +4,8 @@ from enum import Enum from comfy.cli_args import args import comfy.utils import torch +if args.use_npu == True: + import torch_npu import sys class VRAMState(Enum): @@ -71,6 +73,12 @@ def is_intel_xpu(): return True return False +def is_ascend_npu(): + if args.use_npu and torch.npu.is_available(): + return True + return False + + def get_torch_device(): global directml_enabled global cpu_state @@ -84,6 +92,8 @@ def get_torch_device(): else: if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) + elif is_ascend_npu(): + return torch.device("npu", torch.npu.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -104,6 +114,11 @@ def get_total_memory(dev=None, torch_total_too=False): mem_reserved = stats['reserved_bytes.all.current'] mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total_torch = mem_reserved + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + mem_total = torch.npu.get_device_properties(dev).total_memory + mem_total_torch = mem_reserved else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -179,6 +194,10 @@ try: if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True + if is_ascend_npu(): + if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + ENABLE_PYTORCH_ATTENTION = True + except: pass @@ -249,6 +268,8 @@ def get_torch_device_name(device): return "{}".format(device.type) elif is_intel_xpu(): return "{} {}".format(device, torch.xpu.get_device_name(device)) + elif is_ascend_npu(): + return "{} {}".format(device, torch.npu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -306,6 +327,11 @@ class LoadedModel: if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) + if is_ascend_npu() and not args.disable_torchair_optimize: # torchair optimize + import torchair as tng + npu_backend = tng.get_npu_backend() + self.real_model = torch.compile(self.real_model.eval(), backend=npu_backend, dynamic=False) + self.weights_loaded = True return self.real_model @@ -649,6 +675,8 @@ def xformers_enabled(): return False if directml_enabled: return False + if is_ascend_npu(): + return False return XFORMERS_IS_AVAILABLE @@ -690,6 +718,13 @@ def get_free_memory(dev=None, torch_free_too=False): mem_reserved = stats['reserved_bytes.all.current'] mem_free_torch = mem_reserved - mem_active mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_npu, _ = torch.npu.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_npu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -755,6 +790,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True + if is_ascend_npu(): + return True + if torch.version.hip: return True @@ -833,6 +871,8 @@ def soft_empty_cache(force=False): torch.mps.empty_cache() elif is_intel_xpu(): torch.xpu.empty_cache() + elif is_ascend_npu(): + torch.npu.empty_cache() elif torch.cuda.is_available(): if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache()