diff --git a/comfy/cli_args.py b/comfy/cli_args.py index e7ce256b..29e5fb15 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -40,7 +40,10 @@ parser.add_argument("--extra-model-paths-config", type=str, default=None, metava parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") -parser.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync.") +cm_group = parser.add_mutually_exclusive_group() +cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") +cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Enable cudaMallocAsync.") + parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") fp_group = parser.add_mutually_exclusive_group() @@ -85,4 +88,3 @@ args = parser.parse_args() if args.windows_standalone_build: args.auto_launch = True - args.cuda_malloc = True #work around memory issue in nvidia drivers > 531 diff --git a/main.py b/main.py index a2254557..61ba9e8e 100644 --- a/main.py +++ b/main.py @@ -61,7 +61,23 @@ if __name__ == "__main__": os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) print("Set cuda device to:", args.cuda_device) - if args.cuda_malloc: + if not args.cuda_malloc: + try: #if there's a better way to check the torch version without importing it let me know + version = "" + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ + if int(version[0]) >= 2: #enable by default for torch version 2.0 and up + args.cuda_malloc = True + except: + pass + + if args.cuda_malloc and not args.disable_cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync"