Browse Source

add ascend backend

pull/3449/head
root 6 months ago
parent
commit
ba59f7b4e1
  1. 13
      README.md
  2. 4
      comfy/cli_args.py
  3. 40
      comfy/model_management.py

13
README.md

@ -118,6 +118,19 @@ This is the command to install pytorch nightly instead which might have performa
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
### ASCEND
Ascend users should install ```cann>=7.0.0 torch+cpu>=2.1.0``` and ```torch_npu```,below are the installation reference documents.
https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha001/softwareinst/instg/instg_0001.html
https://gitee.com/ascend/pytorch
This is the command to lanuch ComfyUI using Ascend backend.
```python main.py --use-npu```
#### Troubleshooting
If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with:

4
comfy/cli_args.py

@ -81,6 +81,10 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
parser.add_argument("--disable-torchair-optimize", action="store_true", help="Disables torchair graph modee optimize when loading models with Ascend NPUs.")
parser.add_argument("--use-npu", action="store_true", help="use Huawei Ascend NPUs backend.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"

40
comfy/model_management.py

@ -4,6 +4,8 @@ from enum import Enum
from comfy.cli_args import args
import comfy.utils
import torch
if args.use_npu == True:
import torch_npu
import sys
class VRAMState(Enum):
@ -71,6 +73,12 @@ def is_intel_xpu():
return True
return False
def is_ascend_npu():
if args.use_npu and torch.npu.is_available():
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@ -84,6 +92,8 @@ def get_torch_device():
else:
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
else:
return torch.device(torch.cuda.current_device())
@ -104,6 +114,11 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.npu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -179,6 +194,10 @@ try:
if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_ascend_npu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
@ -249,6 +268,8 @@ def get_torch_device_name(device):
return "{}".format(device.type)
elif is_intel_xpu():
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -306,6 +327,11 @@ class LoadedModel:
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
if is_ascend_npu() and not args.disable_torchair_optimize: # torchair optimize
import torchair as tng
npu_backend = tng.get_npu_backend()
self.real_model = torch.compile(self.real_model.eval(), backend=npu_backend, dynamic=False)
self.weights_loaded = True
return self.real_model
@ -649,6 +675,8 @@ def xformers_enabled():
return False
if directml_enabled:
return False
if is_ascend_npu():
return False
return XFORMERS_IS_AVAILABLE
@ -690,6 +718,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@ -755,6 +790,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
if is_ascend_npu():
return True
if torch.version.hip:
return True
@ -833,6 +871,8 @@ def soft_empty_cache(force=False):
torch.mps.empty_cache()
elif is_intel_xpu():
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
elif torch.cuda.is_available():
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()

Loading…
Cancel
Save