diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat deleted file mode 100755 index 94f5d102..00000000 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ /dev/null @@ -1,3 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -pause diff --git a/.ci/nightly/windows_base_files/run_nvidia_gpu.bat b/.ci/nightly/windows_base_files/run_nvidia_gpu.bat deleted file mode 100755 index 8ee2f340..00000000 --- a/.ci/nightly/windows_base_files/run_nvidia_gpu.bat +++ /dev/null @@ -1,2 +0,0 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention -pause diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index ef9374c4..127247b2 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -1,6 +1,9 @@ import pygit2 from datetime import datetime import sys +import os +import shutil +import filecmp def pull(repo, remote_name='origin', branch='master'): for remote in repo.remotes: @@ -42,7 +45,8 @@ def pull(repo, remote_name='origin', branch='master'): raise AssertionError('Unknown merge analysis result') pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) -repo = pygit2.Repository(str(sys.argv[1])) +repo_path = str(sys.argv[1]) +repo = pygit2.Repository(repo_path) ident = pygit2.Signature('comfyui', 'comfy@ui') try: print("stashing current changes") @@ -51,7 +55,10 @@ except KeyError: print("nothing to stash") backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) print("creating backup branch: {}".format(backup_branch_name)) -repo.branches.local.create(backup_branch_name, repo.head.peel()) +try: + repo.branches.local.create(backup_branch_name, repo.head.peel()) +except: + pass print("checking out master branch") branch = repo.lookup_branch('master') @@ -63,3 +70,41 @@ pull(repo) print("Done!") +self_update = True +if len(sys.argv) > 2: + self_update = '--skip_self_update' not in sys.argv + +update_py_path = os.path.realpath(__file__) +repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py") + +cur_path = os.path.dirname(update_py_path) + + +req_path = os.path.join(cur_path, "current_requirements.txt") +repo_req_path = os.path.join(repo_path, "requirements.txt") + + +def files_equal(file1, file2): + try: + return filecmp.cmp(file1, file2, shallow=False) + except: + return False + +def file_size(f): + try: + return os.path.getsize(f) + except: + return 0 + + +if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10: + shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py")) + exit() + +if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path): + import subprocess + try: + subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path]) + shutil.copy(repo_req_path, req_path) + except: + pass diff --git a/.ci/update_windows/update_comfyui.bat b/.ci/update_windows/update_comfyui.bat index 60d1e694..bb08c0de 100755 --- a/.ci/update_windows/update_comfyui.bat +++ b/.ci/update_windows/update_comfyui.bat @@ -1,2 +1,8 @@ +@echo off ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -pause +if exist update_new.py ( + move /y update_new.py update.py + echo Running updater again since it got updated. + ..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update +) +if "%~1"=="" pause diff --git a/.ci/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/update_windows/update_comfyui_and_python_dependencies.bat deleted file mode 100755 index b7308550..00000000 --- a/.ci/update_windows/update_comfyui_and_python_dependencies.bat +++ /dev/null @@ -1,3 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 xformers -r ../ComfyUI/requirements.txt pygit2 -pause diff --git a/.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat b/.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat deleted file mode 100755 index c33adc0a..00000000 --- a/.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat +++ /dev/null @@ -1,11 +0,0 @@ -@echo off -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -echo -echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff -echo You should not be running this anyways unless you really have to -echo -echo If you just want to update normally, close this and run update_comfyui.bat instead. -echo -pause -..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers -r ../ComfyUI/requirements.txt pygit2 -pause diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml index 95069175..4b8b9793 100644 --- a/.github/workflows/test-ui.yaml +++ b/.github/workflows/test-ui.yaml @@ -22,5 +22,5 @@ jobs: run: | npm ci npm run test:generate - npm test + npm test -- --verbose working-directory: ./tests-ui diff --git a/.github/workflows/windows_release_cu118_dependencies.yml b/.github/workflows/windows_release_cu118_dependencies.yml deleted file mode 100644 index 75c42b62..00000000 --- a/.github/workflows/windows_release_cu118_dependencies.yml +++ /dev/null @@ -1,71 +0,0 @@ -name: "Windows Release cu118 dependencies" - -on: - workflow_dispatch: -# push: -# branches: -# - master - -jobs: - build_dependencies: - env: - # you need at least cuda 5.0 for some of the stuff compiled here. - TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6 8.9" - FORCE_CUDA: 1 - MAX_JOBS: 1 # will crash otherwise - DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc - XFORMERS_BUILD_TYPE: "Release" - runs-on: windows-latest - steps: - - name: Cache Built Dependencies - uses: actions/cache@v3 - id: cache-cu118_python_stuff - with: - path: cu118_python_deps.tar - key: ${{ runner.os }}-build-cu118 - - - if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true' - uses: actions/checkout@v3 - - - if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true' - uses: actions/setup-python@v4 - with: - python-version: '3.10.9' - - - if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true' - uses: comfyanonymous/cuda-toolkit@test - id: cuda-toolkit - with: - cuda: '11.8.0' - # copied from xformers github - - name: Setup MSVC - uses: ilammy/msvc-dev-cmd@v1 - - name: Configure Pagefile - # windows runners will OOM with many CUDA architectures - # we cheat here with a page file - uses: al-cheb/configure-pagefile-action@v1.3 - with: - minimum-size: 2GB - # really unfortunate: https://github.com/ilammy/msvc-dev-cmd#name-conflicts-with-shell-bash - - name: Remove link.exe - shell: bash - run: rm /usr/bin/link - - - if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true' - shell: bash - run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir - python -m pip install --no-cache-dir ./temp_wheel_dir/* - echo installed basic - git clone --recurse-submodules https://github.com/facebookresearch/xformers.git - cd xformers - python -m pip install --no-cache-dir wheel setuptools twine - echo building xformers - python setup.py bdist_wheel -d ../temp_wheel_dir/ - cd .. - rm -rf xformers - ls -lah temp_wheel_dir - mv temp_wheel_dir cu118_python_deps - tar cf cu118_python_deps.tar cu118_python_deps - - diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml deleted file mode 100644 index a7760b21..00000000 --- a/.github/workflows/windows_release_cu118_dependencies_2.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: "Windows Release cu118 dependencies 2" - -on: - workflow_dispatch: - inputs: - xformers: - description: 'xformers version' - required: true - type: string - default: "xformers" - -# push: -# branches: -# - master - -jobs: - build_dependencies: - runs-on: windows-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10.9' - - - shell: bash - run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir - python -m pip install --no-cache-dir ./temp_wheel_dir/* - echo installed basic - ls -lah temp_wheel_dir - mv temp_wheel_dir cu118_python_deps - tar cf cu118_python_deps.tar cu118_python_deps - - - uses: actions/cache/save@v3 - with: - path: cu118_python_deps.tar - key: ${{ runner.os }}-build-cu118 diff --git a/.github/workflows/windows_release_cu118_package.yml b/.github/workflows/windows_release_cu118_package.yml deleted file mode 100644 index 0f0fbf28..00000000 --- a/.github/workflows/windows_release_cu118_package.yml +++ /dev/null @@ -1,79 +0,0 @@ -name: "Windows Release cu118 packaging" - -on: - workflow_dispatch: -# push: -# branches: -# - master - -jobs: - package_comfyui: - permissions: - contents: "write" - packages: "write" - pull-requests: "read" - runs-on: windows-latest - steps: - - uses: actions/cache/restore@v3 - id: cache - with: - path: cu118_python_deps.tar - key: ${{ runner.os }}-build-cu118 - - shell: bash - run: | - mv cu118_python_deps.tar ../ - cd .. - tar xf cu118_python_deps.tar - pwd - ls - - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - shell: bash - run: | - cd .. - cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip - unzip python_embeded.zip -d python_embeded - cd python_embeded - echo 'import site' >> ./python310._pth - curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py - ./python.exe get-pip.py - ./python.exe -s -m pip install ../cu118_python_deps/* - sed -i '1i../ComfyUI' ./python310._pth - cd .. - - git clone https://github.com/comfyanonymous/taesd - cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ - - mkdir ComfyUI_windows_portable - mv python_embeded ComfyUI_windows_portable - mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI - - cd ComfyUI_windows_portable - - mkdir update - cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/update_windows_cu118/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ - - cd .. - - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable - mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z - - cd ComfyUI_windows_portable - python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu - - ls - - - name: Upload binaries to release - uses: svenstaro/upload-release-action@v2 - with: - repo_token: ${{ secrets.GITHUB_TOKEN }} - file: new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z - tag: "latest" - overwrite: true - diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index aafe8a21..ffd3e221 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -24,7 +24,7 @@ on: description: 'python patch version' required: true type: string - default: "6" + default: "8" # push: # branches: # - master @@ -41,10 +41,9 @@ jobs: - shell: bash run: | echo "@echo off - ..\python_embeded\python.exe .\update.py ..\ComfyUI\\ + call update_comfyui.bat nopause echo - - echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff - echo You should not be running this anyways unless you really have to + echo This will try to update pytorch and all python dependencies. echo - echo If you just want to update normally, close this and run update_comfyui.bat instead. echo - diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index b793f7fe..672a7f22 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch" on: workflow_dispatch: + inputs: + cu: + description: 'cuda version' + required: true + type: string + default: "121" + + python_minor: + description: 'python minor version' + required: true + type: string + default: "12" + + python_patch: + description: 'python patch version' + required: true + type: string + default: "2" # push: # branches: # - master @@ -20,21 +38,21 @@ jobs: persist-credentials: false - uses: actions/setup-python@v4 with: - python-version: '3.11.6' + python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }} - shell: bash run: | cd .. cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip + curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip unzip python_embeded.zip -d python_embeded cd python_embeded - echo 'import site' >> ./python311._pth + echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* - sed -i '1i../ComfyUI' ./python311._pth + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth cd .. git clone https://github.com/comfyanonymous/taesd @@ -49,9 +67,10 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ - cp -r ComfyUI/.ci/nightly/update_windows/* ./update/ - cp -r ComfyUI/.ci/nightly/windows_base_files/* ./ + echo "call update_comfyui.bat nopause + ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 + pause" > ./update/update_comfyui_and_python_dependencies.bat cd .. "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 87d37c24..4e3cdabd 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "6" + default: "8" # push: # branches: # - master diff --git a/.gitignore b/.gitignore index 43c038e4..9f038924 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ venv/ /web/extensions/* !/web/extensions/logging.js.example !/web/extensions/core/ -/tests-ui/data/object_info.json \ No newline at end of file +/tests-ui/data/object_info.json +/user/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 202121e1..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "path-intellisense.mappings": { - "../": "${workspaceFolder}/web/extensions/core" - }, - "[python]": { - "editor.defaultFormatter": "ms-python.autopep8" - }, - "python.formatting.provider": "none" -} diff --git a/README.md b/README.md index 450a012b..a94a212a 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. -- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) +- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) @@ -77,6 +77,8 @@ There is a portable standalone build for Windows that should work for running on Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints +If you have trouble extracting it, right click the file -> properties -> unblock + #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. @@ -93,23 +95,26 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae -Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier. ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7``` -This is the command to install the nightly with ROCm 5.7 that might have some performance improvements: +This is the command to install the nightly with ROCm 6.0 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0``` ### NVIDIA -Nvidia users should install pytorch using this command: +Nvidia users should install stable pytorch using this command: ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121``` +This is the command to install pytorch nightly instead which might have performance improvements: + +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121``` + #### Troubleshooting If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with: diff --git a/app/app_settings.py b/app/app_settings.py new file mode 100644 index 00000000..8c6edc56 --- /dev/null +++ b/app/app_settings.py @@ -0,0 +1,54 @@ +import os +import json +from aiohttp import web + + +class AppSettings(): + def __init__(self, user_manager): + self.user_manager = user_manager + + def get_settings(self, request): + file = self.user_manager.get_request_user_filepath( + request, "comfy.settings.json") + if os.path.isfile(file): + with open(file) as f: + return json.load(f) + else: + return {} + + def save_settings(self, request, settings): + file = self.user_manager.get_request_user_filepath( + request, "comfy.settings.json") + with open(file, "w") as f: + f.write(json.dumps(settings, indent=4)) + + def add_routes(self, routes): + @routes.get("/settings") + async def get_settings(request): + return web.json_response(self.get_settings(request)) + + @routes.get("/settings/{id}") + async def get_setting(request): + value = None + settings = self.get_settings(request) + setting_id = request.match_info.get("id", None) + if setting_id and setting_id in settings: + value = settings[setting_id] + return web.json_response(value) + + @routes.post("/settings") + async def post_settings(request): + settings = self.get_settings(request) + new_settings = await request.json() + self.save_settings(request, {**settings, **new_settings}) + return web.Response(status=200) + + @routes.post("/settings/{id}") + async def post_setting(request): + setting_id = request.match_info.get("id", None) + if not setting_id: + return web.Response(status=400) + settings = self.get_settings(request) + settings[setting_id] = await request.json() + self.save_settings(request, settings) + return web.Response(status=200) \ No newline at end of file diff --git a/app/user_manager.py b/app/user_manager.py new file mode 100644 index 00000000..209094af --- /dev/null +++ b/app/user_manager.py @@ -0,0 +1,140 @@ +import json +import os +import re +import uuid +from aiohttp import web +from comfy.cli_args import args +from folder_paths import user_directory +from .app_settings import AppSettings + +default_user = "default" +users_file = os.path.join(user_directory, "users.json") + + +class UserManager(): + def __init__(self): + global user_directory + + self.settings = AppSettings(self) + if not os.path.exists(user_directory): + os.mkdir(user_directory) + if not args.multi_user: + print("****** User settings have been changed to be stored on the server instead of browser storage. ******") + print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") + + if args.multi_user: + if os.path.isfile(users_file): + with open(users_file) as f: + self.users = json.load(f) + else: + self.users = {} + else: + self.users = {"default": "default"} + + def get_request_user_id(self, request): + user = "default" + if args.multi_user and "comfy-user" in request.headers: + user = request.headers["comfy-user"] + + if user not in self.users: + raise KeyError("Unknown user: " + user) + + return user + + def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): + global user_directory + + if type == "userdata": + root_dir = user_directory + else: + raise KeyError("Unknown filepath type:" + type) + + user = self.get_request_user_id(request) + path = user_root = os.path.abspath(os.path.join(root_dir, user)) + + # prevent leaving /{type} + if os.path.commonpath((root_dir, user_root)) != root_dir: + return None + + parent = user_root + + if file is not None: + # prevent leaving /{type}/{user} + path = os.path.abspath(os.path.join(user_root, file)) + if os.path.commonpath((user_root, path)) != user_root: + return None + + if create_dir and not os.path.exists(parent): + os.mkdir(parent) + + return path + + def add_user(self, name): + name = name.strip() + if not name: + raise ValueError("username not provided") + user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + user_id = user_id + "_" + str(uuid.uuid4()) + + self.users[user_id] = name + + global users_file + with open(users_file, "w") as f: + json.dump(self.users, f) + + return user_id + + def add_routes(self, routes): + self.settings.add_routes(routes) + + @routes.get("/users") + async def get_users(request): + if args.multi_user: + return web.json_response({"storage": "server", "users": self.users}) + else: + user_dir = self.get_request_user_filepath(request, None, create_dir=False) + return web.json_response({ + "storage": "server", + "migrated": os.path.exists(user_dir) + }) + + @routes.post("/users") + async def post_users(request): + body = await request.json() + username = body["username"] + if username in self.users.values(): + return web.json_response({"error": "Duplicate username."}, status=400) + + user_id = self.add_user(username) + return web.json_response(user_id) + + @routes.get("/userdata/{file}") + async def getuserdata(request): + file = request.match_info.get("file", None) + if not file: + return web.Response(status=400) + + path = self.get_request_user_filepath(request, file) + if not path: + return web.Response(status=403) + + if not os.path.exists(path): + return web.Response(status=404) + + return web.FileResponse(path) + + @routes.post("/userdata/{file}") + async def post_userdata(request): + file = request.match_info.get("file", None) + if not file: + return web.Response(status=400) + + path = self.get_request_user_filepath(request, file) + if not path: + return web.Response(status=403) + + body = await request.read() + with open(path, "wb") as f: + f.write(body) + + return web.Response(status=200) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 76a525b3..5eee5a51 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -53,7 +53,7 @@ class ControlNet(nn.Module): transformer_depth_middle=None, transformer_depth_output=None, device=None, - operations=comfy.ops, + operations=comfy.ops.disable_weight_init, **kwargs, ): super().__init__() @@ -141,24 +141,24 @@ class ControlNet(nn.Module): ) ] ) - self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)]) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)]) self.input_hint_block = TimestepEmbedSequential( - operations.conv_nd(dims, hint_channels, 16, 3, padding=1), + operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 16, 3, padding=1), + operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2), + operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 32, 3, padding=1), + operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2), + operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 96, 3, padding=1), + operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2), + operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1)) + operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) self._feature_size = model_channels @@ -206,7 +206,7 @@ class ControlNet(nn.Module): ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: @@ -234,7 +234,7 @@ class ControlNet(nn.Module): ) ch = out_ch input_block_chans.append(ch) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) ds *= 2 self._feature_size += ch @@ -276,14 +276,14 @@ class ControlNet(nn.Module): operations=operations )] self.middle_block = TimestepEmbedSequential(*mid_block) - self.middle_block_out = self.make_zero_conv(ch, operations=operations) + self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self._feature_size += ch - def make_zero_conv(self, channels, operations=None): - return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) + def make_zero_conv(self, channels, operations=None, dtype=None, device=None): + return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) def forward(self, x, hint, timesteps, context, y=None, **kwargs): - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) guided_hint = self.input_hint_block(hint, emb, context) @@ -295,7 +295,7 @@ class ControlNet(nn.Module): assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for module, zero_conv in zip(self.input_blocks, self.zero_convs): if guided_hint is not None: h = module(h, emb, context) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 72fce108..353bb51e 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") -parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group = parser.add_mutually_exclusive_group() +fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") +fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") +fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") +parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.") + fpte_group = parser.add_mutually_exclusive_group() fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") @@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") - +parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") @@ -106,6 +112,11 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") +parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") + +parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") + + if comfy.options.args_parsing: args = parser.parse_args() else: @@ -116,3 +127,10 @@ if args.windows_standalone_build: if args.disable_auto_launch: args.auto_launch = False + +import logging +logging_level = logging.INFO +if args.verbose: + logging_level = logging.DEBUG + +logging.basicConfig(format="%(message)s", level=logging_level) diff --git a/comfy/clip_model.py b/comfy/clip_model.py new file mode 100644 index 00000000..14f43c56 --- /dev/null +++ b/comfy/clip_model.py @@ -0,0 +1,194 @@ +import torch +from comfy.ldm.modules.attention import optimized_attention_for_device + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device, operations): + super().__init__() + + self.heads = heads + self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x, mask=None, optimized_attention=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + out = optimized_attention(q, k, v, self.heads, mask) + return self.out_proj(out) + +ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + +class CLIPMLP(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations): + super().__init__() + self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device) + self.activation = ACTIVATIONS[activation] + self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations): + super().__init__() + self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations) + self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations) + + def forward(self, x, mask=None, optimized_attention=None): + x += self.self_attn(self.layer_norm1(x), mask, optimized_attention) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations): + super().__init__() + self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) + + def forward(self, x, mask=None, intermediate_output=None): + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask, optimized_attention) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) + self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + + x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + + pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),] + return x, i, pooled_output + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) + embed_dim = config_dict["hidden_size"] + self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class CLIPVisionEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): + super().__init__() + self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device)) + + self.patch_embedding = operations.Conv2d( + in_channels=num_channels, + out_channels=embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=False, + dtype=dtype, + device=device + ) + + num_patches = (image_size // patch_size) ** 2 + num_positions = num_patches + 1 + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, pixel_values): + embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) + return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device) + + +class CLIPVision(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations) + self.pre_layrnorm = operations.LayerNorm(embed_dim) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) + self.post_layernorm = operations.LayerNorm(embed_dim) + + def forward(self, pixel_values, attention_mask=None, intermediate_output=None): + x = self.embeddings(pixel_values) + x = self.pre_layrnorm(x) + #TODO: attention_mask? + x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) + pooled_output = self.post_layernorm(x[:, 0, :]) + return x, i, pooled_output + +class CLIPVisionModelProjection(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.vision_model = CLIPVision(config_dict, dtype, device, operations) + self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) + + def forward(self, *args, **kwargs): + x = self.vision_model(*args, **kwargs) + out = self.visual_projection(x[2]) + return (x[0], x[1], out) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 9e2e03d7..acc86be8 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,64 +1,62 @@ -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils -from .utils import load_torch_file, transformers_convert, common_upscale +from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os import torch -import contextlib +import json +import logging import comfy.ops import comfy.model_patcher import comfy.model_management import comfy.utils +import comfy.clip_model + +class Output: + def __getitem__(self, key): + return getattr(self, key) + def __setitem__(self, key, item): + setattr(self, key, item) def clip_preprocess(image, size=224): mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) - scale = (size / min(image.shape[1], image.shape[2])) - image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + scale = (size / min(image.shape[2], image.shape[3])) + image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) class ClipVisionModel(): def __init__(self, json_config): - config = CLIPVisionConfig.from_json_file(json_config) + with open(json_config) as f: + config = json.load(f) + self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() - self.dtype = torch.float32 - if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): - self.dtype = torch.float16 - - with comfy.ops.use_comfy_ops(offload_device, self.dtype): - with modeling_utils.no_init_weights(): - self.model = CLIPVisionModelWithProjection(config) - self.model.to(self.dtype) + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast) + self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) + def get_sd(self): + return self.model.state_dict() + def encode_image(self, image): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device)) - - if self.dtype != torch.float32: - precision_scope = torch.autocast - else: - precision_scope = lambda a, b: contextlib.nullcontext(a) - - with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32): - outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) - - for k in outputs: - t = outputs[k] - if t is not None: - if k == 'hidden_states': - outputs["penultimate_hidden_states"] = t[-2].cpu() - outputs["hidden_states"] = None - else: - outputs[k] = t.cpu() + pixel_values = clip_preprocess(image.to(self.load_device)).float() + out = self.model(pixel_values=pixel_values, intermediate_output=-2) + outputs = Output() + outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) + outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) + outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) return outputs def convert_to_transformers(sd, prefix): @@ -82,6 +80,9 @@ def convert_to_transformers(sd, prefix): sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd = transformers_convert(sd, prefix, "vision_model.", 48) + else: + replace_prefix = {prefix: ""} + sd = state_dict_prefix_replace(sd, replace_prefix) return sd def load_clipvision_from_sd(sd, prefix="", convert_keys=False): @@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): clip = ClipVisionModel(json_config) m, u = clip.load_sd(sd) if len(m) > 0: - print("missing clip vision:", m) + logging.warning("missing clip vision: {}".format(m)) u = set(u) keys = list(sd.keys()) for k in keys: diff --git a/comfy/conds.py b/comfy/conds.py index 6cff2518..23fa4887 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,4 +1,3 @@ -import enum import torch import math import comfy.utils diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 433381df..b6941d8c 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,13 +1,16 @@ import torch import math import os +import logging import comfy.utils import comfy.model_management import comfy.model_detection import comfy.model_patcher +import comfy.ops import comfy.cldm.cldm import comfy.t2i_adapter.adapter +import comfy.ldm.cascade.controlnet def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -34,13 +37,15 @@ class ControlBase: self.cond_hint = None self.strength = 1.0 self.timestep_percent_range = (0.0, 1.0) + self.global_average_pooling = False self.timestep_range = None + self.compression_ratio = 8 + self.upscale_algorithm = 'nearest-exact' if device is None: device = comfy.model_management.get_torch_device() self.device = device self.previous_controlnet = None - self.global_average_pooling = False def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): self.cond_hint_original = cond_hint @@ -75,6 +80,9 @@ class ControlBase: c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range + c.global_average_pooling = self.global_average_pooling + c.compression_ratio = self.compression_ratio + c.upscale_algorithm = self.upscale_algorithm def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -123,16 +131,21 @@ class ControlBase: if o[i] is None: o[i] = prev_val else: - o[i] += prev_val + if o[i].shape[0] < prev_val.shape[0]: + o[i] = prev_val + o[i] + else: + o[i] += prev_val return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.load_device = load_device + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling self.model_sampling_current = None + self.manual_cast_dtype = manual_cast_dtype def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -146,28 +159,31 @@ class ControlNet(ControlBase): else: return None + dtype = self.control_model.dtype + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + output_dtype = x_noisy.dtype - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - - context = cond['c_crossattn'] + context = cond.get('crossattn_controlnet', cond['c_crossattn']) y = cond.get('y', None) if y is not None: - y = y.to(self.control_model.dtype) + y = y.to(dtype) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) self.copy_to(c) return c @@ -185,7 +201,7 @@ class ControlNet(ControlBase): super().cleanup() class ControlLoraOps: - class Linear(torch.nn.Module): + class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -198,12 +214,13 @@ class ControlLoraOps: self.bias = None def forward(self, input): + weight, bias = comfy.ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) + return torch.nn.functional.linear(input, weight, bias) - class Conv2d(torch.nn.Module): + class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( self, in_channels, @@ -237,16 +254,11 @@ class ControlLoraOps: def forward(self, input): + weight, bias = comfy.ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) - - def conv_nd(self, dims, *args, **kwargs): - if dims == 2: - return self.Conv2d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") + return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) class ControlLora(ControlNet): @@ -260,25 +272,34 @@ class ControlLora(ControlNet): controlnet_config = model.model_config.unet_config.copy() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - controlnet_config["operations"] = ControlLoraOps() - self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + self.manual_cast_dtype = model.manual_cast_dtype dtype = model.get_dtype() - self.control_model.to(dtype) + if self.manual_cast_dtype is None: + class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): + pass + else: + class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast): + pass + dtype = self.manual_cast_dtype + + controlnet_config["operations"] = control_lora_ops + controlnet_config["dtype"] = dtype + self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) self.control_model.to(comfy.model_management.get_torch_device()) diffusion_model = model.diffusion_model sd = diffusion_model.state_dict() cm = self.control_model.state_dict() for k in sd: - weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) + weight = sd[k] try: - comfy.utils.set_attr(self.control_model, k, weight) + comfy.utils.set_attr_param(self.control_model, k, weight) except: pass for k in self.control_weights: if k not in {"lora_controlnet"}: - comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) @@ -303,9 +324,10 @@ def load_controlnet(ckpt_path, model=None): return ControlLora(controlnet_data) controlnet_config = None + supported_inference_dtypes = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - unet_dtype = comfy.model_management.unet_dtype() - controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -346,7 +368,7 @@ def load_controlnet(ckpt_path, model=None): leftover_keys = controlnet_data.keys() if len(leftover_keys) > 0: - print("leftover keys:", leftover_keys) + logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd pth_key = 'control_model.zero_convs.0.0.weight' @@ -361,12 +383,24 @@ def load_controlnet(ckpt_path, model=None): else: net = load_t2i_adapter(controlnet_data) if net is None: - print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) + logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) return net if controlnet_config is None: + model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) + supported_inference_dtypes = model_config.supported_inference_dtypes + controlnet_config = model_config.unet_config + + load_device = comfy.model_management.get_torch_device() + if supported_inference_dtypes is None: unet_dtype = comfy.model_management.unet_dtype() - controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + else: + unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + controlnet_config["operations"] = comfy.ops.manual_cast + controlnet_config["dtype"] = unet_dtype controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) @@ -384,7 +418,7 @@ def load_controlnet(ckpt_path, model=None): cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) else: - print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") class WeightsLoader(torch.nn.Module): pass @@ -393,24 +427,29 @@ def load_controlnet(ckpt_path, model=None): missing, unexpected = w.load_state_dict(controlnet_data, strict=False) else: missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) - print(missing, unexpected) - control_model = control_model.to(unet_dtype) + if len(missing) > 0: + logging.warning("missing controlnet keys: {}".format(missing)) + + if len(unexpected) > 0: + logging.debug("unexpected controlnet keys: {}".format(unexpected)) global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control class T2IAdapter(ControlBase): - def __init__(self, t2i_model, channels_in, device=None): + def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): super().__init__(device) self.t2i_model = t2i_model self.channels_in = channels_in self.control_input = None + self.compression_ratio = compression_ratio + self.upscale_algorithm = upscale_algorithm def scale_image_to(self, width, height): unshuffle_amount = self.t2i_model.unshuffle_amount @@ -430,13 +469,13 @@ class T2IAdapter(ControlBase): else: return None - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.control_input = None self.cond_hint = None - width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) + width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) if x_noisy.shape[0] != self.cond_hint.shape[0]: @@ -455,11 +494,14 @@ class T2IAdapter(ControlBase): return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) def copy(self): - c = T2IAdapter(self.t2i_model, self.channels_in) + c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) return c def load_t2i_adapter(t2i_data): + compression_ratio = 8 + upscale_algorithm = 'nearest-exact' + if 'adapter' in t2i_data: t2i_data = t2i_data['adapter'] if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format @@ -487,13 +529,22 @@ def load_t2i_adapter(t2i_data): if cin == 256 or cin == 768: xl = True model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + elif "backbone.0.0.weight" in keys: + model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) + compression_ratio = 32 + upscale_algorithm = 'bilinear' + elif "backbone.10.blocks.0.weight" in keys: + model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) + compression_ratio = 1 + upscale_algorithm = 'nearest-exact' else: return None + missing, unexpected = model_ad.load_state_dict(t2i_data) if len(missing) > 0: - print("t2i missing", missing) + logging.warning("t2i missing {}".format(missing)) if len(unexpected) > 0: - print("t2i unexpected", unexpected) + logging.debug("t2i unexpected {}".format(unexpected)) - return T2IAdapter(model_ad, model_ad.input_channels) + return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index a9eb9302..08018c54 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,5 +1,6 @@ import re import torch +import logging # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - print(f"Reshaping {k} for SD format") + logging.debug(f"Reshaping {k} for SD format") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -237,8 +238,12 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): capture_qkv_bias[k_pre][code2idx[k_code]] = v continue - relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) - new_state_dict[relabelled_key] = v + text_proj = "transformer.text_projection.weight" + if k.endswith(text_proj): + new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous() + else: + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v for k_pre, tensors in capture_qkv_weight.items(): if None in tensors: diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index c0b420e7..98b888a1 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -1,4 +1,3 @@ -import json import os import comfy.sd diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 08bf0fc9..a30d1d03 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -358,9 +358,6 @@ class UniPC: thresholding=False, max_val=1., variant='bh1', - noise_mask=None, - masked_image=None, - noise=None, ): """Construct a UniPC. @@ -372,9 +369,6 @@ class UniPC: self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val - self.noise_mask = noise_mask - self.masked_image = masked_image - self.noise = noise def dynamic_thresholding_fn(self, x0, t=None): """ @@ -391,10 +385,7 @@ class UniPC: """ Return the noise prediction model. """ - if self.noise_mask is not None: - return self.model(x, t) * self.noise_mask - else: - return self.model(x, t) + return self.model(x, t) def data_prediction_fn(self, x, t): """ @@ -409,8 +400,6 @@ class UniPC: s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s - if self.noise_mask is not None: - x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image return x0 def model_fn(self, x, t): @@ -723,8 +712,6 @@ class UniPC: assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): for step_index in trange(steps, disable=disable_pbar): - if self.noise_mask is not None: - x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] @@ -766,7 +753,7 @@ class UniPC: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x if callback is not None: - callback(step_index, model_prev_list[-1], x, steps) + callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]}) else: raise NotImplementedError() # if denoise_to_zero: @@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs): return (input - model(input, sigma_in, **kwargs)) / sigma -def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): +def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): timesteps = sigmas.clone() if sigmas[-1] == 0: timesteps = sigmas[:] @@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call timesteps = sigmas.clone() ns = SigmaConvert() - if image is not None: - img = image * ns.marginal_alpha(timesteps[0]) - if max_denoise: - noise_mult = 1.0 - else: - noise_mult = ns.marginal_std(timesteps[0]) - img += noise * noise_mult - else: - img = noise - + noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0) model_type = "noise" model_fn = model_wrapper( @@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call ) order = min(3, len(timesteps) - 2) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant) + x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x /= ns.marginal_alpha(timesteps[-1]) return x + +def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False): + return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2') \ No newline at end of file diff --git a/comfy/gligen.py b/comfy/gligen.py index 8d182839..59252276 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -1,8 +1,9 @@ import torch -from torch import nn, einsum +from torch import nn from .ldm.modules.attention import CrossAttention from inspect import isfunction - +import comfy.ops +ops = comfy.ops.manual_cast def exists(val): return val is not None @@ -22,7 +23,7 @@ def default(val, d): class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = ops.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -35,14 +36,14 @@ class FeedForward(nn.Module): inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + ops.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + ops.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module): query_dim=query_dim, context_dim=context_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( query_dim=query_dim, context_dim=query_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( - query_dim=query_dim, context_dim=query_dim, dim_head=d_head) + query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -201,11 +204,11 @@ class PositionNet(nn.Module): self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy self.linears = nn.Sequential( - nn.Linear(self.in_dim + self.position_dim, 512), + ops.Linear(self.in_dim + self.position_dim, 512), nn.SiLU(), - nn.Linear(512, 512), + ops.Linear(512, 512), nn.SiLU(), - nn.Linear(512, out_dim), + ops.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter( @@ -215,16 +218,15 @@ class PositionNet(nn.Module): def forward(self, boxes, masks, positive_embeddings): B, N, _ = boxes.shape - dtype = self.linears[0].weight.dtype - masks = masks.unsqueeze(-1).to(dtype) - positive_embeddings = positive_embeddings.to(dtype) + masks = masks.unsqueeze(-1) + positive_embeddings = positive_embeddings # embedding position (it may includes padding as placeholder) - xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C # learnable null embedding - positive_null = self.null_positive_feature.view(1, 1, -1) - xyxy_null = self.null_position_feature.view(1, 1, -1) + positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) + xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * \ @@ -251,7 +253,7 @@ class Gligen(nn.Module): def func(x, extra_options): key = extra_options["transformer_index"] module = self.module_list[key] - return module(x, objs) + return module(x, objs.to(device=x.device, dtype=x.dtype)) return func def set_position(self, latent_image_shape, position_params, device): diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 761c2e0e..7af01682 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n x = denoised if sigmas[i + 1] > 0: - x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x) return x diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index c209087e..4ca466d9 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -1,3 +1,4 @@ +import torch class LatentFormat: scale_factor = 1.0 @@ -33,3 +34,71 @@ class SDXL(LatentFormat): [-0.3112, -0.2359, -0.2076] ] self.taesd_decoder_name = "taesdxl_decoder" + +class SDXL_Playground_2_5(LatentFormat): + def __init__(self): + self.scale_factor = 0.5 + self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) + self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) + + self.latent_rgb_factors = [ + # R G B + [ 0.3920, 0.4054, 0.4549], + [-0.2634, -0.0196, 0.0653], + [ 0.0568, 0.1687, -0.0755], + [-0.3112, -0.2359, -0.2076] + ] + self.taesd_decoder_name = "taesdxl_decoder" + + def process_in(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return (latent - latents_mean) * self.scale_factor / latents_std + + def process_out(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return latent * latents_std / self.scale_factor + latents_mean + + +class SD_X4(LatentFormat): + def __init__(self): + self.scale_factor = 0.08333 + self.latent_rgb_factors = [ + [-0.2340, -0.3863, -0.3257], + [ 0.0994, 0.0885, -0.0908], + [-0.2833, -0.2349, -0.3741], + [ 0.2523, -0.0055, -0.1651] + ] + +class SC_Prior(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 + self.latent_rgb_factors = [ + [-0.0326, -0.0204, -0.0127], + [-0.1592, -0.0427, 0.0216], + [ 0.0873, 0.0638, -0.0020], + [-0.0602, 0.0442, 0.1304], + [ 0.0800, -0.0313, -0.1796], + [-0.0810, -0.0638, -0.1581], + [ 0.1791, 0.1180, 0.0967], + [ 0.0740, 0.1416, 0.0432], + [-0.1745, -0.1888, -0.1373], + [ 0.2412, 0.1577, 0.0928], + [ 0.1908, 0.0998, 0.0682], + [ 0.0209, 0.0365, -0.0092], + [ 0.0448, -0.0650, -0.1728], + [-0.1658, -0.1045, -0.1308], + [ 0.0542, 0.1545, 0.1325], + [-0.0352, -0.1672, -0.2541] + ] + +class SC_B(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 / 0.43 + self.latent_rgb_factors = [ + [ 0.1121, 0.2006, 0.1023], + [-0.2093, -0.0222, -0.0195], + [-0.3087, -0.1535, 0.0366], + [ 0.0290, -0.1574, -0.4078] + ] diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py new file mode 100644 index 00000000..124902c0 --- /dev/null +++ b/comfy/ldm/cascade/common.py @@ -0,0 +1,161 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torch.nn as nn +from comfy.ldm.modules.attention import optimized_attention + +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + + self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, q, k, v): + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + + out = optimized_attention(q, k, v, self.heads) + + return self.out_proj(out) + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) + # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + # x = self.attn(x, kv, kv, need_weights=False)[0] + x = self.attn(x, kv, kv) + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +def LayerNorm2d_op(operations): + class LayerNorm2d(operations.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return LayerNorm2d + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + operations.Linear(c_cond, c, dtype=dtype, device=device) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None): + super().__init__() + self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/comfy/ldm/cascade/controlnet.py b/comfy/ldm/cascade/controlnet.py new file mode 100644 index 00000000..5dac5939 --- /dev/null +++ b/comfy/ldm/cascade/controlnet.py @@ -0,0 +1,93 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torchvision +from torch import nn +from .common import LayerNorm2d_op + + +class CNetResBlock(nn.Module): + def __init__(self, c, dtype=None, device=None, operations=None): + super().__init__() + self.blocks = nn.Sequential( + LayerNorm2d_op(operations)(c, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c, c, kernel_size=3, padding=1), + LayerNorm2d_op(operations)(c, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c, c, kernel_size=3, padding=1), + ) + + def forward(self, x): + return x + self.blocks(x) + + +class ControlNet(nn.Module): + def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn): + super().__init__() + if bottleneck_mode is None: + bottleneck_mode = 'effnet' + self.proj_blocks = proj_blocks + if bottleneck_mode == 'effnet': + embd_channels = 1280 + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + if c_in != 3: + in_weights = self.backbone[0][0].weight.data + self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device) + if c_in > 3: + # nn.init.constant_(self.backbone[0][0].weight, 0) + self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() + else: + self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() + elif bottleneck_mode == 'simple': + embd_channels = c_in + self.backbone = nn.Sequential( + operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device), + ) + elif bottleneck_mode == 'large': + self.backbone = nn.Sequential( + operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device), + *[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)], + operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device), + ) + embd_channels = 1280 + else: + raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') + self.projections = nn.ModuleList() + for _ in range(len(proj_blocks)): + self.projections.append(nn.Sequential( + operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device), + )) + # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection + self.xl = False + self.input_channels = c_in + self.unshuffle_amount = 8 + + def forward(self, x): + x = self.backbone(x) + proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] + for i, idx in enumerate(self.proj_blocks): + proj_outputs[idx] = self.projections[i](x) + return proj_outputs diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py new file mode 100644 index 00000000..ca8867ea --- /dev/null +++ b/comfy/ldm/cascade/stage_a.py @@ -0,0 +1,255 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +from torch.autograd import Function + +class vector_quantize(Function): + @staticmethod + def forward(ctx, x, codebook): + with torch.no_grad(): + codebook_sqr = torch.sum(codebook ** 2, dim=1) + x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) + + dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) + _, indices = dist.min(dim=1) + + ctx.save_for_backward(indices, codebook) + ctx.mark_non_differentiable(indices) + + nn = torch.index_select(codebook, 0, indices) + return nn, indices + + @staticmethod + def backward(ctx, grad_output, grad_indices): + grad_inputs, grad_codebook = None, None + + if ctx.needs_input_grad[0]: + grad_inputs = grad_output.clone() + if ctx.needs_input_grad[1]: + # Gradient wrt. the codebook + indices, codebook = ctx.saved_tensors + + grad_codebook = torch.zeros_like(codebook) + grad_codebook.index_add_(0, indices, grad_output) + + return (grad_inputs, grad_codebook) + + +class VectorQuantize(nn.Module): + def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): + """ + Takes an input of variable size (as long as the last dimension matches the embedding size). + Returns one tensor containing the nearest neigbour embeddings to each of the inputs, + with the same size as the input, vq and commitment components for the loss as a touple + in the second output and the indices of the quantized vectors in the third: + quantized, (vq_loss, commit_loss), indices + """ + super(VectorQuantize, self).__init__() + + self.codebook = nn.Embedding(k, embedding_size) + self.codebook.weight.data.uniform_(-1./k, 1./k) + self.vq = vector_quantize.apply + + self.ema_decay = ema_decay + self.ema_loss = ema_loss + if ema_loss: + self.register_buffer('ema_element_count', torch.ones(k)) + self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) + + def _laplace_smoothing(self, x, epsilon): + n = torch.sum(x) + return ((x + epsilon) / (n + x.size(0) * epsilon) * n) + + def _updateEMA(self, z_e_x, indices): + mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() + elem_count = mask.sum(dim=0) + weight_sum = torch.mm(mask.t(), z_e_x) + + self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) + self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) + self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) + + self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) + + def idx2vq(self, idx, dim=-1): + q_idx = self.codebook(idx) + if dim != -1: + q_idx = q_idx.movedim(-1, dim) + return q_idx + + def forward(self, x, get_losses=True, dim=-1): + if dim != -1: + x = x.movedim(dim, -1) + z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x + z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) + vq_loss, commit_loss = None, None + if self.ema_loss and self.training: + self._updateEMA(z_e_x.detach(), indices.detach()) + # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss + z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) + if get_losses: + vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() + commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() + + z_q_x = z_q_x.view(x.shape) + if dim != -1: + z_q_x = z_q_x.movedim(-1, dim) + return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) + + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + try: + x = x + self.depthwise(x_temp) * mods[2] + except: #operation not implemented for bf16 + x_temp = self.depthwise[0](x_temp.float()).to(x.dtype) + x = x + self.depthwise[1](x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192): + super().__init__() + self.c_latent = c_latent + c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe, x, indices, vq_loss + commit_loss * 0.25 + else: + return x + + def decode(self, x): + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py new file mode 100644 index 00000000..6d2c2223 --- /dev/null +++ b/comfy/ldm/cascade/stage_b.py @@ -0,0 +1,257 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import math +import numpy as np +import torch +from torch import nn +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock + +class StageB(nn.Module): + def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, + c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True, + t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.pixels_mapper = nn.Sequential( + operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) + x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', + align_corners=True) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py new file mode 100644 index 00000000..67c1e52b --- /dev/null +++ b/comfy/ldm/cascade/stage_c.py @@ -0,0 +1,274 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +import numpy as np +import math +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock +# from .controlnet import ControlNetDeliverer + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None, + dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device) + self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pooled = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None: + next_cnet = cnet.pop() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True).to(x.dtype) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + if cnet is not None: + next_cnet = cnet.pop() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True).to(x.dtype) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + if control is not None: + cnet = control.get("input") + else: + cnet = None + + # Model Blocks + x = self.embedding(x) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py new file mode 100644 index 00000000..0cb7c49f --- /dev/null +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -0,0 +1,95 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +import torch +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) + self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225])) + + def forward(self, x): + x = x * 0.5 + 0.5 + x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + o = self.mapper(self.backbone(x)) + return o + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return (self.blocks(x) - 0.5) * 2.0 + +class StageC_coder(nn.Module): + def __init__(self): + super().__init__() + self.previewer = Previewer() + self.encoder = EfficientNetEncoder() + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.previewer(x) diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index d2f1d74a..b91ec324 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -8,6 +8,7 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri from comfy.ldm.util import instantiate_from_config from comfy.ldm.modules.ema import LitEma +import comfy.ops class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = True): @@ -161,12 +162,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine): }, **kwargs, ) - self.quant_conv = torch.nn.Conv2d( + self.quant_conv = comfy.ops.disable_weight_init.Conv2d( (1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * embed_dim, 1, ) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim def get_autoencoder_params(self) -> list: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f6845238..f116efee 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1,12 +1,10 @@ -from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any -from functools import partial - +import logging from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -19,10 +17,11 @@ if model_management.xformers_enabled(): from comfy.cli_args import args import comfy.ops +ops = comfy.ops.disable_weight_init # CrossAttn precision handling if args.dont_upcast_attention: - print("disabling upcasting of attention") + logging.info("disabling upcasting of attention") _ATTN_PRECISION = "fp16" else: _ATTN_PRECISION = "fp32" @@ -55,7 +54,7 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops): super().__init__() self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) @@ -65,7 +64,7 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) @@ -83,16 +82,6 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) @@ -113,19 +102,25 @@ def attention_basic(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale del q, k if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + if mask.dtype == torch.bool: + mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + else: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + sim.add_(mask) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) @@ -176,6 +171,13 @@ def attention_sub_quad(query, key, value, heads, mask=None): if query_chunk_size is None: query_chunk_size = 512 + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + hidden_states = efficient_dot_product_attention( query, key, @@ -185,6 +187,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): kv_chunk_size_min=kv_chunk_size_min, use_checkpoint=False, upcast_attention=upcast_attention, + mask=mask, ) hidden_states = hidden_states.to(dtype) @@ -233,6 +236,13 @@ def attention_split(q, k, v, heads, mask=None): raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) first_op_done = False cleared_cache = False @@ -247,6 +257,12 @@ def attention_split(q, k, v, heads, mask=None): else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + if mask is not None: + if len(mask.shape) == 2: + s1 += mask[i:end] + else: + s1 += mask[:, i:end] + s2 = s1.softmax(dim=-1).to(v.dtype) del s1 first_op_done = True @@ -259,12 +275,12 @@ def attention_split(q, k, v, heads, mask=None): model_management.soft_empty_cache(True) if cleared_cache == False: cleared_cache = True - print("out of memory error, emptying cache and trying again") + logging.warning("out of memory error, emptying cache and trying again") continue steps *= 2 if steps > 64: raise e - print("out of memory error, increasing steps and trying again", steps) + logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) else: raise e @@ -302,11 +318,14 @@ def attention_xformers(q, k, v, heads, mask=None): (q, k, v), ) - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + if mask is not None: + pad = 8 - q.shape[1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) + mask_out[:, :, :mask.shape[-1]] = mask + mask = mask_out[:, :, :mask.shape[-1]] + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - if exists(mask): - raise NotImplementedError out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) @@ -331,27 +350,41 @@ def attention_pytorch(q, k, v, heads, mask=None): optimized_attention = attention_basic -optimized_attention_masked = attention_basic if model_management.xformers_enabled(): - print("Using xformers cross attention") + logging.info("Using xformers cross attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): - print("Using pytorch cross attention") + logging.info("Using pytorch cross attention") optimized_attention = attention_pytorch else: if args.use_split_cross_attention: - print("Using split optimization for cross attention") + logging.info("Using split optimization for cross attention") optimized_attention = attention_split else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -if model_management.pytorch_attention_enabled(): - optimized_attention_masked = attention_pytorch +optimized_attention_masked = optimized_attention + +def optimized_attention_for_device(device, mask=False, small_input=False): + if small_input: + if model_management.pytorch_attention_enabled(): + return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases + else: + return attention_basic + + if device == torch.device("cpu"): + return attention_sub_quad + + if mask: + return optimized_attention_masked + + return optimized_attention + class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -384,7 +417,7 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, - disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops): + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): super().__init__() self.ff_in = ff_in or inner_dim is not None @@ -394,7 +427,7 @@ class BasicTransformerBlock(nn.Module): self.is_res = inner_dim == dim if self.ff_in: - self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) + self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.disable_self_attn = disable_self_attn @@ -414,10 +447,10 @@ class BasicTransformerBlock(nn.Module): self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head @@ -553,13 +586,13 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None, device=None, operations=comfy.ops): + use_checkpoint=True, dtype=None, device=None, operations=ops): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) if not use_linear: self.proj_in = operations.Conv2d(in_channels, inner_dim, @@ -627,7 +660,7 @@ class SpatialVideoTransformer(SpatialTransformer): disable_self_attn=False, disable_temporal_crossattention=False, max_time_embed_period: int = 10000, - dtype=None, device=None, operations=comfy.ops + dtype=None, device=None, operations=ops ): super().__init__( in_channels, diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index f23417fd..fabc5c5e 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -5,9 +5,11 @@ import torch.nn as nn import numpy as np from einops import rearrange from typing import Optional, Any +import logging from comfy import model_management import comfy.ops +ops = comfy.ops.disable_weight_init if model_management.xformers_enabled_vae(): import xformers @@ -40,7 +42,7 @@ def nonlinearity(x): def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): @@ -48,7 +50,7 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = comfy.ops.Conv2d(in_channels, + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, @@ -78,7 +80,7 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = comfy.ops.Conv2d(in_channels, + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, @@ -105,30 +107,30 @@ class ResnetBlock(nn.Module): self.swish = torch.nn.SiLU(inplace=True) self.norm1 = Normalize(in_channels) - self.conv1 = comfy.ops.Conv2d(in_channels, + self.conv1 = ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = comfy.ops.Linear(temb_channels, + self.temb_proj = ops.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) - self.conv2 = comfy.ops.Conv2d(out_channels, + self.conv2 = ops.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = comfy.ops.Conv2d(in_channels, + self.conv_shortcut = ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = comfy.ops.Conv2d(in_channels, + self.nin_shortcut = ops.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, @@ -189,7 +191,7 @@ def slice_attention(q, k, v): steps *= 2 if steps > 128: raise e - print("out of memory error, increasing steps and trying again", steps) + logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) return r1 @@ -234,7 +236,7 @@ def pytorch_attention(q, k, v): out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(B, C, H, W) except model_management.OOM_EXCEPTION as e: - print("scaled_dot_product_attention OOMed: switched to slice attention") + logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) return out @@ -245,35 +247,35 @@ class AttnBlock(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = comfy.ops.Conv2d(in_channels, + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = comfy.ops.Conv2d(in_channels, + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = comfy.ops.Conv2d(in_channels, + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = comfy.ops.Conv2d(in_channels, + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) if model_management.xformers_enabled_vae(): - print("Using xformers attention in VAE") + logging.info("Using xformers attention in VAE") self.optimized_attention = xformers_attention elif model_management.pytorch_attention_enabled(): - print("Using pytorch attention in VAE") + logging.info("Using pytorch attention in VAE") self.optimized_attention = pytorch_attention else: - print("Using split attention in VAE") + logging.info("Using split attention in VAE") self.optimized_attention = normal_attention def forward(self, x): @@ -312,14 +314,14 @@ class Model(nn.Module): # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList([ - comfy.ops.Linear(self.ch, + ops.Linear(self.ch, self.temb_ch), - comfy.ops.Linear(self.temb_ch, + ops.Linear(self.temb_ch, self.temb_ch), ]) # downsampling - self.conv_in = comfy.ops.Conv2d(in_channels, + self.conv_in = ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -388,7 +390,7 @@ class Model(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = ops.Conv2d(block_in, out_ch, kernel_size=3, stride=1, @@ -461,7 +463,7 @@ class Encoder(nn.Module): self.in_channels = in_channels # downsampling - self.conv_in = comfy.ops.Conv2d(in_channels, + self.conv_in = ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -506,7 +508,7 @@ class Encoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = ops.Conv2d(block_in, 2*z_channels if double_z else z_channels, kernel_size=3, stride=1, @@ -541,7 +543,7 @@ class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - conv_out_op=comfy.ops.Conv2d, + conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, **ignorekwargs): @@ -561,11 +563,11 @@ class Decoder(nn.Module): block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( + logging.debug("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = comfy.ops.Conv2d(z_channels, + self.conv_in = ops.Conv2d(z_channels, block_in, kernel_size=3, stride=1, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 48264892..d782eff3 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -1,24 +1,22 @@ from abc import abstractmethod -import math -import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from functools import partial +import logging from .util import ( checkpoint, avg_pool_nd, zero_module, - normalization, timestep_embedding, AlphaBlender, ) from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists import comfy.ops +ops = comfy.ops.disable_weight_init class TimestepBlock(nn.Module): """ @@ -70,7 +68,7 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -106,7 +104,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -159,7 +157,7 @@ class ResBlock(TimestepBlock): skip_t_emb=False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): super().__init__() self.channels = channels @@ -177,7 +175,7 @@ class ResBlock(TimestepBlock): padding = kernel_size // 2 self.in_layers = nn.Sequential( - nn.GroupNorm(32, channels, dtype=dtype, device=device), + operations.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) @@ -206,12 +204,11 @@ class ResBlock(TimestepBlock): ), ) self.out_layers = nn.Sequential( - nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), + operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) - ), + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) + , ) if self.out_channels == channels: @@ -285,7 +282,7 @@ class VideoResBlock(ResBlock): down: bool = False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): super().__init__( channels, @@ -363,7 +360,7 @@ def apply_control(h, control, name): try: h += ctrl except: - print("warning control could not be applied", h.shape, ctrl.shape) + logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape)) return h class UNetModel(nn.Module): @@ -435,12 +432,9 @@ class UNetModel(nn.Module): disable_temporal_crossattention=False, max_ddpm_temb_period=10000, device=None, - operations=comfy.ops, + operations=ops, ): super().__init__() - assert use_spatial_transformer == True, "use_spatial_transformer has to be true" - if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' if context_dim is not None: assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' @@ -457,7 +451,6 @@ class UNetModel(nn.Module): if num_head_channels == -1: assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' - self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels @@ -492,7 +485,6 @@ class UNetModel(nn.Module): self.predict_codebook_ids = n_embed is not None self.default_num_video_frames = None - self.default_image_only_indicator = None time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( @@ -503,9 +495,9 @@ class UNetModel(nn.Module): if self.num_classes is not None: if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) elif self.num_classes == "continuous": - print("setting up linear c_adm embedding layer") + logging.debug("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "sequential": assert adm_in_channels is not None @@ -582,7 +574,7 @@ class UNetModel(nn.Module): up=False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): if self.use_temporal_resblocks: return VideoResBlock( @@ -716,27 +708,30 @@ class UNetModel(nn.Module): device=device, operations=operations )] - if transformer_depth_middle >= 0: - mid_block += [get_attention_layer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint - ), - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - time_embed_dim=time_embed_dim, - dropout=dropout, - out_channels=None, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype, - device=device, - operations=operations - )] - self.middle_block = TimestepEmbedSequential(*mid_block) + + self.middle_block = None + if transformer_depth_middle >= -1: + if transformer_depth_middle >= 0: + mid_block += [get_attention_layer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint + ), + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + operations=operations + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -810,13 +805,13 @@ class UNetModel(nn.Module): self._feature_size += ch self.out = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) @@ -835,21 +830,21 @@ class UNetModel(nn.Module): transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) - image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + image_only_indicator = kwargs.get("image_only_indicator", None) time_context = kwargs.get("time_context", None) assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) @@ -866,7 +861,8 @@ class UNetModel(nn.Module): h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + if self.middle_block is not None: + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 709a7f52..f5ac7c2f 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module): self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + def q_sample(self, x_start, t, noise=None, seed=None): + if noise is None: + if seed is None: + noise = torch.randn_like(x_start) + else: + noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device) + return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise) def forward(self, x): return x, None @@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): super().__init__(noise_schedule_config=noise_schedule_config) self.max_noise_level = max_noise_level - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) return z, noise_level diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 704bbe57..ce14ad5e 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -16,7 +16,6 @@ import numpy as np from einops import repeat, rearrange from comfy.ldm.util import instantiate_from_config -import comfy.ops class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] @@ -47,23 +46,25 @@ class AlphaBlender(nn.Module): else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") - def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor: # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor + alpha = self.mix_factor.to(device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor) + alpha = torch.sigmoid(self.mix_factor.to(device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": - assert image_only_indicator is not None, "need image_only_indicator ..." - alpha = torch.where( - image_only_indicator.bool(), - torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), - ) + if image_only_indicator is None: + alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1") + else: + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), + ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) @@ -77,7 +78,7 @@ class AlphaBlender(nn.Module): x_temporal, image_only_indicator=None, ) -> torch.Tensor: - alpha = self.get_alpha(image_only_indicator) + alpha = self.get_alpha(image_only_indicator, x_spatial.device) x = ( alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal @@ -99,7 +100,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) + betas = torch.clamp(betas, min=0, max=0.999) elif schedule == "squaredcos_cap_v2": # used for karlo prior # return early @@ -114,7 +115,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() + return betas def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): @@ -273,46 +274,6 @@ def mean_flat(tensor): return tensor.mean(dim=list(range(1, len(tensor.shape)))) -def normalization(channels, dtype=None): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels, dtype=dtype) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return comfy.ops.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return comfy.ops.Linear(*args, **kwargs) - - def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index b59bf204..a5d86603 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -15,21 +15,21 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): def scale(self, x): # re-normalize to centered mean and unit variance - x = (x - self.data_mean) * 1. / self.data_std + x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) return x def unscale(self, x): # back to original data stats - x = (x * self.data_std) + self.data_mean + x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return x - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) x = self.scale(x) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) z = self.unscale(z) noise_level = self.time_embed(noise_level) return z, noise_level diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 8e8e8054..1bc4138c 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -14,6 +14,7 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math +import logging try: from typing import Optional, NamedTuple, List, Protocol @@ -61,6 +62,7 @@ def _summarize_chunk( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> AttnChunk: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -84,6 +86,8 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() attn_weights -= max_score + if mask is not None: + attn_weights += mask torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) @@ -96,11 +100,12 @@ def _query_chunk_attention( value: Tensor, summarize_chunk: SummarizeChunk, kv_chunk_size: int, + mask, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape - def chunk_scanner(chunk_idx: int) -> AttnChunk: + def chunk_scanner(chunk_idx: int, mask) -> AttnChunk: key_chunk = dynamic_slice( key_t, (0, 0, chunk_idx), @@ -111,10 +116,13 @@ def _query_chunk_attention( (0, chunk_idx, 0), (batch_x_heads, kv_chunk_size, v_channels_per_head) ) - return summarize_chunk(query, key_chunk, value_chunk) + if mask is not None: + mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size] + + return summarize_chunk(query, key_chunk, value_chunk, mask=mask) chunks: List[AttnChunk] = [ - chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk @@ -135,6 +143,7 @@ def _get_attention_scores_no_kv_chunking( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> Tensor: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -156,11 +165,13 @@ def _get_attention_scores_no_kv_chunking( beta=0, ) + if mask is not None: + attn_scores += mask try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores except model_management.OOM_EXCEPTION: - print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") + logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values torch.exp(attn_scores, out=attn_scores) summed = torch.sum(attn_scores, dim=-1, keepdim=True) @@ -183,6 +194,7 @@ def efficient_dot_product_attention( kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, upcast_attention=False, + mask = None, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -209,13 +221,22 @@ def efficient_dot_product_attention( if kv_chunk_size_min is not None: kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + if mask is not None and len(mask.shape) == 2: + mask = mask.unsqueeze(0) + def get_query_chunk(chunk_idx: int) -> Tensor: return dynamic_slice( query, (0, chunk_idx, 0), (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) - + + def get_mask_chunk(chunk_idx: int) -> Tensor: + if mask is None: + return None + chunk = min(query_chunk_size, q_tokens) + return mask[:,chunk_idx:chunk_idx + chunk] + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( @@ -237,6 +258,7 @@ def efficient_dot_product_attention( query=query, key_t=key_t, value=value, + mask=mask, ) # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, @@ -246,6 +268,7 @@ def efficient_dot_product_attention( query=get_query_chunk(i * query_chunk_size), key_t=key_t, value=value, + mask=get_mask_chunk(i * query_chunk_size) ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1) return res diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py index 11ae049f..2992aeaf 100644 --- a/comfy/ldm/modules/temporal_ae.py +++ b/comfy/ldm/modules/temporal_ae.py @@ -5,6 +5,7 @@ import torch from einops import rearrange, repeat import comfy.ops +ops = comfy.ops.disable_weight_init from .diffusionmodules.model import ( AttnBlock, @@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock): x = self.time_stack(x, temb) - alpha = self.get_alpha(bs=b // timesteps) + alpha = self.get_alpha(bs=b // timesteps).to(x.device) x = alpha * x + (1.0 - alpha) * x_mix x = rearrange(x, "b c t h w -> (b t) c h w") return x -class AE3DConv(torch.nn.Conv2d): +class AE3DConv(ops.Conv2d): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): super().__init__(in_channels, out_channels, *args, **kwargs) if isinstance(video_kernel_size, Iterable): @@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d): else: padding = int(video_kernel_size // 2) - self.time_mix_conv = torch.nn.Conv3d( + self.time_mix_conv = ops.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=video_kernel_size, @@ -130,9 +131,9 @@ class AttnVideoBlock(AttnBlock): time_embed_dim = self.in_channels * 4 self.video_time_embed = torch.nn.Sequential( - comfy.ops.Linear(self.in_channels, time_embed_dim), + ops.Linear(self.in_channels, time_embed_dim), torch.nn.SiLU(), - comfy.ops.Linear(time_embed_dim, self.in_channels), + ops.Linear(time_embed_dim, self.in_channels), ) self.merge_strategy = merge_strategy @@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock): emb = emb[:, None, :] x_mix = x_mix + emb - alpha = self.get_alpha() + alpha = self.get_alpha().to(x.device) x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge diff --git a/comfy/lora.py b/comfy/lora.py index 29c59d89..096285bb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -1,4 +1,5 @@ import comfy.utils +import logging LORA_CLIP_MAP = { "mlp.fc1": "mlp_fc1", @@ -20,6 +21,12 @@ def load_lora(lora, to_load): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) + dora_scale_name = "{}.dora_scale".format(x) + dora_scale = None + if dora_scale_name in lora.keys(): + dora_scale = lora[dora_scale_name] + loaded_keys.add(dora_scale_name) + regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) @@ -43,7 +50,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -64,7 +71,7 @@ def load_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -116,8 +123,19 @@ def load_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) - + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) + + #glora + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) w_norm_name = "{}.w_norm".format(x) b_norm_name = "{}.b_norm".format(x) @@ -126,26 +144,26 @@ def load_lora(lora, to_load): if w_norm is not None: loaded_keys.add(w_norm_name) - patch_dict[to_load[x]] = (w_norm,) + patch_dict[to_load[x]] = ("diff", (w_norm,)) if b_norm is not None: loaded_keys.add(b_norm_name) - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) diff_name = "{}.diff".format(x) diff_weight = lora.get(diff_name, None) if diff_weight is not None: - patch_dict[to_load[x]] = (diff_weight,) + patch_dict[to_load[x]] = ("diff", (diff_weight,)) loaded_keys.add(diff_name) diff_bias_name = "{}.diff_b".format(x) diff_bias = lora.get(diff_bias_name, None) if diff_bias is not None: - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) loaded_keys.add(diff_bias_name) for x in lora.keys(): if x not in loaded_keys: - print("lora key not loaded", x) + logging.warning("lora key not loaded: {}".format(x)) return patch_dict def model_lora_keys_clip(model, key_map={}): @@ -186,6 +204,15 @@ def model_lora_keys_clip(model, key_map={}): key_map[lora_key] = k lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora key_map[lora_key] = k + lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config + key_map[lora_key] = k + + + k = "clip_g.transformer.text_projection.weight" + if k in sdk: + key_map["lora_prior_te_text_projection"] = k #cascade lora? + # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too + # key_map["lora_te_text_projection"] = k return key_map @@ -196,6 +223,7 @@ def model_lora_keys_unet(model, key_map={}): if k.startswith("diffusion_model.") and k.endswith(".weight"): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = k + key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: diff --git a/comfy/model_base.py b/comfy/model_base.py index 786c9cf4..bc019de5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,9 +1,13 @@ import torch -from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +import logging +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep +from comfy.ldm.cascade.stage_c import StageC +from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation -from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep +from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation import comfy.model_management import comfy.conds +import comfy.ops from enum import Enum from . import utils @@ -11,9 +15,11 @@ class ModelType(Enum): EPS = 1 V_PREDICTION = 2 V_PREDICTION_EDM = 3 + STABLE_CASCADE = 4 + EDM = 5 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM +from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -26,6 +32,12 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.V_PREDICTION_EDM: c = V_PREDICTION s = ModelSamplingContinuousEDM + elif model_type == ModelType.STABLE_CASCADE: + c = EPS + s = StableCascadeSampling + elif model_type == ModelType.EDM: + c = EDM + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass @@ -34,15 +46,20 @@ def model_sampling(model_config, model_type): class BaseModel(torch.nn.Module): - def __init__(self, model_config, model_type=ModelType.EPS, device=None): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config + self.manual_cast_dtype = model_config.manual_cast_dtype if not unet_config.get("disable_unet_model_creation", False): - self.diffusion_model = UNetModel(**unet_config, device=device) + if self.manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops.disable_weight_init + self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -50,8 +67,8 @@ class BaseModel(torch.nn.Module): if self.adm_channels is None: self.adm_channels = 0 self.inpaint_model = False - print("model_type", model_type.name) - print("adm", self.adm_channels) + logging.info("model_type {}".format(model_type.name)) + logging.debug("adm {}".format(self.adm_channels)) def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t @@ -61,15 +78,21 @@ class BaseModel(torch.nn.Module): context = c_crossattn dtype = self.get_dtype() + + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() context = context.to(dtype) extra_conds = {} for o in kwargs: extra = kwargs[o] - if hasattr(extra, "to"): - extra = extra.to(dtype) + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) extra_conds[o] = extra + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) @@ -87,11 +110,29 @@ class BaseModel(torch.nn.Module): if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] - denoise_mask = kwargs.get("denoise_mask", None) - latent_image = kwargs.get("latent_image", None) + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + concat_latent_image = kwargs.get("concat_latent_image", None) + if concat_latent_image is None: + concat_latent_image = kwargs.get("latent_image", None) + else: + concat_latent_image = self.process_latent_in(concat_latent_image) + noise = kwargs.get("noise", None) device = kwargs["device"] + if concat_latent_image.shape[1:] != noise.shape[1:]: + concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) + + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] + + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space @@ -104,9 +145,9 @@ class BaseModel(torch.nn.Module): for ck in concat_keys: if denoise_mask is not None: if ck == "mask": - cond_concat.append(denoise_mask[:,:1].to(device)) + cond_concat.append(denoise_mask.to(device)) elif ck == "masked_image": - cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) @@ -114,9 +155,23 @@ class BaseModel(torch.nn.Module): cond_concat.append(blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = comfy.conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = comfy.conds.CONDRegular(adm) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + + cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) + if cross_attn_cnet is not None: + out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet) + + c_concat = kwargs.get("noise_concat", None) + if c_concat is not None: + out['c_concat'] = comfy.conds.CONDNoiseShape(data) + return out def load_model_weights(self, sd, unet_prefix=""): @@ -129,10 +184,10 @@ class BaseModel(torch.nn.Module): to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: - print("unet missing:", m) + logging.warning("unet missing: {}".format(m)) if len(u) > 0: - print("unet unexpected:", u) + logging.warning("unet unexpected: {}".format(u)) del to_load return self @@ -142,39 +197,47 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict, vae_state_dict): - clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) - unet_sd = self.diffusion_model.state_dict() - unet_state_dict = {} - for k in unet_sd: - unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + extra_sds = [] + if clip_state_dict is not None: + extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) + if vae_state_dict is not None: + extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) + if clip_vision_state_dict is not None: + extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) + unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) + if self.get_dtype() == torch.float16: - clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) - vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds) if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) - return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + for sd in extra_sds: + unet_state_dict.update(sd) + + return unet_state_dict def set_inpaint(self): self.inpaint_model = True def memory_required(self, input_shape): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): + dtype = self.get_dtype() + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype #TODO: this needs to be tweaked area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) + return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * input_shape[2] * input_shape[3] return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) -def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): +def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] weights = [] noise_aug = [] @@ -183,7 +246,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge weight = unclip_cond["strength"] noise_augment = unclip_cond["noise_augmentation"] noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight weights.append(weight) noise_aug.append(noise_augment) @@ -209,11 +272,11 @@ class SD21UNCLIP(BaseModel): if unclip_conditioning is None: return torch.zeros((1, self.adm_channels)) else: - return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) + return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: - return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] + return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: return args["pooled_output"] @@ -303,13 +366,157 @@ class SVD_img2vid(BaseModel): if latent_image.shape[1:] != noise.shape[1:]: latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") - latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0]) + latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0]) out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + if "time_conditioning" in kwargs: out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) - out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) return out + +class SV3D_u(SVD_img2vid): + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + +class SV3D_p(SVD_img2vid): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder_512 = Timestep(512) + + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here + azimuth = kwargs.get("azimuth", 0) + noise = kwargs.get("noise", None) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0)))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0)))) + + out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) + return torch.cat(out, dim=1) + + +class Stable_Zero123(BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): + super().__init__(model_config, model_type, device=device) + self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) + self.cc_projection.weight.copy_(cc_projection_weight) + self.cc_projection.bias.copy_(cc_projection_bias) + + def extra_conds(self, **kwargs): + out = {} + + latent_image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + + if latent_image is None: + latent_image = torch.zeros_like(noise) + + if latent_image.shape[1:] != noise.shape[1:]: + latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + if cross_attn.shape[-1] != 768: + cross_attn = self.cc_projection(cross_attn) + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + return out + +class SD_X4Upscaler(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): + super().__init__(model_config, model_type, device=device) + self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350) + + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_image", None) + noise = kwargs.get("noise", None) + noise_augment = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + seed = kwargs["seed"] - 10 + + noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) + + if image is None: + image = torch.zeros_like(noise)[:,:3] + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + noise_level = torch.tensor([noise_level], device=device) + if noise_augment > 0: + image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed) + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(image) + out['y'] = comfy.conds.CONDRegular(noise_level) + return out + +class StableCascade_C(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageC) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + + if "unclip_conditioning" in kwargs: + embeds = [] + for unclip_cond in kwargs["unclip_conditioning"]: + weight = unclip_cond["strength"] + embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight) + clip_img = torch.cat(embeds, dim=1) + else: + clip_img = torch.zeros((1, 1, 768)) + out["clip_img"] = comfy.conds.CONDRegular(clip_img) + out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) + out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,))) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn) + return out + + +class StableCascade_B(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageB) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + noise = kwargs.get("noise", None) + + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip'] = comfy.conds.CONDRegular(clip_text_pooled) + + #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched + prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) + + out["effnet"] = comfy.conds.CONDRegular(prior) + out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index c682c3e1..b7c3be30 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,5 +1,6 @@ import comfy.supported_models import comfy.supported_models_base +import logging def count_blocks(state_dict_keys, prefix_string): count = 0 @@ -28,13 +29,41 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return None -def detect_unet_config(state_dict, key_prefix, dtype): +def detect_unet_config(state_dict, key_prefix): state_dict_keys = list(state_dict.keys()) + if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade + unet_config = {} + text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) + if text_mapper_name in state_dict_keys: + unet_config['stable_cascade_stage'] = 'c' + w = state_dict[text_mapper_name] + if w.shape[0] == 1536: #stage c lite + unet_config['c_cond'] = 1536 + unet_config['c_hidden'] = [1536, 1536] + unet_config['nhead'] = [24, 24] + unet_config['blocks'] = [[4, 12], [12, 4]] + elif w.shape[0] == 2048: #stage c full + unet_config['c_cond'] = 2048 + elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: + unet_config['stable_cascade_stage'] = 'b' + w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] + if w.shape[-1] == 640: + unet_config['c_hidden'] = [320, 640, 1280, 1280] + unet_config['nhead'] = [-1, -1, 20, 20] + unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] + elif w.shape[-1] == 576: #stage b lite + unet_config['c_hidden'] = [320, 576, 1152, 1152] + unet_config['nhead'] = [-1, 9, 18, 18] + unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] + + return unet_config + unet_config = { "use_checkpoint": False, "image_size": 32, - "out_channels": 4, "use_spatial_transformer": True, "legacy": False } @@ -46,10 +75,15 @@ def detect_unet_config(state_dict, key_prefix, dtype): else: unet_config["adm_in_channels"] = None - unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] + out_key = '{}out.2.weight'.format(key_prefix) + if out_key in state_dict: + out_channels = state_dict[out_key].shape[0] + else: + out_channels = 4 + num_res_blocks = [] channel_mult = [] attention_resolutions = [] @@ -118,10 +152,13 @@ def detect_unet_config(state_dict, key_prefix, dtype): channel_mult.append(last_channel_mult) if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') - else: + elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys: transformer_depth_middle = -1 + else: + transformer_depth_middle = -2 unet_config["in_channels"] = in_channels + unet_config["out_channels"] = out_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks unet_config["transformer_depth"] = transformer_depth @@ -150,11 +187,11 @@ def model_config_from_unet_config(unet_config): if model_config.matches(unet_config): return model_config(unet_config) - print("no match", unet_config) + logging.error("no match {}".format(unet_config)) return None -def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype) +def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): + unet_config = detect_unet_config(state_dict, unet_key_prefix) model_config = model_config_from_unet_config(unet_config) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) @@ -200,7 +237,7 @@ def convert_config(unet_config): return new_config -def unet_config_from_diffusers_unet(state_dict, dtype): +def unet_config_from_diffusers_unet(state_dict, dtype=None): match = {} transformer_depth = [] @@ -208,6 +245,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype): down_blocks = count_blocks(state_dict, "down_blocks.{}") for i in range(down_blocks): attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') + res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}') for ab in range(attn_blocks): transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_depth.append(transformer_count) @@ -216,8 +254,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype): attn_res *= 2 if attn_blocks == 0: - transformer_depth.append(0) - transformer_depth.append(0) + for i in range(res_blocks): + transformer_depth.append(0) match["transformer_depth"] = transformer_depth @@ -289,7 +327,25 @@ def unet_config_from_diffusers_unet(state_dict, dtype): 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] + Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B] for unet_config in supported_models: matches = True @@ -301,8 +357,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype): return convert_config(unet_config) return None -def model_config_from_diffusers_unet(state_dict, dtype): - unet_config = unet_config_from_diffusers_unet(state_dict, dtype) +def model_config_from_diffusers_unet(state_dict): + unet_config = unet_config_from_diffusers_unet(state_dict) if unet_config is not None: return model_config_from_unet_config(unet_config) return None diff --git a/comfy/model_management.py b/comfy/model_management.py index d4acd895..715ca2ee 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,4 +1,5 @@ import psutil +import logging from enum import Enum from comfy.cli_args import args import comfy.utils @@ -28,6 +29,10 @@ total_vram = 0 lowvram_available = True xpu_available = False +if args.deterministic: + logging.info("Using deterministic algorithms for pytorch") + torch.use_deterministic_algorithms(True, warn_only=True) + directml_enabled = False if args.directml is not None: import torch_directml @@ -37,7 +42,7 @@ if args.directml is not None: directml_device = torch_directml.device() else: directml_device = torch_directml.device(device_index) - print("Using directml with device:", torch_directml.device_name(device_index)) + logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index))) # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. @@ -113,10 +118,10 @@ def get_total_memory(dev=None, torch_total_too=False): total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) -print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) if not args.normalvram and not args.cpu: if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM try: @@ -139,12 +144,10 @@ else: pass try: XFORMERS_VERSION = xformers.version.__version__ - print("xformers version:", XFORMERS_VERSION) + logging.info("xformers version: {}".format(XFORMERS_VERSION)) if XFORMERS_VERSION.startswith("0.0.18"): - print() - print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") - print("Please downgrade or upgrade xformers to a different version.") - print() + logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") + logging.warning("Please downgrade or upgrade xformers to a different version.\n") XFORMERS_ENABLED_VAE = False except: pass @@ -171,7 +174,7 @@ try: if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if torch.cuda.is_bf16_supported(): + if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: VAE_DTYPE = torch.bfloat16 if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: @@ -182,6 +185,9 @@ except: if is_intel_xpu(): VAE_DTYPE = torch.bfloat16 +if args.cpu_vae: + VAE_DTYPE = torch.float32 + if args.fp16_vae: VAE_DTYPE = torch.float16 elif args.bf16_vae: @@ -206,23 +212,16 @@ elif args.highvram or args.gpu_only: FORCE_FP32 = False FORCE_FP16 = False if args.force_fp32: - print("Forcing FP32, if this improves things please report it.") + logging.info("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True if args.force_fp16: - print("Forcing FP16.") + logging.info("Forcing FP16.") FORCE_FP16 = True if lowvram_available: - try: - import accelerate - if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): - vram_state = set_vram_to - except Exception as e: - import traceback - print(traceback.format_exc()) - print("ERROR: LOW VRAM MODE NEEDS accelerate.") - lowvram_available = False + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to if cpu_state != CPUState.GPU: @@ -231,12 +230,12 @@ if cpu_state != CPUState.GPU: if cpu_state == CPUState.MPS: vram_state = VRAMState.SHARED -print(f"Set vram state to: {vram_state.name}") +logging.info(f"Set vram state to: {vram_state.name}") DISABLE_SMART_MEMORY = args.disable_smart_memory if DISABLE_SMART_MEMORY: - print("Disabling smart memory management") + logging.info("Disabling smart memory management") def get_torch_device_name(device): if hasattr(device, 'type'): @@ -254,19 +253,27 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Device:", get_torch_device_name(get_torch_device())) + logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: - print("Could not pick default device.") + logging.warning("Could not pick default device.") -print("VAE dtype:", VAE_DTYPE) +logging.info("VAE dtype: {}".format(VAE_DTYPE)) current_loaded_models = [] +def module_size(module): + module_mem = 0 + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + return module_mem + class LoadedModel: def __init__(self, model): self.model = model - self.model_accelerated = False self.device = model.load_device + self.weights_loaded = False def model_memory(self): return self.model.model_size() @@ -278,38 +285,33 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0): - patch_model_to = None - if lowvram_model_memory == 0: - patch_model_to = self.device + patch_model_to = self.device self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) + load_weights = not self.weights_loaded + try: - self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU + if lowvram_model_memory > 0 and load_weights: + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) + else: + self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() raise e - if lowvram_model_memory > 0: - print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) - device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) - self.model_accelerated = True - if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) + self.weights_loaded = True return self.real_model - def model_unload(self): - if self.model_accelerated: - accelerate.hooks.remove_hook_from_submodules(self.real_model) - self.model_accelerated = False - - self.model.unpatch_model(self.model.offload_device) + def model_unload(self, unpatch_weights=True): + self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) + self.weights_loaded = self.weights_loaded and not unpatch_weights def __eq__(self, other): return self.model is other.model @@ -317,31 +319,57 @@ class LoadedModel: def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(model): +def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + if len(to_unload) == 0: + return None + + same_weights = 0 + for i in to_unload: + if model.clone_has_same_weights(current_loaded_models[i].model): + same_weights += 1 + + if same_weights == len(to_unload): + unload_weight = False + else: + unload_weight = True + + if not force_unload: + if unload_weights_only and unload_weight == False: + return None + for i in to_unload: - print("unload clone", i) - current_loaded_models.pop(i).model_unload() + logging.debug("unload clone {} {}".format(i, unload_weight)) + current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) + + return unload_weight def free_memory(memory_required, device, keep_loaded=[]): - unloaded_model = False + unloaded_model = [] + can_unload = [] + for i in range(len(current_loaded_models) -1, -1, -1): - if not DISABLE_SMART_MEMORY: - if get_free_memory(device) > memory_required: - break shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: - m = current_loaded_models.pop(i) - m.model_unload() - del m - unloaded_model = True + can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) - if unloaded_model: + for x in sorted(can_unload): + i = x[-1] + if not DISABLE_SMART_MEMORY: + if get_free_memory(device) > memory_required: + break + current_loaded_models[i].model_unload() + unloaded_model.append(i) + + for i in sorted(unloaded_model, reverse=True): + current_loaded_models.pop(i) + + if len(unloaded_model) > 0: soft_empty_cache() else: if vram_state != VRAMState.HIGH_VRAM: @@ -366,7 +394,7 @@ def load_models_gpu(models, memory_required=0): models_already_loaded.append(loaded_model) else: if hasattr(x, "model"): - print(f"Requested to load {x.model.__class__.__name__}") + logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) if len(models_to_load) == 0: @@ -376,17 +404,22 @@ def load_models_gpu(models, memory_required=0): free_memory(extra_mem, d, models_already_loaded) return - print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") + logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model) + unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + for loaded_model in models_to_load: + weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded + if weights_unloaded is not None: + loaded_model.weights_loaded = not weights_unloaded + for loaded_model in models_to_load: model = loaded_model.model torch_dev = model.load_device @@ -398,14 +431,14 @@ def load_models_gpu(models, memory_required=0): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM else: lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 256 * 1024 * 1024 + lowvram_model_memory = 64 * 1024 * 1024 cur_loaded_model = loaded_model.model_load(lowvram_model_memory) current_loaded_models.insert(0, loaded_model) @@ -430,6 +463,13 @@ def dtype_size(dtype): dtype_size = 4 if dtype == torch.float16 or dtype == torch.bfloat16: dtype_size = 2 + elif dtype == torch.float32: + dtype_size = 4 + else: + try: + dtype_size = dtype.itemsize + except: #Old pytorch doesn't have .itemsize + pass return dtype_size def unet_offload_device(): @@ -456,13 +496,44 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev -def unet_dtype(device=None, model_params=0): +def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if args.bf16_unet: return torch.bfloat16 - if should_use_fp16(device=device, model_params=model_params): + if args.fp16_unet: return torch.float16 + if args.fp8_e4m3fn_unet: + return torch.float8_e4m3fn + if args.fp8_e5m2_unet: + return torch.float8_e5m2 + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): + if torch.float16 in supported_dtypes: + return torch.float16 + if should_use_bf16(device, model_params=model_params, manual_cast=True): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 return torch.float32 +# None means no manual cast +def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): + if weight_dtype == torch.float32: + return None + + fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) + if fp16_supported and weight_dtype == torch.float16: + return None + + bf16_supported = should_use_bf16(inference_device) + if bf16_supported and weight_dtype == torch.bfloat16: + return None + + if fp16_supported and torch.float16 in supported_dtypes: + return torch.float16 + + elif bf16_supported and torch.bfloat16 in supported_dtypes: + return torch.bfloat16 + else: + return torch.float32 + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -492,12 +563,21 @@ def text_encoder_dtype(device=None): elif args.fp32_text_enc: return torch.float32 - if should_use_fp16(device, prioritize_performance=False): + if is_device_cpu(device): return torch.float16 + + return torch.float16 + + +def intermediate_device(): + if args.gpu_only: + return get_torch_device() else: - return torch.float32 + return torch.device("cpu") def vae_device(): + if args.cpu_vae: + return torch.device("cpu") return get_torch_device() def vae_offload_device(): @@ -515,6 +595,22 @@ def get_autocast_device(dev): return dev.type return "cuda" +def supports_dtype(device, dtype): #TODO + if dtype == torch.float32: + return True + if is_device_cpu(device): + return False + if dtype == torch.float16: + return True + if dtype == torch.bfloat16: + return True + return False + +def device_supports_non_blocking(device): + if is_device_mps(device): + return False #pytorch bug? mps doesn't support non blocking + return True + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: @@ -525,15 +621,17 @@ def cast_to_device(tensor, device, dtype, copy=False): elif is_intel_xpu(): device_supports_cast = True + non_blocking = device_supports_non_blocking(device) + if device_supports_cast: if copy: if tensor.device == device: - return tensor.to(dtype, copy=copy) - return tensor.to(device, copy=copy).to(dtype) + return tensor.to(dtype, copy=copy, non_blocking=non_blocking) + return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(device).to(dtype) + return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(dtype).to(device, copy=copy) + return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) def xformers_enabled(): global directml_enabled @@ -606,19 +704,22 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS -def is_device_cpu(device): +def is_device_type(device, type): if hasattr(device, 'type'): - if (device.type == 'cpu'): + if (device.type == type): return True return False +def is_device_cpu(device): + return is_device_type(device, 'cpu') + def is_device_mps(device): - if hasattr(device, 'type'): - if (device.type == 'mps'): - return True - return False + return is_device_type(device, 'mps') + +def is_device_cuda(device): + return is_device_type(device, 'cuda') -def should_use_fp16(device=None, model_params=0, prioritize_performance=True): +def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled if device is not None: @@ -628,9 +729,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if FORCE_FP16: return True - if device is not None: #TODO + if device is not None: if is_device_mps(device): - return False + return True if FORCE_FP32: return False @@ -638,16 +739,22 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if directml_enabled: return False - if cpu_mode() or mps_mode(): - return False #TODO ? + if mps_mode(): + return True + + if cpu_mode(): + return False if is_intel_xpu(): return True - if torch.cuda.is_bf16_supported(): + if torch.version.hip: return True props = torch.cuda.get_device_properties("cuda") + if props.major >= 8: + return True + if props.major < 6: return False @@ -655,12 +762,12 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): #FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled #when the model doesn't actually fit on the card #TODO: actually test if GP106 and others have the same type of behavior - nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"] + nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"] for x in nvidia_10_series: if x in props.name.lower(): fp16_works = True - if fp16_works: + if fp16_works or manual_cast: free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True @@ -676,6 +783,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return True +def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + if device is not None: + if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow + return False + + if device is not None: #TODO not sure about mps bf16 support + if is_device_mps(device): + return False + + if FORCE_FP32: + return False + + if directml_enabled: + return False + + if cpu_mode() or mps_mode(): + return False + + if is_intel_xpu(): + return True + + if device is None: + device = torch.device("cuda") + + props = torch.cuda.get_device_properties(device) + if props.major >= 8: + return True + + bf16_works = torch.cuda.is_bf16_supported() + + if bf16_works or manual_cast: + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True + + return False + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: @@ -687,11 +831,11 @@ def soft_empty_cache(force=False): torch.cuda.empty_cache() torch.cuda.ipc_collect() -def resolve_lowvram_weight(weight, model, key): - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = comfy.utils.get_attr(model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] +def unload_all_models(): + free_memory(1e30, get_torch_device()) + + +def resolve_lowvram_weight(weight, model, key): #TODO: remove return weight #TODO: might be cleaner to put this somewhere else diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a3cffc3b..8dda84cf 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1,10 +1,24 @@ import torch import copy import inspect +import logging +import uuid import comfy.utils import comfy.model_management +def apply_weight_decompose(dora_scale, weight): + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * (weight.dim() - 1)) + .transpose(0, 1) + ) + + return weight * (dora_scale / weight_norm) + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size @@ -23,28 +37,29 @@ class ModelPatcher: self.current_device = current_device self.weight_inplace_update = weight_inplace_update + self.model_lowvram = False + self.patches_uuid = uuid.uuid4() def model_size(self): if self.size > 0: return self.size model_sd = self.model.state_dict() - size = 0 - for k in model_sd: - t = model_sd[k] - size += t.nelement() * t.element_size() - self.size = size + self.size = comfy.model_management.module_size(self.model) self.model_keys = set(model_sd.keys()) - return size + return self.size def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] + n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys + n.backup = self.backup + n.object_patches_backup = self.object_patches_backup return n def is_clone(self, other): @@ -52,31 +67,58 @@ class ModelPatcher: return True return False + def clone_has_same_weights(self, clone): + if not self.is_clone(clone): + return False + + if len(self.patches) == 0 and len(clone.patches) == 0: + return True + + if self.patches_uuid == clone.patches_uuid: + if len(self.patches) != len(clone.patches): + logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") + else: + return True + def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) - def set_model_sampler_cfg_function(self, sampler_cfg_function): + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True + + def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): + self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function + def set_model_denoise_mask_function(self, denoise_mask_function): + self.model_options["denoise_mask_function"] = denoise_mask_function + def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: to["patches"] = {} to["patches"][name] = to["patches"].get(name, []) + [patch] - def set_model_patch_replace(self, patch, name, block_name, number): + def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): to = self.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} if name not in to["patches_replace"]: to["patches_replace"][name] = {} - to["patches_replace"][name][(block_name, number)] = patch + if transformer_index is not None: + block = (block_name, number, transformer_index) + else: + block = (block_name, number) + to["patches_replace"][name][block] = patch def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") @@ -84,11 +126,11 @@ class ModelPatcher: def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") - def set_model_attn1_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn1", block_name, number) + def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) - def set_model_attn2_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn2", block_name, number) + def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) def set_model_attn1_output_patch(self, patch): self.set_model_patch(patch, "attn1_output_patch") @@ -142,6 +184,7 @@ class ModelPatcher: current_patches.append((strength_patch, patches[k], strength_model)) self.patches[k] = current_patches + self.patches_uuid = uuid.uuid4() return list(p) def get_key_patches(self, filter_prefix=None): @@ -167,41 +210,87 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self, device_to=None): + def patch_weight_to_device(self, key, device_to=None): + if key not in self.patches: + return + + weight = comfy.utils.get_attr(self.model, key) + + inplace_update = self.weight_inplace_update + + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr_param(self.model, key, out_weight) + + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: - old = getattr(self.model, k) + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) if k not in self.object_patches_backup: self.object_patches_backup[k] = old - setattr(self.model, k, self.object_patches[k]) - - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) - continue - weight = model_sd[key] - - inplace_update = self.weight_inplace_update + if patch_weights: + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + logging.warning("could not patch. key doesn't exist in model: {}".format(key)) + continue - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + self.patch_weight_to_device(key, device_to) if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr(self.model, key, out_weight) - del temp_weight + self.model.to(device_to) + self.current_device = device_to - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + return self.model + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): + self.patch_model(device_to, patch_weights=False) + + logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) + class LowVramPatch: + def __init__(self, key, model_patcher): + self.key = key + self.model_patcher = model_patcher + def __call__(self, weight): + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + + mem_counter = 0 + for n, m in self.model.named_modules(): + lowvram_weight = False + if hasattr(m, "comfy_cast_weights"): + module_mem = comfy.model_management.module_size(m) + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + if lowvram_weight: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self) + if bias_key in self.patches: + m.bias_function = LowVramPatch(weight_key, self) + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "weight"): + self.patch_weight_to_device(weight_key, device_to) + self.patch_weight_to_device(bias_key, device_to) + m.to(device_to) + mem_counter += comfy.model_management.module_size(m) + logging.debug("lowvram: loaded module regularly {}".format(m)) + + self.model_lowvram = True return self.model def calculate_weight(self, patches, weight, key): @@ -217,15 +306,22 @@ class ModelPatcher: v = (self.calculate_weight(v[1:], v[0].clone(), key), ) if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": w1 = v[0] if alpha != 0.0: if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) - elif len(v) == 4: #lora/locon + elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) + dora_scale = v[4] if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: @@ -235,9 +331,11 @@ class ModelPatcher: mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: - print("ERROR", key, e) - elif len(v) == 8: #lokr + logging.error("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "lokr": w1 = v[0] w2 = v[1] w1_a = v[3] @@ -245,6 +343,7 @@ class ModelPatcher: w2_a = v[5] w2_b = v[6] t2 = v[7] + dora_scale = v[8] dim = None if w1 is None: @@ -274,15 +373,18 @@ class ModelPatcher: try: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: - print("ERROR", key, e) - else: #loha + logging.error("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "loha": w1a = v[0] w1b = v[1] if v[2] is not None: alpha *= v[2] / w1b.shape[0] w2a = v[3] w2b = v[4] + dora_scale = v[7] if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] @@ -303,29 +405,61 @@ class ModelPatcher: try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "glora": + if v[4] is not None: + alpha *= v[4] / v[0].shape[0] + + dora_scale = v[5] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + + try: + weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + except Exception as e: + logging.error("ERROR {} {} {}".format(patch_type, key, e)) + else: + logging.warning("patch type not recognized {} {}".format(patch_type, key)) return weight - def unpatch_model(self, device_to=None): - keys = list(self.backup.keys()) + def unpatch_model(self, device_to=None, unpatch_weights=True): + if unpatch_weights: + if self.model_lowvram: + for m in self.model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None - if self.weight_inplace_update: - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) + self.model_lowvram = False - self.backup = {} + keys = list(self.backup.keys()) - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + if self.weight_inplace_update: + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + comfy.utils.set_attr_param(self.model, k, self.backup[k]) + + self.backup.clear() + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: - setattr(self.model, k, self.object_patches_backup[k]) + comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup = {} diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 69c8b1f0..37976b32 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -1,5 +1,4 @@ import torch -import numpy as np from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule import math @@ -12,20 +11,43 @@ class EPS: sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input - model_output * sigma + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigma ** 2.0) + else: + noise = noise * sigma + + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent class V_PREDICTION(EPS): def calculate_denoised(self, sigma, model_output, model_input): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class EDM(V_PREDICTION): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None): super().__init__() - beta_schedule = "linear" + if model_config is not None: - beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule) - self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + beta_schedule = sampling_settings.get("beta_schedule", "linear") + linear_start = sampling_settings.get("linear_start", 0.00085) + linear_end = sampling_settings.get("linear_end", 0.012) + + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) self.sigma_data = 1.0 def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, @@ -35,8 +57,7 @@ class ModelSamplingDiscrete(torch.nn.Module): else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas - alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) - # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod = torch.cumprod(alphas, dim=0) timesteps, = betas.shape self.num_timesteps = int(timesteps) @@ -51,8 +72,8 @@ class ModelSamplingDiscrete(torch.nn.Module): self.set_sigmas(sigmas) def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) + self.register_buffer('sigmas', sigmas.float()) + self.register_buffer('log_sigmas', sigmas.log().float()) @property def sigma_min(self): @@ -87,8 +108,6 @@ class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingContinuousEDM(torch.nn.Module): def __init__(self, model_config=None): super().__init__() - self.sigma_data = 1.0 - if model_config is not None: sampling_settings = model_config.sampling_settings else: @@ -96,9 +115,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module): sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_max = sampling_settings.get("sigma_max", 120.0) - self.set_sigma_range(sigma_min, sigma_max) + sigma_data = sampling_settings.get("sigma_data", 1.0) + self.set_parameters(sigma_min, sigma_max, sigma_data) - def set_sigma_range(self, sigma_min, sigma_max): + def set_parameters(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers @@ -127,3 +148,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module): log_sigma_min = math.log(self.sigma_min) return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) + +class StableCascadeSampling(ModelSamplingDiscrete): + def __init__(self, model_config=None): + super().__init__() + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + self.set_parameters(sampling_settings.get("shift", 1.0)) + + def set_parameters(self, shift=1.0, cosine_s=8e-3): + self.shift = shift + self.cosine_s = torch.tensor(cosine_s) + self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 + + #This part is just for compatibility with some schedulers in the codebase + self.num_timesteps = 10000 + sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) + for x in range(self.num_timesteps): + t = (x + 1) / self.num_timesteps + sigmas[x] = self.sigma(t) + + self.set_sigmas(sigmas) + + def sigma(self, timestep): + alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod) + + if self.shift != 1.0: + var = alpha_cumprod + logSNR = (var/(1-var)).log() + logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift)) + alpha_cumprod = logSNR.sigmoid() + + alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999) + return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5 + + def timestep(self, sigma): + var = 1 / ((sigma * sigma) + 1) + var = var.clamp(0, 1.0) + s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + + percent = 1.0 - percent + return self.sigma(torch.tensor(percent)) diff --git a/comfy/ops.py b/comfy/ops.py index 0bfb698a..eb650768 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,40 +1,163 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch -from contextlib import contextmanager - -class Linear(torch.nn.Linear): - def reset_parameters(self): - return None - -class Conv2d(torch.nn.Conv2d): - def reset_parameters(self): - return None - -class Conv3d(torch.nn.Conv3d): - def reset_parameters(self): - return None - -def conv_nd(dims, *args, **kwargs): - if dims == 2: - return Conv2d(*args, **kwargs) - elif dims == 3: - return Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - -@contextmanager -def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way - old_torch_nn_linear = torch.nn.Linear - force_device = device - force_dtype = dtype - def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): - if force_device is not None: - device = force_device - if force_dtype is not None: - dtype = force_dtype - return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) - - torch.nn.Linear = linear_with_dtype - try: - yield - finally: - torch.nn.Linear = old_torch_nn_linear +import comfy.model_management + +def cast_bias_weight(s, input): + bias = None + non_blocking = comfy.model_management.device_supports_non_blocking(input.device) + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.bias_function is not None: + bias = s.bias_function(bias) + weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.weight_function is not None: + weight = s.weight_function(weight) + return weight, bias + +class CastWeightBiasOp: + comfy_cast_weights = False + weight_function = None + bias_function = None + +class disable_weight_init: + class Linear(torch.nn.Linear, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + + class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + bias = None + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input, output_size=None): + num_spatial_dims = 2 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, + num_spatial_dims, self.dilation) + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.conv_transpose2d( + input, weight, bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + @classmethod + def conv_nd(s, dims, *args, **kwargs): + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +class manual_cast(disable_weight_init): + class Linear(disable_weight_init.Linear): + comfy_cast_weights = True + + class Conv2d(disable_weight_init.Conv2d): + comfy_cast_weights = True + + class Conv3d(disable_weight_init.Conv3d): + comfy_cast_weights = True + + class GroupNorm(disable_weight_init.GroupNorm): + comfy_cast_weights = True + + class LayerNorm(disable_weight_init.LayerNorm): + comfy_cast_weights = True + + class ConvTranspose2d(disable_weight_init.ConvTranspose2d): + comfy_cast_weights = True diff --git a/comfy/sample.py b/comfy/sample.py index 034db97e..5c8a7d13 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None): def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = noise_mask.to(device) @@ -47,7 +46,8 @@ def convert_cond(cond): temp = c[1].copy() model_conds = temp.get("model_conds", {}) if c[0] is not None: - model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) + model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove + temp["cross_attn"] = c[0] temp["model_conds"] = model_conds out.append(temp) return out @@ -98,10 +98,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.cpu() + samples = samples.to(comfy.model_management.intermediate_device()) cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): @@ -111,8 +111,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent sigmas = sigmas.to(model.load_device) samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.cpu() + samples = samples.to(comfy.model_management.intermediate_device()) cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples diff --git a/comfy/samplers.py b/comfy/samplers.py index 1d012a51..3678dc81 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,260 +1,266 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch -import enum +import collections from comfy import model_management import math -from comfy import model_base -import comfy.utils -import comfy.conds +import logging + +def get_area_and_mult(conds, x_in, timestep_in): + area = (x_in.shape[2], x_in.shape[3], 0, 0) + strength = 1.0 + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] + if timestep_in[0] > timestep_start: + return None + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] + if timestep_in[0] < timestep_end: + return None + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] + + input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if 'mask' in conds: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask_strength = 1.0 + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in conds: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + + conditioning = {} + model_conds = conds["model_conds"] + for c in model_conds: + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + + control = conds.get('control', None) + + patches = None + if 'gligen' in conds: + gligen = conds['gligen'] + patches = {} + gligen_type = gligen[0] + gligen_model = gligen[1] + if gligen_type == "position": + gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) + else: + gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) + + patches['middle_patch'] = [gligen_patch] + + cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) + return cond_obj(input_x, mult, conditioning, area, control, patches) + +def cond_equal_size(c1, c2): + if c1 is c2: + return True + if c1.keys() != c2.keys(): + return False + for k in c1: + if not c1[k].can_concat(c2[k]): + return False + return True + +def can_concat_cond(c1, c2): + if c1.input_x.shape != c2.input_x.shape: + return False + + def objects_concatable(obj1, obj2): + if (obj1 is None) != (obj2 is None): + return False + if obj1 is not None: + if obj1 is not obj2: + return False + return True + if not objects_concatable(c1.control, c2.control): + return False -#The main sampling function shared by all the samplers -#Returns denoised -def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(conds, x_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) - strength = 1.0 - - if 'timestep_start' in conds: - timestep_start = conds['timestep_start'] - if timestep_in[0] > timestep_start: - return None - if 'timestep_end' in conds: - timestep_end = conds['timestep_end'] - if timestep_in[0] < timestep_end: - return None - if 'area' in conds: - area = conds['area'] - if 'strength' in conds: - strength = conds['strength'] - - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in conds: - # Scale the mask to the size of the input - # The mask should have been resized as we began the sampling process - mask_strength = 1.0 - if "mask_strength" in conds: - mask_strength = conds["mask_strength"] - mask = conds['mask'] - assert(mask.shape[1] == x_in.shape[2]) - assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) - else: - mask = torch.ones_like(input_x) - mult = mask * strength - - if 'mask' not in conds: - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - - conditionning = {} - model_conds = conds["model_conds"] - for c in model_conds: - conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - - control = None - if 'control' in conds: - control = conds['control'] - - patches = None - if 'gligen' in conds: - gligen = conds['gligen'] - patches = {} - gligen_type = gligen[0] - gligen_model = gligen[1] - if gligen_type == "position": - gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) - else: - gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) + if not objects_concatable(c1.patches, c2.patches): + return False - patches['middle_patch'] = [gligen_patch] + return cond_equal_size(c1.conditioning, c2.conditioning) - return (input_x, mult, conditionning, area, control, patches) +def cond_cat(c_list): + c_crossattn = [] + c_concat = [] + c_adm = [] + crossattn_max_len = 0 - def cond_equal_size(c1, c2): - if c1 is c2: - return True - if c1.keys() != c2.keys(): - return False - for k in c1: - if not c1[k].can_concat(c2[k]): - return False - return True + temp = {} + for x in c_list: + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur - def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: - return False + out = {} + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) - #control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: - return False + return out - #patches - if (c1[5] is None) != (c2[5] is None): - return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - - return cond_equal_size(c1[2], c2[2]) - - def cond_cat(c_list): - c_crossattn = [] - c_concat = [] - c_adm = [] - crossattn_max_len = 0 - - temp = {} - for x in c_list: - for k in x: - cur = temp.get(k, []) - cur.append(x[k]) - temp[k] = cur - - out = {} - for k in temp: - conds = temp[k] - out[k] = conds[0].concat(conds[1:]) - - return out - - def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - - to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, UNCOND)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - else: - transformer_options["patches"] = patches +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): + out_cond = torch.zeros_like(x_in) + out_count = torch.ones_like(x_in) * 1e-37 - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep + out_uncond = torch.zeros_like(x_in) + out_uncond_count = torch.ones_like(x_in) * 1e-37 - c['transformer_options'] = transformer_options + COND = 0 + UNCOND = 1 - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x + to_run = [] + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue - for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + to_run += [(p, COND)] + if uncond is not None: + for x in uncond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, UNCOND)] + + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = model_management.get_free_memory(x_in.device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) < free_memory: + to_batch = batch_amount + break + + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) + + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) + + transformer_options = {} + if 'transformer_options' in model_options: + transformer_options = model_options['transformer_options'].copy() + + if patches is not None: + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["sigmas"] = timestep + c['transformer_options'] = transformer_options - if math.isclose(cond_scale, 1.0): - uncond = None + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + del input_x + + for o in range(batch_chunks): + if cond_or_uncond[o] == COND: + out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + else: + out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + del mult + + out_cond /= out_count + del out_count + out_uncond /= out_uncond_count + del out_uncond_count + return out_cond, out_uncond + +#The main sampling function shared by all the samplers +#Returns denoised +def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: + uncond_ = None + else: + uncond_ = uncond - cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) if "sampler_cfg_function" in model_options: - args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} - return x - model_options["sampler_cfg_function"](args) + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + cfg_result = x - model_options["sampler_cfg_function"](args) else: - return uncond + (cond - uncond) * cond_scale + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + + return cfg_result class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): @@ -267,19 +273,19 @@ class CFGNoisePredictor(torch.nn.Module): return self.apply_model(*args, **kwargs) class KSamplerX0Inpaint(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, sigmas): super().__init__() self.inner_model = model + self.sigmas = sigmas def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: + if "denoise_mask_function" in model_options: + denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask - x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask + x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) if denoise_mask is not None: - out *= denoise_mask - - if denoise_mask is not None: - out += self.latent_image * latent_mask + out = out * denoise_mask + self.latent_image * latent_mask return out def simple_scheduler(model, steps): @@ -294,7 +300,7 @@ def simple_scheduler(model, steps): def ddim_scheduler(model, steps): s = model.model_sampling sigs = [] - ss = len(s.sigmas) // steps + ss = max(len(s.sigmas) // steps, 1) x = 1 while x < len(s.sigmas): sigs += [float(s.sigmas[x])] @@ -512,14 +518,6 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -class UNIPC(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) - -class UNIPCBH2(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) - KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] @@ -532,7 +530,7 @@ class KSAMPLER(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): extra_args["denoise_mask"] = denoise_mask - model_k = KSamplerX0Inpaint(model_wrap) + model_k = KSamplerX0Inpaint(model_wrap, sigmas) model_k.latent_image = latent_image if self.inpaint_options.get("random", False): #TODO: Should this be the default? generator = torch.manual_seed(extra_args.get("seed", 41) + 1) @@ -540,20 +538,15 @@ class KSAMPLER(Sampler): else: model_k.noise = noise - if self.max_denoise(model_wrap, sigmas): - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] + noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) k_callback = None total_steps = len(sigmas) - 1 if callback is not None: k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) - if latent_image is not None: - noise += latent_image - samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return samples @@ -567,11 +560,11 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) sampler_function = dpm_fast_function elif sampler_name == "dpm_adaptive": - def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options): sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] - return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options) sampler_function = dpm_adaptive_function else: sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) @@ -594,6 +587,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = model.process_latent_in(latent_image) + + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) @@ -605,13 +605,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) - - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) @@ -634,14 +627,14 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): elif scheduler_name == "sgm_uniform": sigmas = normal_scheduler(model, steps, sgm=True) else: - print("error invalid scheduler", self.scheduler) + logging.error("error invalid scheduler {}".format(scheduler_name)) return sigmas def sampler_object(name): if name == "uni_pc": - sampler = UNIPC() + sampler = KSAMPLER(uni_pc.sample_unipc) elif name == "uni_pc_bh2": - sampler = UNIPCBH2() + sampler = KSAMPLER(uni_pc.sample_unipc_bh2) elif name == "ddim": sampler = ksampler("euler", inpaint_options={"random": True}) else: @@ -651,6 +644,7 @@ def sampler_object(name): class KSampler: SCHEDULERS = SCHEDULER_NAMES SAMPLERS = SAMPLER_NAMES + DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2')) def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -669,7 +663,7 @@ class KSampler: sigmas = None discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']: + if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS: steps += 1 discard_penultimate_sigma = True diff --git a/comfy/sd.py b/comfy/sd.py index f4f84d0a..85821120 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,10 +1,12 @@ import torch -import contextlib -import math +from enum import Enum +import logging from comfy import model_management -from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine +from .ldm.cascade.stage_a import StageA +from .ldm.cascade.stage_c_coder import StageC_coder + import yaml import comfy.utils @@ -36,7 +38,7 @@ def load_model_weights(model, sd): w = sd.pop(x) del w if len(m) > 0: - print("missing", m) + logging.warning("missing {}".format(m)) return model def load_clip_weights(model, sd): @@ -51,7 +53,7 @@ def load_clip_weights(model, sd): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.") return load_model_weights(model, sd) @@ -80,7 +82,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): k1 = set(k1) for x in loaded: if (x not in k) and (x not in k1): - print("NOT LOADED", x) + logging.warning("NOT LOADED {}".format(x)) return (new_modelpatcher, new_clip) @@ -122,10 +124,13 @@ class CLIP: return self.tokenizer.tokenize_with_weights(text, return_word_ids) def encode_from_tokens(self, tokens, return_pooled=False): + self.cond_stage_model.reset_clip_options() + if self.layer_idx is not None: - self.cond_stage_model.clip_layer(self.layer_idx) - else: - self.cond_stage_model.reset_clip_layer() + self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) + + if return_pooled == "unprojected": + self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) @@ -137,8 +142,11 @@ class CLIP: tokens = self.tokenize(text) return self.encode_from_tokens(tokens) - def load_sd(self, sd): - return self.cond_stage_model.load_sd(sd) + def load_sd(self, sd, full_model=False): + if full_model: + return self.cond_stage_model.load_state_dict(sd, strict=False) + else: + return self.cond_stage_model.load_sd(sd) def get_sd(self): return self.cond_stage_model.state_dict() @@ -151,12 +159,17 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None, dtype=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) + self.downscale_ratio = 8 + self.upscale_ratio = 8 + self.latent_channels = 4 + self.process_input = lambda image: image * 2.0 - 1.0 + self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -169,9 +182,43 @@ class VAE: decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: self.first_stage_model = comfy.taesd.taesd.TAESD() + elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade + self.first_stage_model = StageA() + self.downscale_ratio = 4 + self.upscale_ratio = 4 + #TODO + #self.memory_used_encode + #self.memory_used_decode + self.process_input = lambda image: image + self.process_output = lambda image: image + elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["encoder.{}".format(k)] = sd[k] + sd = new_sd + elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["previewer.{}".format(k)] = sd[k] + sd = new_sd + elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.upscale_ratio = 4 + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) @@ -179,32 +226,44 @@ class VAE: m, u = self.first_stage_model.load_state_dict(sd, strict=False) if len(m) > 0: - print("Missing VAE keys", m) + logging.warning("Missing VAE keys {}".format(m)) if len(u) > 0: - print("Leftover VAE keys", u) + logging.debug("Leftover VAE keys {}".format(u)) if device is None: device = model_management.vae_device() self.device = device offload_device = model_management.vae_offload_device() - self.vae_dtype = model_management.vae_dtype() + if dtype is None: + dtype = model_management.vae_dtype() + self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) + self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def vae_encode_crop_pixels(self, pixels): + x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio + y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % self.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % self.downscale_ratio) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + return pixels + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() - output = torch.clamp(( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) - / 3.0) / 2.0, min=0.0, max=1.0) + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + output = self.process_output( + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) + / 3.0) return output def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -213,10 +272,10 @@ class VAE: steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() - samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples /= 3.0 return samples @@ -228,15 +287,15 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) + pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - pixel_samples = pixel_samples.cpu().movedim(1,-1) + pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): @@ -245,6 +304,7 @@ class VAE: return output.movedim(1,-1) def encode(self, pixel_samples): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) @@ -252,18 +312,19 @@ class VAE: free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") + samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() + pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) @@ -290,8 +351,11 @@ def load_style_model(ckpt_path): model.load_state_dict(model_data) return StyleModel(model) +class CLIPType(Enum): + STABLE_DIFFUSION = 1 + STABLE_CASCADE = 2 -def load_clip(ckpt_paths, embedding_directory=None): +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] for p in ckpt_paths: clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) @@ -301,14 +365,21 @@ def load_clip(ckpt_paths, embedding_directory=None): for i in range(len(clip_data)): if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: - clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "") + else: + if "text_projection" in clip_data[i]: + clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node clip_target = EmptyClass() clip_target.params = {} if len(clip_data) == 1: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: - clip_target.clip = sdxl_clip.SDXLRefinerClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.STABLE_CASCADE: + clip_target.clip = sdxl_clip.StableCascadeClipModel + clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer + else: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer @@ -323,10 +394,10 @@ def load_clip(ckpt_paths, embedding_directory=None): for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: - print("clip missing:", m) + logging.warning("clip missing: {}".format(m)) if len(u) > 0: - print("clip unexpected:", u) + logging.debug("clip unexpected: {}".format(u)) return clip def load_gligen(ckpt_path): @@ -431,12 +502,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip_target = None parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") - unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() - class WeightsLoader(torch.nn.Module): - pass + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -451,27 +523,33 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VAE(sd=vae_sd) if output_clip: - w = WeightsLoader() clip_target = model_config.clip_target() if clip_target is not None: - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - sd = model_config.process_clip_state_dict(sd) - load_model_weights(w, sd) + clip_sd = model_config.process_clip_state_dict(sd) + if len(clip_sd) > 0: + clip = CLIP(clip_target, embedding_directory=embedding_directory) + m, u = clip.load_sd(clip_sd, full_model=True) + if len(m) > 0: + logging.warning("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + else: + logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: - print("left over keys:", left_over) + logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): - print("loaded straight to GPU") + logging.info("loaded straight to GPU") model_management.load_model_gpu(model_patcher) return (model_patcher, clip, vae, clipvision) @@ -480,14 +558,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) - if "input_blocks.0.0.weight" in sd: #ldm - model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) + load_device = model_management.get_torch_device() + + if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade + model_config = model_detection.model_config_from_unet(sd, "") if model_config is None: return None new_sd = sd else: #diffusers - model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) + model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None: return None @@ -498,25 +578,36 @@ def load_unet_state_dict(sd): #load unet in diffusers format if k in sd: new_sd[diffusers_keys[k]] = sd.pop(k) else: - print(diffusers_keys[k], k) + logging.warning("{} {}".format(diffusers_keys[k], k)) + offload_device = model_management.unet_offload_device() + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: - print("left over keys in unet:", left_over) - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + logging.info("left over keys in unet: {}".format(left_over)) + return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) def load_unet(unet_path): sd = comfy.utils.load_torch_file(unet_path) model = load_unet_state_dict(sd) if model is None: - print("ERROR UNSUPPORTED UNET", unet_path) + logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model -def save_checkpoint(output_path, model, clip, vae, metadata=None): - model_management.load_models_gpu([model, clip.load_model()]) - sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) +def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): + clip_sd = None + load_models = [model] + if clip is not None: + load_models.append(clip.load_model()) + clip_sd = clip.get_sd() + + model_management.load_models_gpu(load_models) + clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None + sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) comfy.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 58acb97f..ff6db0d2 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -1,12 +1,14 @@ import os -from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils +from transformers import CLIPTokenizer import comfy.ops import torch import traceback import zipfile from . import model_management -import contextlib +import comfy.clip_model +import json +import logging def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) @@ -37,7 +39,7 @@ class ClipTokenWeightEncoder: out, pooled = self.encode(to_encode) if pooled is not None: - first_pooled = pooled[0:1].cpu() + first_pooled = pooled[0:1].to(model_management.intermediate_device()) else: first_pooled = pooled @@ -54,8 +56,8 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - return out[-1:].cpu(), first_pooled - return torch.cat(output, dim=-2).cpu(), first_pooled + return out[-1:].to(model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" @@ -65,31 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, - freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, - special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, - model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32 + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS - self.num_layers = 12 - if textmodel_path is not None: - self.transformer = model_class.from_pretrained(textmodel_path) - else: - if textmodel_json_config is None: - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") - config = config_class.from_json_file(textmodel_json_config) - self.num_layers = config.num_hidden_layers - with comfy.ops.use_comfy_ops(device, dtype): - with modeling_utils.no_init_weights(): - self.transformer = model_class(config) - - self.inner_name = inner_name - if dtype is not None: - self.transformer.to(dtype) - inner_model = getattr(self.transformer, self.inner_name) - if hasattr(inner_model, "embeddings"): - inner_model.embeddings.to(torch.float32) - else: - self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) + + if textmodel_json_config is None: + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") + + with open(textmodel_json_config) as f: + config = json.load(f) + + self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast) + self.num_layers = self.transformer.num_layers self.max_length = max_length if freeze: @@ -97,16 +87,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = layer self.layer_idx = None self.special_tokens = special_tokens - self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False + self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": assert layer_idx is not None - assert abs(layer_idx) <= self.num_layers - self.clip_layer(layer_idx) - self.layer_default = (self.layer, self.layer_idx) + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) def freeze(self): self.transformer = self.transformer.eval() @@ -114,16 +106,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): for param in self.parameters(): param.requires_grad = False - def clip_layer(self, layer_idx): - if abs(layer_idx) >= self.num_layers: + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" self.layer_idx = layer_idx - def reset_clip_layer(self): - self.layer = self.layer_default[0] - self.layer_idx = self.layer_default[1] + def reset_clip_options(self): + self.layer = self.options_default[0] + self.layer_idx = self.options_default[1] + self.return_projected_pooled = self.options_default[2] def set_up_textual_embeddings(self, tokens, current_embeds): out_tokens = [] @@ -143,7 +138,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_temp += [next_new_token] next_new_token += 1 else: - print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) + logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1])) while len(tokens_temp) < len(x): tokens_temp += [self.special_tokens["pad"]] out_tokens += [tokens_temp] @@ -170,51 +165,37 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: - precision_scope = torch.autocast + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs[0] else: - precision_scope = lambda a, dtype: contextlib.nullcontext(a) - - with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): - attention_mask = None - if self.enable_attention_masks: - attention_mask = torch.zeros_like(tokens) - max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == max_token: - break - - outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") - self.transformer.set_input_embeddings(backup_embeds) - - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] - else: - z = outputs.hidden_states[self.layer_idx] - if self.layer_norm_hidden_state: - z = getattr(self.transformer, self.inner_name).final_layer_norm(z) + z = outputs[1] - if hasattr(outputs, "pooler_output"): - pooled_output = outputs.pooler_output.float() - else: - pooled_output = None + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() return z.float(), pooled_output def encode(self, tokens): return self(tokens) def load_sd(self, sd): - if "text_projection" in sd: - self.text_projection[:] = sd.pop("text_projection") - if "text_projection.weight" in sd: - self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): @@ -349,9 +330,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No else: embed = torch.load(embed_path, map_location="cpu") except Exception as e: - print(traceback.format_exc()) - print() - print("error loading embedding, skipping loading:", embedding_name) + logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name)) return None if embed_out is None: @@ -375,11 +354,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length + self.min_length = min_length empty = self.tokenizer('')["input_ids"] if has_start_token: @@ -441,7 +421,7 @@ class SDTokenizer: embedding_name = word[len(self.embedding_identifier):].strip('\n') embed, leftover = self._try_get_embedding(embedding_name) if embed is None: - print(f"warning, embedding:{embedding_name} does not exist, ignoring") + logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring") else: if len(embed.shape) == 1: tokens.append([(embed, weight)]) @@ -491,6 +471,8 @@ class SDTokenizer: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] @@ -524,11 +506,11 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) - def clip_layer(self, layer_idx): - getattr(self, self.clip).clip_layer(layer_idx) + def set_clip_options(self, options): + getattr(self, self.clip).set_clip_options(options) - def reset_clip_layer(self): - getattr(self, self.clip).reset_clip_layer() + def reset_clip_options(self): + getattr(self, self.clip).reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs[self.clip_name] diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 2ee0ca05..9c878d54 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -3,13 +3,13 @@ import torch import os class SD2ClipHModel(sd1_clip.SDClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): if layer == "penultimate": layer="hidden" - layer_idx=23 + layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 673399e2..e62d1ed8 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,13 +3,13 @@ import torch import os class SDXLClipG(sd1_clip.SDClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) def load_sd(self, sd): @@ -37,16 +37,16 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) - def clip_layer(self, layer_idx): - self.clip_l.clip_layer(layer_idx) - self.clip_g.clip_layer(layer_idx) + def set_clip_options(self, options): + self.clip_l.set_clip_options(options) + self.clip_g.set_clip_options(options) - def reset_clip_layer(self): - self.clip_g.reset_clip_layer() - self.clip_l.reset_clip_layer() + def reset_clip_options(self): + self.clip_g.reset_clip_options() + self.clip_l.reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs_g = token_weight_pairs["g"] @@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) + + +class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): + def __init__(self, tokenizer_path=None, embedding_directory=None): + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + +class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer) + +class StableCascadeClipG(sd1_clip.SDClipModel): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) + + def load_sd(self, sd): + return super().load_sd(sd) + +class StableCascadeClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 455323b9..2ce9736b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -40,11 +40,16 @@ class SD15(supported_models_base.BASE): state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() replace_prefix = {} - replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) + replace_prefix["cond_stage_model."] = "clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) return state_dict def process_clip_state_dict_for_saving(self, state_dict): + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict: + state_dict.pop(p) + replace_prefix = {"clip_l.": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -72,10 +77,10 @@ class SD20(supported_models_base.BASE): def process_clip_state_dict(self, state_dict): replace_prefix = {} - replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) - - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) + replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format + replace_prefix["cond_stage_model.model."] = "clip_h." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -131,11 +136,10 @@ class SDXLRefiner(supported_models_base.BASE): def process_clip_state_dict(self, state_dict): keys_to_replace = {} replace_prefix = {} + replace_prefix["conditioner.embedders.0.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -164,7 +168,13 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL def model_type(self, state_dict, prefix=""): - if "v_pred" in state_dict: + if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 + self.latent_format = latent_formats.SDXL_Playground_2_5() + self.sampling_settings["sigma_data"] = 0.5 + self.sampling_settings["sigma_max"] = 80.0 + self.sampling_settings["sigma_min"] = 0.002 + return model_base.ModelType.EDM + elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: return model_base.ModelType.EPS @@ -179,26 +189,28 @@ class SDXL(supported_models_base.BASE): keys_to_replace = {} replace_prefix = {} - replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" + replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model" + replace_prefix["conditioner.embedders.1.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} keys_to_replace = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") for k in state_dict: if k.startswith("clip_l"): state_dict_g[k] = state_dict[k] + state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1)) + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict_g: + state_dict_g.pop(p) + replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_l"] = "conditioner.embedders.0" state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) @@ -217,6 +229,36 @@ class SSD1B(SDXL): "use_temporal_attention": False, } +class Segmind_Vega(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 1, 1, 2, 2], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + +class KOALA_700M(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 2, 5], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + +class KOALA_1B(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 2, 6], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + class SVD_img2vid(supported_models_base.BASE): unet_config = { "model_channels": 320, @@ -242,5 +284,161 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] +class SV3D_u(SVD_img2vid): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 256, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + vae_key_prefix = ["conditioner.embedders.1.encoder."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_u(self, device=device) + return out + +class SV3D_p(SV3D_u): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 1280, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_p(self, device=device) + return out + +class Stable_Zero123(supported_models_base.BASE): + unet_config = { + "context_dim": 768, + "model_channels": 320, + "use_linear_in_transformer": False, + "adm_in_channels": None, + "use_temporal_attention": False, + "in_channels": 8, + } + + unet_extra_config = { + "num_heads": 8, + "num_head_channels": -1, + } + + clip_vision_prefix = "cond_stage_model.model.visual." + + latent_format = latent_formats.SD15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) + return out + + def clip_target(self): + return None + +class SD_X4Upscaler(SD20): + unet_config = { + "context_dim": 1024, + "model_channels": 256, + 'in_channels': 7, + "use_linear_in_transformer": True, + "adm_in_channels": None, + "use_temporal_attention": False, + } + + unet_extra_config = { + "disable_self_attentions": [True, True, True, False], + "num_classes": 1000, + "num_heads": 8, + "num_head_channels": -1, + } + + latent_format = latent_formats.SD_X4 + + sampling_settings = { + "linear_start": 0.0001, + "linear_end": 0.02, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SD_X4Upscaler(self, device=device) + return out + +class Stable_Cascade_C(supported_models_base.BASE): + unet_config = { + "stable_cascade_stage": 'c', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_Prior + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + sampling_settings = { + "shift": 2.0, + } + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoder."] + clip_vision_prefix = "clip_l_vision." + + def process_unet_state_dict(self, state_dict): + key_list = list(state_dict.keys()) + for y in ["weight", "bias"]: + suffix = "in_proj_{}".format(y) + keys = filter(lambda a: a.endswith(suffix), key_list) + for k_from in keys: + weights = state_dict.pop(k_from) + prefix = k_from[:-(len(suffix) + 1)] + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["to_q", "to_k", "to_v"] + k_to = "{}.{}.{}".format(prefix, p[x], y) + state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return state_dict + + def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) + if "clip_g.text_projection" in state_dict: + state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1) + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_C(self, device=device) + return out + + def clip_target(self): + return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) + +class Stable_Cascade_B(Stable_Cascade_C): + unet_config = { + "stable_cascade_stage": 'b', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_B + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + sampling_settings = { + "shift": 1.0, + } + + clip_vision_prefix = None + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_B(self, device=device) + return out + + +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3412cfea..4d7e2593 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -21,11 +21,16 @@ class BASE: noise_aug_config = None sampling_settings = {} latent_format = latent_formats.LatentFormat + vae_key_prefix = ["first_stage_model."] + text_encoder_key_prefix = ["cond_stage_model."] + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + manual_cast_dtype = None @classmethod def matches(s, unet_config): for k in s.unet_config: - if s.unet_config[k] != unet_config[k]: + if k not in unet_config or s.unet_config[k] != unet_config[k]: return False return True @@ -51,6 +56,7 @@ class BASE: return out def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) return state_dict def process_unet_state_dict(self, state_dict): @@ -60,7 +66,13 @@ class BASE: return state_dict def process_clip_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "cond_stage_model."} + replace_prefix = {"": self.text_encoder_key_prefix[0]} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + + def process_clip_vision_state_dict_for_saving(self, state_dict): + replace_prefix = {} + if self.clip_vision_prefix is not None: + replace_prefix[""] = self.clip_vision_prefix return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_unet_state_dict_for_saving(self, state_dict): @@ -68,6 +80,9 @@ class BASE: return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_vae_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "first_stage_model."} + replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def set_inference_dtype(self, dtype, manual_cast_dtype): + self.unet_config['dtype'] = dtype + self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 46f3097a..8f96c54e 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -7,9 +7,10 @@ import torch import torch.nn as nn import comfy.utils +import comfy.ops def conv(n_in, n_out, **kwargs): - return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs) class Clamp(nn.Module): def forward(self, x): @@ -19,7 +20,7 @@ class Block(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) - self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) diff --git a/comfy/utils.py b/comfy/utils.py index 294bbb42..ab47b8f2 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -5,6 +5,7 @@ import comfy.checkpoint_pickle import safetensors.torch import numpy as np from PIL import Image +import logging def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -14,14 +15,14 @@ def load_torch_file(ckpt, safe_load=False, device=None): else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: - print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + logging.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: @@ -98,8 +99,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number): p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return sd +def clip_text_transformers_convert(sd, prefix_from, prefix_to): + sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32) + + tp = "{}text_projection.weight".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp) + + tp = "{}text_projection".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous() + return sd + + UNET_MAP_ATTENTIONS = { "proj_in.weight", "proj_in.bias", @@ -169,6 +184,8 @@ UNET_MAP_BASIC = { } def unet_to_diffusers(unet_config): + if "num_res_blocks" not in unet_config: + return {} num_res_blocks = unet_config["num_res_blocks"] channel_mult = unet_config["channel_mult"] transformer_depth = unet_config["transformer_depth"][:] @@ -239,6 +256,26 @@ def repeat_to_batch_size(tensor, batch_size): return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] return tensor +def resize_to_batch_size(tensor, batch_size): + in_batch_size = tensor.shape[0] + if in_batch_size == batch_size: + return tensor + + if batch_size <= 1: + return tensor[:batch_size] + + output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device) + if batch_size < in_batch_size: + scale = (in_batch_size - 1) / (batch_size - 1) + for i in range(batch_size): + output[i] = tensor[min(round(i * scale), in_batch_size - 1)] + else: + scale = in_batch_size / batch_size + for i in range(batch_size): + output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)] + + return output + def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: @@ -258,8 +295,11 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) - del prev + setattr(obj, attrs[-1], value) + return prev + +def set_attr_param(obj, attr, value): + return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it @@ -356,7 +396,7 @@ def lanczos(samples, width, height): images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) - return result + return result.to(samples.device, samples.dtype) def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": @@ -385,17 +425,19 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): - output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): + output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device) for b in range(samples.shape[0]): s = samples[b:b+1] - out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") - out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") + out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) + out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): + x = max(0, min(s.shape[-1] - overlap, x)) + y = max(0, min(s.shape[-2] - overlap, y)) s_in = s[:,:,y:y+tile_y,x:x+tile_x] - ps = function(s_in).cpu() + ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) for t in range(feather): diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 94d453f2..8138b5f7 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -5,275 +5,7 @@ import torch import torch.nn.functional as F import comfy.model_management -def get_canny_nms_kernel(device=None, dtype=None): - """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" - return torch.tensor( - [ - [[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - ], - device=device, - dtype=dtype, - ) - - -def get_hysteresis_kernel(device=None, dtype=None): - """Utility function that returns the 3x3 kernels for the Canny hysteresis.""" - return torch.tensor( - [ - [[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - ], - device=device, - dtype=dtype, - ) - -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = torch.nn.functional.pad(img, padding, mode="reflect") - img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3]) - - return img - -def get_sobel_kernel2d(device=None, dtype=None): - kernel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=device, dtype=dtype) - kernel_y = kernel_x.transpose(0, 1) - return torch.stack([kernel_x, kernel_y]) - -def spatial_gradient(input, normalized: bool = True): - r"""Compute the first order image derivative in both x and y using a Sobel operator. - .. image:: _static/img/spatial_gradient.png - Args: - input: input image tensor with shape :math:`(B, C, H, W)`. - mode: derivatives modality, can be: `sobel` or `diff`. - order: the order of the derivatives. - normalized: whether the output is normalized. - Return: - the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. - .. note:: - See a working example `here `__. - Examples: - >>> input = torch.rand(1, 3, 4, 4) - >>> output = spatial_gradient(input) # 1x3x2x4x4 - >>> output.shape - torch.Size([1, 3, 2, 4, 4]) - """ - # KORNIA_CHECK_IS_TENSOR(input) - # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W']) - - # allocate kernel - kernel = get_sobel_kernel2d(device=input.device, dtype=input.dtype) - if normalized: - kernel = normalize_kernel2d(kernel) - - # prepare kernel - b, c, h, w = input.shape - tmp_kernel = kernel[:, None, ...] - - # Pad with "replicate for spatial dims, but with zeros for channel - spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] - out_channels: int = 2 - padded_inp = torch.nn.functional.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate') - out = F.conv2d(padded_inp, tmp_kernel, groups=1, padding=0, stride=1) - return out.reshape(b, c, out_channels, h, w) - -def rgb_to_grayscale(image, rgb_weights = None): - r"""Convert a RGB image to grayscale version of image. - - .. image:: _static/img/rgb_to_grayscale.png - - The image data is assumed to be in the range of (0, 1). - - Args: - image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`. - rgb_weights: Weights that will be applied on each channel (RGB). - The sum of the weights should add up to one. - Returns: - grayscale version of the image with shape :math:`(*,1,H,W)`. - - .. note:: - See a working example `here `__. - - Example: - >>> input = torch.rand(2, 3, 4, 5) - >>> gray = rgb_to_grayscale(input) # 2x1x4x5 - """ - - if len(image.shape) < 3 or image.shape[-3] != 3: - raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") - - if rgb_weights is None: - # 8 bit images - if image.dtype == torch.uint8: - rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) - # floating point images - elif image.dtype in (torch.float16, torch.float32, torch.float64): - rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) - else: - raise TypeError(f"Unknown data type: {image.dtype}") - else: - # is tensor that we make sure is in the same device/dtype - rgb_weights = rgb_weights.to(image) - - # unpack the color image channels with RGB order - r: Tensor = image[..., 0:1, :, :] - g: Tensor = image[..., 1:2, :, :] - b: Tensor = image[..., 2:3, :, :] - - w_r, w_g, w_b = rgb_weights.unbind() - return w_r * r + w_g * g + w_b * b - -def canny( - input, - low_threshold = 0.1, - high_threshold = 0.2, - kernel_size = 5, - sigma = 1, - hysteresis = True, - eps = 1e-6, -): - r"""Find edges of the input image and filters them using the Canny algorithm. - .. image:: _static/img/canny.png - Args: - input: input image tensor with shape :math:`(B,C,H,W)`. - low_threshold: lower threshold for the hysteresis procedure. - high_threshold: upper threshold for the hysteresis procedure. - kernel_size: the size of the kernel for the gaussian blur. - sigma: the standard deviation of the kernel for the gaussian blur. - hysteresis: if True, applies the hysteresis edge tracking. - Otherwise, the edges are divided between weak (0.5) and strong (1) edges. - eps: regularization number to avoid NaN during backprop. - Returns: - - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. - - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. - .. note:: - See a working example `here `__. - Example: - >>> input = torch.rand(5, 3, 4, 4) - >>> magnitude, edges = canny(input) # 5x3x4x4 - >>> magnitude.shape - torch.Size([5, 1, 4, 4]) - >>> edges.shape - torch.Size([5, 1, 4, 4]) - """ - # KORNIA_CHECK_IS_TENSOR(input) - # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W']) - # KORNIA_CHECK( - # low_threshold <= high_threshold, - # "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: " - # f"{low_threshold}>{high_threshold}", - # ) - # KORNIA_CHECK(0 < low_threshold < 1, f'Invalid low threshold. Should be in range (0, 1). Got: {low_threshold}') - # KORNIA_CHECK(0 < high_threshold < 1, f'Invalid high threshold. Should be in range (0, 1). Got: {high_threshold}') - - device = input.device - dtype = input.dtype - - # To Grayscale - if input.shape[1] == 3: - input = rgb_to_grayscale(input) - - # Gaussian filter - blurred: Tensor = gaussian_blur_2d(input, kernel_size, sigma) - - # Compute the gradients - gradients: Tensor = spatial_gradient(blurred, normalized=False) - - # Unpack the edges - gx: Tensor = gradients[:, :, 0] - gy: Tensor = gradients[:, :, 1] - - # Compute gradient magnitude and angle - magnitude: Tensor = torch.sqrt(gx * gx + gy * gy + eps) - angle: Tensor = torch.atan2(gy, gx) - - # Radians to Degrees - angle = 180.0 * angle / math.pi - - # Round angle to the nearest 45 degree - angle = torch.round(angle / 45) * 45 - - # Non-maximal suppression - nms_kernels: Tensor = get_canny_nms_kernel(device, dtype) - nms_magnitude: Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) - - # Get the indices for both directions - positive_idx: Tensor = (angle / 45) % 8 - positive_idx = positive_idx.long() - - negative_idx: Tensor = ((angle / 45) + 4) % 8 - negative_idx = negative_idx.long() - - # Apply the non-maximum suppression to the different directions - channel_select_filtered_positive: Tensor = torch.gather(nms_magnitude, 1, positive_idx) - channel_select_filtered_negative: Tensor = torch.gather(nms_magnitude, 1, negative_idx) - - channel_select_filtered: Tensor = torch.stack( - [channel_select_filtered_positive, channel_select_filtered_negative], 1 - ) - - is_max: Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 - - magnitude = magnitude * is_max - - # Threshold - edges: Tensor = F.threshold(magnitude, low_threshold, 0.0) - - low: Tensor = magnitude > low_threshold - high: Tensor = magnitude > high_threshold - - edges = low * 0.5 + high * 0.5 - edges = edges.to(dtype) - - # Hysteresis - if hysteresis: - edges_old: Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) - hysteresis_kernels: Tensor = get_hysteresis_kernel(device, dtype) - - while ((edges_old - edges).abs() != 0).any(): - weak: Tensor = (edges == 0.5).float() - strong: Tensor = (edges == 1).float() - - hysteresis_magnitude: Tensor = F.conv2d( - edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 - ) - hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) - hysteresis_magnitude = hysteresis_magnitude * weak + strong - - edges_old = edges.clone() - edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 - - edges = hysteresis_magnitude - - return magnitude, edges +from kornia.filters import canny class Canny: @@ -291,7 +23,7 @@ class Canny: def detect_edge(self, image, low_threshold, high_threshold): output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1) + img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) return (img_out,) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py new file mode 100644 index 00000000..646fefa1 --- /dev/null +++ b/comfy_extras/nodes_cond.py @@ -0,0 +1,25 @@ + + +class CLIPTextEncodeControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "_for_testing/conditioning" + + def encode(self, clip, conditioning, text): + tokens = clip.tokenize(text) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['cross_attn_controlnet'] = cond + n[1]['pooled_output_controlnet'] = pooled + c.append(n) + return (c, ) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet +} diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 008d0b8d..72ff7957 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -13,6 +13,7 @@ class BasicScheduler: {"model": ("MODEL",), "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("SIGMAS",) @@ -20,8 +21,14 @@ class BasicScheduler: FUNCTION = "get_sigmas" - def get_sigmas(self, model, scheduler, steps): - sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu() + def get_sigmas(self, model, scheduler, steps, denoise): + total_steps = steps + if denoise < 1.0: + total_steps = int(steps/denoise) + + comfy.model_management.load_models_gpu([model]) + sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -87,6 +94,7 @@ class SDTurboScheduler: return {"required": {"model": ("MODEL",), "steps": ("INT", {"default": 1, "min": 1, "max": 10}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("SIGMAS",) @@ -94,8 +102,10 @@ class SDTurboScheduler: FUNCTION = "get_sigmas" - def get_sigmas(self, model, steps): - timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] + def get_sigmas(self, model, steps, denoise): + start_step = 10 - int(10 * denoise) + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] + comfy.model_management.load_models_gpu([model]) sigmas = model.model.model_sampling.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) @@ -171,6 +181,28 @@ class KSamplerSelect: sampler = comfy.samplers.sampler_object(sampler_name) return (sampler, ) +class SamplerDPMPP_3M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + class SamplerDPMPP_2M_SDE: @classmethod def INPUT_TYPES(s): @@ -218,6 +250,66 @@ class SamplerDPMPP_SDE: sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) +class SamplerEulerAncestral: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise): + sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + +class SamplerLMS: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 4, "min": 1, "max": 100}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order): + sampler = comfy.samplers.ksampler("lms", {"order": order}) + return (sampler, ) + +class SamplerDPMAdaptative: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 3, "min": 2, "max": 3}), + "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return (sampler, ) + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -278,8 +370,12 @@ NODE_CLASS_MAPPINGS = { "VPScheduler": VPScheduler, "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, + "SamplerEulerAncestral": SamplerEulerAncestral, + "SamplerLMS": SamplerLMS, + "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, + "SamplerDPMAdaptative": SamplerDPMAdaptative, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, } diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py new file mode 100644 index 00000000..98dbbf10 --- /dev/null +++ b/comfy_extras/nodes_differential_diffusion.py @@ -0,0 +1,42 @@ +# code adapted from https://github.com/exx8/differential-diffusion + +import torch + +class DifferentialDiffusion(): + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply" + CATEGORY = "_for_testing" + INIT = False + + def apply(self, model): + model = model.clone() + model.set_model_denoise_mask_function(self.forward) + return (model,) + + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + model = extra_options["model"] + step_sigmas = extra_options["sigmas"] + sigma_to = model.inner_model.model_sampling.sigma_min + if step_sigmas[-1] > sigma_to: + sigma_to = step_sigmas[-1] + sigma_from = step_sigmas[0] + + ts_from = model.inner_model.model_sampling.timestep(sigma_from) + ts_to = model.inner_model.model_sampling.timestep(sigma_to) + current_ts = model.inner_model.model_sampling.timestep(sigma[0]) + + threshold = (current_ts - ts_to) / (ts_from - ts_to) + + return (denoise_mask >= threshold).to(denoise_mask.dtype) + + +NODE_CLASS_MAPPINGS = { + "DifferentialDiffusion": DifferentialDiffusion, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "DifferentialDiffusion": "Differential Diffusion", +} diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 7512b841..6f1d87bf 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -1,7 +1,7 @@ #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) import torch - +import logging def Fourier_filter(x, threshold, scale): # FFT @@ -34,7 +34,7 @@ class FreeU: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] @@ -49,7 +49,7 @@ class FreeU: try: hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) except: - print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) on_cpu_devices[hsp.device] = True hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) else: @@ -73,7 +73,7 @@ class FreeU_V2: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] @@ -95,7 +95,7 @@ class FreeU_V2: try: hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) except: - print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) on_cpu_devices[hsp.device] = True hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) else: diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index f692945a..cafafa6a 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -1,6 +1,7 @@ import comfy.utils import folder_paths import torch +import logging def load_hypernetwork_patch(path, strength): sd = comfy.utils.load_torch_file(path, safe_load=True) @@ -23,7 +24,7 @@ def load_hypernetwork_patch(path, strength): } if activation_func not in valid_activation: - print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) + logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)) return None out = {} diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index 0d7d4c95..ae55d23d 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -2,9 +2,10 @@ import math from einops import rearrange -import random +# Use torch rng for consistency across generations +from torch import randint -def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: min_value = min(min_value, value) # All big divisors of value (inclusive) @@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter ns = [value // i for i in divisors[:max_options]] # has at least 1 element - random.seed(counter) - idx = random.randint(0, len(ns) - 1) + if len(ns) - 1 > 0: + idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() + else: + idx = 0 return ns[idx] @@ -29,34 +32,31 @@ class HyperTile: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, tile_size, swap_size, max_depth, scale_depth): model_channels = model.model.model_config.unet_config["model_channels"] - apply_to = set() - temp = model_channels - for x in range(max_depth + 1): - apply_to.add(temp) - temp *= 2 - latent_tile_size = max(32, tile_size) // 8 self.temp = None - self.counter = 1 def hypertile_in(q, k, v, extra_options): - if q.shape[-1] in apply_to: + model_chans = q.shape[-2] + orig_shape = extra_options['original_shape'] + apply_to = [] + for i in range(max_depth + 1): + apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) + + if model_chans in apply_to: shape = extra_options["original_shape"] aspect_ratio = shape[-1] / shape[-2] hw = q.size(1) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) - factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 - nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) - self.counter += 1 - nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter) - self.counter += 1 + factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size) + nw = random_divisor(w, latent_tile_size * factor, swap_size) if nh * nw > 1: q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 5ad2235a..af37666b 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -37,7 +37,7 @@ class RepeatImageBatch: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "repeat" @@ -48,6 +48,25 @@ class RepeatImageBatch: s = image.repeat((amount, 1,1,1)) return (s,) +class ImageFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), + "length": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "frombatch" + + CATEGORY = "image/batch" + + def frombatch(self, image, batch_index, length): + s_in = image + batch_index = min(s_in.shape[0] - 1, batch_index) + length = min(s_in.shape[0] - batch_index, length) + s = s_in[batch_index:batch_index + length].clone() + return (s,) + class SaveAnimatedWEBP: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -74,7 +93,7 @@ class SaveAnimatedWEBP: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "image/animation" def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): method = self.methods.get(method) @@ -136,7 +155,7 @@ class SaveAnimatedPNG: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "image/animation" def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append @@ -170,6 +189,7 @@ class SaveAnimatedPNG: NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, + "ImageFromBatch": ImageFromBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, } diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index cedf39d6..eabae088 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -3,9 +3,7 @@ import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: - latent.movedim(1, -1) latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") - latent.movedim(-1, 1) return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) @@ -102,9 +100,56 @@ class LatentInterpolate: samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) return (samples_out,) +class LatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "batch" + + CATEGORY = "latent/batch" + + def batch(self, samples1, samples2): + samples_out = samples1.copy() + s1 = samples1["samples"] + s2 = samples2["samples"] + + if s1.shape[1:] != s2.shape[1:]: + s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center") + s = torch.cat((s1, s2), dim=0) + samples_out["samples"] = s + samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) + return (samples_out,) + +class LatentBatchSeedBehavior: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, seed_behavior): + samples_out = samples.copy() + latent = samples["samples"] + if seed_behavior == "random": + if 'batch_index' in samples_out: + samples_out.pop('batch_index') + elif seed_behavior == "fixed": + batch_number = samples_out.get("batch_index", [0])[0] + samples_out["batch_index"] = [batch_number] * latent.shape[0] + + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, + "LatentBatch": LatentBatch, + "LatentBatchSeedBehavior": LatentBatchSeedBehavior, } diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index d8c65c2b..29589b4a 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -6,6 +6,7 @@ import comfy.utils from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): + source = source.to(destination.device) if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") @@ -20,7 +21,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou if mask is None: mask = torch.ones_like(source) else: - mask = mask.clone() + mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) @@ -340,6 +341,24 @@ class GrowMask: out.append(output) return (torch.stack(out, dim=0),) +class ThresholdMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, mask, value): + mask = (mask > value).float() + return (mask,) NODE_CLASS_MAPPINGS = { @@ -354,6 +373,7 @@ NODE_CLASS_MAPPINGS = { "MaskComposite": MaskComposite, "FeatherMask": FeatherMask, "GrowMask": GrowMask, + "ThresholdMask": ThresholdMask, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index efcdf193..21af4b73 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -1,6 +1,7 @@ import folder_paths import comfy.sd import comfy.model_sampling +import comfy.latent_formats import torch class LCM(comfy.model_sampling.EPS): @@ -17,41 +18,23 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteDistilled(torch.nn.Module): - original_timesteps = 50 - - def __init__(self): - super().__init__() - self.sigma_data = 1.0 - timesteps = 1000 - beta_start = 0.00085 - beta_end = 0.012 +class X0(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output - betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) +class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): + original_timesteps = 50 - self.skip_steps = timesteps // self.original_timesteps + def __init__(self, model_config=None): + super().__init__(model_config) + self.skip_steps = self.num_timesteps // self.original_timesteps - alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) for x in range(self.original_timesteps): - alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] - - sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 - self.set_sigmas(sigmas) - - def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - - @property - def sigma_min(self): - return self.sigmas[0] + sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps] - @property - def sigma_max(self): - return self.sigmas[-1] + self.set_sigmas(sigmas_valid) def timestep(self, sigma): log_sigma = sigma.log() @@ -66,14 +49,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp().to(timestep.device) - def percent_to_sigma(self, percent): - if percent <= 0.0: - return 999999999.9 - if percent >= 1.0: - return 0.0 - percent = 1.0 - percent - return self.sigma(torch.tensor(percent * 999.0)).item() - def rescale_zero_terminal_snr_sigmas(sigmas): alphas_cumprod = 1 / ((sigmas * sigmas) + 1) @@ -98,7 +73,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm"],), + "sampling": (["eps", "v_prediction", "lcm", "x0"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -118,22 +93,50 @@ class ModelSamplingDiscrete: elif sampling == "lcm": sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled + elif sampling == "x0": + sampling_type = X0 class ModelSamplingAdvanced(sampling_base, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) if zsnr: model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingStableCascade: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "eps"],), + "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -146,17 +149,25 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + latent_format = None + sigma_data = 1.0 if sampling == "eps": sampling_type = comfy.model_sampling.EPS elif sampling == "v_prediction": sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "edm_playground_v2.5": + sampling_type = comfy.model_sampling.EDM + sigma_data = 0.5 + latent_format = comfy.latent_formats.SDXL_Playground_2_5() class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass - model_sampling = ModelSamplingAdvanced() - model_sampling.set_sigma_range(sigma_min, sigma_max) + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) m.add_object_patch("model_sampling", model_sampling) + if latent_format is not None: + m.add_object_patch("latent_format", latent_format) return (m, ) class RescaleCFG: @@ -201,5 +212,6 @@ class RescaleCFG: NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingStableCascade": ModelSamplingStableCascade, "RescaleCFG": RescaleCFG, } diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index dad1dd63..a25b73ca 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -87,6 +87,50 @@ class CLIPMergeSimple: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) + +class CLIPSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2, multiplier): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + + +class CLIPAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + class ModelMergeBlocks: @classmethod def INPUT_TYPES(s): @@ -119,6 +163,48 @@ class ModelMergeBlocks: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + + enable_modelspec = True + if isinstance(model.model, comfy.model_base.SDXL): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + elif isinstance(model.model, comfy.model_base.SDXLRefiner): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + else: + enable_modelspec = False + + if enable_modelspec: + metadata["modelspec.sai_model_spec"] = "1.0.0" + metadata["modelspec.implementation"] = "sgm" + metadata["modelspec.title"] = "{} {}".format(filename, counter) + + #TODO: + # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", + # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", + # "v2-inpainting" + + if model.model.model_type == comfy.model_base.ModelType.EPS: + metadata["modelspec.predict_key"] = "epsilon" + elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: + metadata["modelspec.predict_key"] = "v" + + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) + class CheckpointSave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -137,46 +223,7 @@ class CheckpointSave: CATEGORY = "advanced/model_merging" def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - prompt_info = "" - if prompt is not None: - prompt_info = json.dumps(prompt) - - metadata = {} - - enable_modelspec = True - if isinstance(model.model, comfy.model_base.SDXL): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" - elif isinstance(model.model, comfy.model_base.SDXLRefiner): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" - else: - enable_modelspec = False - - if enable_modelspec: - metadata["modelspec.sai_model_spec"] = "1.0.0" - metadata["modelspec.implementation"] = "sgm" - metadata["modelspec.title"] = "{} {}".format(filename, counter) - - #TODO: - # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", - # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", - # "v2-inpainting" - - if model.model.model_type == comfy.model_base.ModelType.EPS: - metadata["modelspec.predict_key"] = "epsilon" - elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: - metadata["modelspec.predict_key"] = "v" - - if not args.disable_metadata: - metadata["prompt"] = prompt_info - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - output_checkpoint = f"{filename}_{counter:05}_.safetensors" - output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - - comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) + save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) return {} class CLIPSave: @@ -276,6 +323,8 @@ NODE_CLASS_MAPPINGS = { "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, + "CLIPMergeSubtract": CLIPSubtract, + "CLIPMergeAdd": CLIPAdd, "CLIPSave": CLIPSave, "VAESave": VAESave, } diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py new file mode 100644 index 00000000..071521d8 --- /dev/null +++ b/comfy_extras/nodes_morphology.py @@ -0,0 +1,49 @@ +import torch +import comfy.model_management + +from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat + + +class Morphology: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), + "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + + CATEGORY = "image/postprocessing" + + def process(self, image, operation, kernel_size): + device = comfy.model_management.get_torch_device() + kernel = torch.ones(kernel_size, kernel_size, device=device) + image_k = image.to(device).movedim(-1, 1) + if operation == "erode": + output = erosion(image_k, kernel) + elif operation == "dilate": + output = dilation(image_k, kernel) + elif operation == "open": + output = opening(image_k, kernel) + elif operation == "close": + output = closing(image_k, kernel) + elif operation == "gradient": + output = gradient(image_k, kernel) + elif operation == "top_hat": + output = top_hat(image_k, kernel) + elif operation == "bottom_hat": + output = bottom_hat(image_k, kernel) + else: + raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") + img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) + return (img_out,) + +NODE_CLASS_MAPPINGS = { + "Morphology": Morphology, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Morphology": "ImageMorphology", +} \ No newline at end of file diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py new file mode 100644 index 00000000..dc73c552 --- /dev/null +++ b/comfy_extras/nodes_perpneg.py @@ -0,0 +1,55 @@ +import torch +import comfy.model_management +import comfy.sample +import comfy.samplers +import comfy.utils + + +class PerpNeg: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "empty_conditioning": ("CONDITIONING", ), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, empty_conditioning, neg_scale): + m = model.clone() + nocond = comfy.sample.convert_cond(empty_conditioning) + + def cfg_function(args): + model = args["model"] + noise_pred_pos = args["cond_denoised"] + noise_pred_neg = args["uncond_denoised"] + cond_scale = args["cond_scale"] + x = args["input"] + sigma = args["sigma"] + model_options = args["model_options"] + nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") + + (noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) + + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + cfg_result = x - cfg_result + return cfg_result + + m.set_model_sampler_cfg_function(cfg_function) + + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "PerpNeg": PerpNeg, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PerpNeg": "Perp-Neg", +} diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py new file mode 100644 index 00000000..90130142 --- /dev/null +++ b/comfy_extras/nodes_photomaker.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import folder_paths +import comfy.clip_model +import comfy.clip_vision +import comfy.ops + +# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "image_size": 224, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "hidden_act": "quick_gelu", +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = operations.LayerNorm(in_dim) + self.fc1 = operations.Linear(in_dim, hidden_dim) + self.fc2 = operations.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim, operations): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations) + self.layer_norm = operations.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + return updated_prompt_embeds + +class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): + def __init__(self): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + dtype = comfy.model_management.text_encoder_dtype(self.load_device) + + super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast) + self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048, comfy.ops.manual_cast) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[2] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + + return updated_prompt_embeds + + +class PhotoMakerLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + + RETURN_TYPES = ("PHOTOMAKER",) + FUNCTION = "load_photomaker_model" + + CATEGORY = "_for_testing/photomaker" + + def load_photomaker_model(self, photomaker_model_name): + photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) + photomaker_model = PhotoMakerIDEncoder() + data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) + if "id_encoder" in data: + data = data["id_encoder"] + photomaker_model.load_state_dict(data) + return (photomaker_model,) + + +class PhotoMakerEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker": ("PHOTOMAKER",), + "image": ("IMAGE",), + "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_photomaker" + + CATEGORY = "_for_testing/photomaker" + + def apply_photomaker(self, photomaker, image, clip, text): + special_token = "photomaker" + pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() + try: + index = text.split(" ").index(special_token) + 1 + except ValueError: + index = -1 + tokens = clip.tokenize(text, return_word_ids=True) + out_tokens = {} + for k in tokens: + out_tokens[k] = [] + for t in tokens[k]: + f = list(filter(lambda x: x[2] != index, t)) + while len(f) < len(t): + f.append(t[-1]) + out_tokens[k].append(f) + + cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) + + if index > 0: + token_index = index - 1 + num_id_images = 1 + class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)] + out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), + class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) + else: + out = cond + + return ([[out, {"pooled_output": pooled}]], ) + + +NODE_CLASS_MAPPINGS = { + "PhotoMakerLoader": PhotoMakerLoader, + "PhotoMakerEncode": PhotoMakerEncode, +} + diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 12704f54..cb5c7d22 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -33,6 +33,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') @@ -226,7 +227,7 @@ class Sharpen: batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) + kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) center = kernel_size // 2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 88a4ebe2..3010fbd4 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -99,10 +99,40 @@ class LatentRebatch: return (output_list,) +class ImageRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "image/batch" + + def rebatch(self, images, batch_size): + batch_size = batch_size[0] + + output_list = [] + all_images = [] + for img in images: + for i in range(img.shape[0]): + all_images.append(img[i:i+1]) + + for i in range(0, len(all_images), batch_size): + output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) + + return (output_list,) + NODE_CLASS_MAPPINGS = { "RebatchLatents": LatentRebatch, + "RebatchImages": ImageRebatch, } NODE_DISPLAY_NAME_MAPPINGS = { "RebatchLatents": "Rebatch Latents", -} \ No newline at end of file + "RebatchImages": "Rebatch Images", +} diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py new file mode 100644 index 00000000..bbd38080 --- /dev/null +++ b/comfy_extras/nodes_sag.py @@ -0,0 +1,170 @@ +import torch +from torch import einsum +import torch.nn.functional as F +import math + +from einops import rearrange, repeat +import os +from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION +import comfy.samplers + +# from comfy/ldm/modules/attention.py +# but modified to return attention scores as well as output +def attention_basic_with_sim(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + + h = heads + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale + + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return (out, sim) + +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() + mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] + + # Reshape + mask = ( + mask.reshape(b, *mid_shape) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + return img + +class SelfAttentionGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, scale, blur_sigma): + m = model.clone() + + attn_scores = None + + # TODO: make this work properly with chunked batches + # currently, we can only save the attn from one UNet call + def attn_and_record(q, k, v, extra_options): + nonlocal attn_scores + # if uncond, save the attention scores + heads = extra_options["n_heads"] + cond_or_uncond = extra_options["cond_or_uncond"] + b = q.shape[0] // len(cond_or_uncond) + if 1 in cond_or_uncond: + uncond_index = cond_or_uncond.index(1) + # do the entire attention operation, but save the attention scores to attn_scores + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) + # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] + n_slices = heads * b + attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] + return out + else: + return optimized_attention(q, k, v, heads=heads) + + def post_cfg_function(args): + nonlocal attn_scores + uncond_attn = attn_scores + + sag_scale = scale + sag_sigma = blur_sigma + sag_threshold = 1.0 + model = args["model"] + uncond_pred = args["uncond_denoised"] + uncond = args["uncond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"] + x = args["input"] + if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding + return cfg_result + + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + return cfg_result + (degraded - sag) * sag_scale + + m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) + + # from diffusers: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch + m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) + + return (m, ) + +NODE_CLASS_MAPPINGS = { + "SelfAttentionGuidance": SelfAttentionGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SelfAttentionGuidance": "Self-Attention Guidance", +} diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py new file mode 100644 index 00000000..28c1cb0f --- /dev/null +++ b/comfy_extras/nodes_sdupscale.py @@ -0,0 +1,47 @@ +import torch +import nodes +import comfy.utils + +class SD_4XUpscale_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/upscale_diffusion" + + def encode(self, images, positive, negative, scale_ratio, noise_augmentation): + width = max(1, round(images.shape[-2] * scale_ratio)) + height = max(1, round(images.shape[-3] * scale_ratio)) + + pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center") + + out_cp = [] + out_cn = [] + + for t in positive: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cp.append(n) + + for t in negative: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cn.append(n) + + latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) + return (out_cp, out_cn, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, +} diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py new file mode 100644 index 00000000..be2e34c2 --- /dev/null +++ b/comfy_extras/nodes_stable3d.py @@ -0,0 +1,143 @@ +import torch +import nodes +import comfy.utils + +def camera_embeddings(elevation, azimuth): + elevation = torch.as_tensor([elevation]) + azimuth = torch.as_tensor([azimuth]) + embeddings = torch.stack( + [ + torch.deg2rad( + (90 - elevation) - (90) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + torch.deg2rad( + 90 - torch.full_like(elevation, 0) + ), + ], dim=-1).unsqueeze(1) + + return embeddings + + +class StableZero123_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + cam_embeds = camera_embeddings(elevation, azimuth) + cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +class StableZero123_Conditioning_Batched: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + cam_embeds = [] + for i in range(batch_size): + cam_embeds.append(camera_embeddings(elevation, azimuth)) + elevation += elevation_batch_increment + azimuth += azimuth_batch_increment + + cam_embeds = torch.cat(cam_embeds, dim=0) + cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) + +class SV3D_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + azimuth = 0 + azimuth_increment = 360 / (max(video_frames, 2) - 1) + + elevations = [] + azimuths = [] + for i in range(video_frames): + elevations.append(elevation) + azimuths.append(azimuth) + azimuth += azimuth_increment + + positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + + +NODE_CLASS_MAPPINGS = { + "StableZero123_Conditioning": StableZero123_Conditioning, + "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, + "SV3D_Conditioning": SV3D_Conditioning, +} diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py new file mode 100644 index 00000000..fcbbeb27 --- /dev/null +++ b/comfy_extras/nodes_stable_cascade.py @@ -0,0 +1,140 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import nodes +import comfy.utils + + +class StableCascade_EmptyLatentImage: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "latent/stable_cascade" + + def generate(self, width, height, compression, batch_size=1): + c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) + b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageC_VAEEncode: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "latent/stable_cascade" + + def generate(self, image, vae, compression): + width = image.shape[-2] + height = image.shape[-3] + out_width = (width // compression) * vae.downscale_ratio + out_height = (height // compression) * vae.downscale_ratio + + s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) + + c_latent = vae.encode(s[:,:,:,:3]) + b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageB_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "conditioning": ("CONDITIONING",), + "stage_c": ("LATENT",), + }} + RETURN_TYPES = ("CONDITIONING",) + + FUNCTION = "set_prior" + + CATEGORY = "conditioning/stable_cascade" + + def set_prior(self, conditioning, stage_c): + c = [] + for t in conditioning: + d = t[1].copy() + d['stable_cascade_prior'] = stage_c['samples'] + n = [t[0], d] + c.append(n) + return (c, ) + +class StableCascade_SuperResolutionControlnet: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + }} + RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") + RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, image, vae): + width = image.shape[-2] + height = image.shape[-3] + batch_size = image.shape[0] + controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1) + + c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) + b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) + return (controlnet_input, { + "samples": c_latent, + }, { + "samples": b_latent, + }) + +NODE_CLASS_MAPPINGS = { + "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, + "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, + "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, + "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, +} diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index 26a717a3..1a0189ed 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -3,6 +3,7 @@ import torch import comfy.utils import comfy.sd import folder_paths +import comfy_extras.nodes_model_merging class ImageOnlyCheckpointLoader: @@ -78,10 +79,54 @@ class VideoLinearCFGGuidance: m.set_model_sampler_cfg_function(linear_cfg) return (m, ) +class VideoTriangleCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + period = 1.0 + values = torch.linspace(0, 1, cond.shape[0], device=cond.device) + values = 2 * (values / period - torch.floor(values / period + 0.5)).abs() + scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1)) + + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + +class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): + CATEGORY = "_for_testing" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip_vision": ("CLIP_VISION",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + + def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None): + comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + return {} + NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, + "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/cuda_malloc.py b/cuda_malloc.py index 144cdacd..eb2857c5 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,6 +1,7 @@ import os import importlib.util from comfy.cli_args import args +import subprocess #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): @@ -34,14 +35,19 @@ def get_gpu_names(): return gpu_names return enum_display_devices() else: - return set() + gpu_names = set() + out = subprocess.check_output(['nvidia-smi', '-L']) + for l in out.split(b'\n'): + if len(l) > 0: + gpu_names.add(l.decode('utf-8').split(' (UUID')[0]) + return gpu_names blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620", "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", - "GeForce GTX 1650", "GeForce GTX 1630" + "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60" } def cuda_malloc_supported(): diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 733014f3..f0663259 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -6,6 +6,8 @@ class Example: ------------- INPUT_TYPES (dict): Tell the main program input parameters of nodes. + IS_CHANGED: + optional method to control when the node is re executed. Attributes ---------- @@ -89,6 +91,20 @@ class Example: image = 1.0 - image return (image,) + """ + The node will always be re executed if any of the inputs change but + this method can be used to force the node to execute again even when the inputs don't change. + You can make this node return a number or a string. This value will be compared to the one returned the last time the node was + executed, if it is different the node will be executed again. + This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash + changes between executions the LoadImage node is executed again. + """ + #@classmethod + #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): + # return "" + +# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension +# WEB_DIRECTORY = "./somejs" # A dictionary that contains all nodes you want to export with their names # NOTE: names should be globally unique diff --git a/custom_nodes/websocket_image_save.py b/custom_nodes/websocket_image_save.py new file mode 100644 index 00000000..5aa57364 --- /dev/null +++ b/custom_nodes/websocket_image_save.py @@ -0,0 +1,45 @@ +from PIL import Image, ImageOps +from io import BytesIO +import numpy as np +import struct +import comfy.utils +import time + +#You can use this node to save full size images through the websocket, the +#images will be sent in exactly the same format as the image previews: as +#binary images on the websocket with a 8 byte header indicating the type +#of binary message (first 4 bytes) and the image format (next 4 bytes). + +#Note that no metadata will be put in the images saved with this node. + +class SaveImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ),} + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "api/image" + + def save_images(self, images): + pbar = comfy.utils.ProgressBar(images.shape[0]) + step = 0 + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pbar.update_absolute(step, images.shape[0], ("PNG", img, None)) + step += 1 + + return {} + + def IS_CHANGED(s, images): + return time.time() + +NODE_CLASS_MAPPINGS = { + "SaveImageWebsocket": SaveImageWebsocket, +} diff --git a/execution.py b/execution.py index 7db1f095..1b8f606a 100644 --- a/execution.py +++ b/execution.py @@ -1,12 +1,11 @@ -import os import sys import copy -import json import logging import threading import heapq import traceback -import gc +import inspect +from typing import List, Literal, NamedTuple, Optional import torch import nodes @@ -36,8 +35,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da if h[x] == "PROMPT": input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": - if "extra_pnginfo" in extra_data: - input_data_all[x] = [extra_data['extra_pnginfo']] + input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] return input_data_all @@ -195,8 +193,12 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute return (True, None, None) -def recursive_will_execute(prompt, outputs, current_item): +def recursive_will_execute(prompt, outputs, current_item, memo={}): unique_id = current_item + + if unique_id in memo: + return memo[unique_id] + inputs = prompt[unique_id]['inputs'] will_execute = [] if unique_id in outputs: @@ -208,9 +210,10 @@ def recursive_will_execute(prompt, outputs, current_item): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id) + will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo) - return will_execute + [unique_id] + memo[unique_id] = will_execute + [unique_id] + return memo[unique_id] def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): unique_id = current_item @@ -267,11 +270,21 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): + self.server = server + self.reset() + + def reset(self): self.outputs = {} self.object_storage = {} self.outputs_ui = {} + self.status_messages = [] + self.success = True self.old_prompt = {} - self.server = server + + def add_message(self, event, data, broadcast: bool): + self.status_messages.append((event, data)) + if self.server.client_id is not None or broadcast: + self.server.send_sync(event, data, self.server.client_id) def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] @@ -286,23 +299,22 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), } - self.server.send_sync("execution_interrupted", mes, self.server.client_id) + self.add_message("execution_interrupted", mes, broadcast=True) else: - if self.server.client_id is not None: - mes = { - "prompt_id": prompt_id, - "node_id": node_id, - "node_type": class_type, - "executed": list(executed), - - "exception_message": error["exception_message"], - "exception_type": error["exception_type"], - "traceback": error["traceback"], - "current_inputs": error["current_inputs"], - "current_outputs": error["current_outputs"], - } - self.server.send_sync("execution_error", mes, self.server.client_id) + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + "exception_message": error["exception_message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.add_message("execution_error", mes, broadcast=False) + # Next, remove the subsequent outputs since they will not be executed to_delete = [] for o in self.outputs: @@ -323,8 +335,8 @@ class PromptExecutor: else: self.server.client_id = None - if self.server.client_id is not None: - self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + self.status_messages = [] + self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): #delete cached outputs if nodes don't exist for them @@ -356,9 +368,9 @@ class PromptExecutor: d = self.outputs_ui.pop(x) del d - comfy.model_management.cleanup_models() - if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + self.add_message("execution_cached", + { "nodes": list(current_outputs) , "prompt_id": prompt_id}, + broadcast=False) executed = set() output_node_id = None to_execute = [] @@ -368,20 +380,23 @@ class PromptExecutor: while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + memo = {} + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute))) output_node_id = to_execute.pop(0)[-1] # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if success is not True: + self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) + if self.success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) break for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() @@ -400,6 +415,10 @@ def validate_inputs(prompt, item, validated): errors = [] valid = True + validate_function_inputs = [] + if hasattr(obj_class, "VALIDATE_INPUTS"): + validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + for x in required_inputs: if x not in inputs: error = { @@ -529,29 +548,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) - #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for i, r in enumerate(ret): - if r is not True: - details = f"{x}" - if r is not False: - details += f" - {str(r)}" - - error = { - "type": "custom_validation_failed", - "message": "Custom validation failed for node", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - else: + if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info @@ -578,6 +575,35 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue + if len(validate_function_inputs) > 0: + input_data_all = get_input_data(inputs, obj_class, unique_id) + input_filtered = {} + for x in input_data_all: + if x in validate_function_inputs: + input_filtered[x] = input_data_all[x] + + #ret = obj_class.VALIDATE_INPUTS(**input_filtered) + ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") + for x in input_filtered: + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f" - {str(r)}" + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) else: @@ -692,6 +718,7 @@ class PromptQueue: self.queue = [] self.currently_running = {} self.history = {} + self.flags = {} server.prompt_queue = self def put(self, item): @@ -713,14 +740,27 @@ class PromptQueue: self.server.queue_updated() return (item, i) - def task_done(self, item_id, outputs): + class ExecutionStatus(NamedTuple): + status_str: Literal['success', 'error'] + completed: bool + messages: List[str] + + def task_done(self, item_id, outputs, + status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) - self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } - for o in outputs: - self.history[prompt[1]]["outputs"][o] = outputs[o] + + status_dict: Optional[dict] = None + if status is not None: + status_dict = copy.deepcopy(status._asdict()) + + self.history[prompt[1]] = { + "prompt": prompt, + "outputs": copy.deepcopy(outputs), + 'status': status_dict, + } self.server.queue_updated() def get_current_queue(self): @@ -778,3 +818,17 @@ class PromptQueue: def delete_history_item(self, id_to_delete): with self.mutex: self.history.pop(id_to_delete, None) + + def set_flag(self, name, data): + with self.mutex: + self.flags[name] = data + self.not_empty.notify() + + def get_flags(self, reset=True): + with self.mutex: + if reset: + ret = self.flags + self.flags = {} + return ret + else: + return self.flags.copy() diff --git a/folder_paths.py b/folder_paths.py index 98704945..f1bf40f8 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -29,11 +29,14 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) +folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) + folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") filename_list_cache = {} @@ -137,15 +140,27 @@ def recursive_search(directory, excluded_dir_names=None): excluded_dir_names = [] result = [] - dirs = {directory: os.path.getmtime(directory)} + dirs = {} + + # Attempt to add the initial directory to dirs with error handling + try: + dirs[directory] = os.path.getmtime(directory) + except FileNotFoundError: + print(f"Warning: Unable to access {directory}. Skipping this path.") + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) result.append(relative_path) + for d in subdirs: path = os.path.join(dirpath, d) - dirs[path] = os.path.getmtime(path) + try: + dirs[path] = os.path.getmtime(path) + except FileNotFoundError: + print(f"Warning: Unable to access {path}. Skipping this path.") + continue return result, dirs def filter_files_extensions(files, extensions): @@ -184,8 +199,7 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] - if time.perf_counter() < (out[2] + 0.5): - return out + for x in out[1]: time_modified = out[1][x] folder = x diff --git a/latent_preview.py b/latent_preview.py index 61754751..4dbcbf45 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -6,6 +6,7 @@ from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD import folder_paths import comfy.utils +import logging MAX_PREVIEW_RESOLUTION = 512 @@ -70,7 +71,7 @@ def get_previewer(device, latent_format): taesd = TAESD(None, taesd_decoder_path).to(device) previewer = TAESDPreviewerImpl(taesd) else: - print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) + logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: if latent_format.latent_rgb_factors is not None: diff --git a/main.py b/main.py index 1f9c5f44..b3a3ebea 100644 --- a/main.py +++ b/main.py @@ -54,15 +54,19 @@ import threading import gc from comfy.cli_args import args +import logging if os.name == "nt": - import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - print("Set cuda device to:", args.cuda_device) + logging.info("Set cuda device to: {}".format(args.cuda_device)) + + if args.deterministic: + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" import cuda_malloc @@ -84,7 +88,7 @@ def cuda_malloc_warning(): if b in device_name: cuda_malloc_warning = True if cuda_malloc_warning: - print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") + logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") def prompt_worker(q, server): e = execution.PromptExecutor(server) @@ -93,7 +97,7 @@ def prompt_worker(q, server): gc_collect_interval = 10.0 while True: - timeout = None + timeout = 1000.0 if need_gc: timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) @@ -102,19 +106,40 @@ def prompt_worker(q, server): item, item_id = queue_item execution_start_time = time.perf_counter() prompt_id = item[1] + server.last_prompt_id = prompt_id + e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True - q.task_done(item_id, e.outputs_ui) + q.task_done(item_id, + e.outputs_ui, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + messages=e.status_messages)) if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time - print("Prompt executed in {:.2f} seconds".format(execution_time)) + logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) + + flags = q.get_flags() + free_memory = flags.get("free_memory", False) + + if flags.get("unload_models", free_memory): + comfy.model_management.unload_all_models() + need_gc = True + last_gc_collect = 0 + + if free_memory: + e.reset() + need_gc = True + last_gc_collect = 0 if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: + comfy.model_management.cleanup_models() gc.collect() comfy.model_management.soft_empty_cache() last_gc_collect = current_time @@ -127,7 +152,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): def hook(value, total, preview_image): comfy.model_management.throw_exception_if_processing_interrupted() - server.send_sync("progress", {"value": value, "max": total}, server.client_id) + progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + + server.send_sync("progress", progress, server.client_id) if preview_image is not None: server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) @@ -156,17 +183,24 @@ def load_extra_path_config(yaml_path): full_path = y if base_path is not None: full_path = os.path.join(base_path, full_path) - print("Adding extra search path", x, full_path) + logging.info("Adding extra search path {} {}".format(x, full_path)) folder_paths.add_model_folder_path(x, full_path) if __name__ == "__main__": if args.temp_directory: temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") - print(f"Setting temp directory to: {temp_dir}") + logging.info(f"Setting temp directory to: {temp_dir}") folder_paths.set_temp_directory(temp_dir) cleanup_temp() + if args.windows_standalone_build: + try: + import new_updater + new_updater.update_windows_updater() + except: + pass + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server = server.PromptServer(loop) @@ -191,7 +225,7 @@ if __name__ == "__main__": if args.output_directory: output_dir = os.path.abspath(args.output_directory) - print(f"Setting output directory to: {output_dir}") + logging.info(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes @@ -201,7 +235,7 @@ if __name__ == "__main__": if args.input_directory: input_dir = os.path.abspath(args.input_directory) - print(f"Setting input directory to: {input_dir}") + logging.info(f"Setting input directory to: {input_dir}") folder_paths.set_input_directory(input_dir) if args.quick_test_for_ci: @@ -219,6 +253,6 @@ if __name__ == "__main__": try: loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) except KeyboardInterrupt: - print("\nStopped server") + logging.info("\nStopped server") cleanup_temp() diff --git a/models/photomaker/put_photomaker_models_here b/models/photomaker/put_photomaker_models_here new file mode 100644 index 00000000..e69de29b diff --git a/new_updater.py b/new_updater.py new file mode 100644 index 00000000..a49e0877 --- /dev/null +++ b/new_updater.py @@ -0,0 +1,35 @@ +import os +import shutil + +base_path = os.path.dirname(os.path.realpath(__file__)) + + +def update_windows_updater(): + top_path = os.path.dirname(base_path) + updater_path = os.path.join(base_path, ".ci/update_windows/update.py") + bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat") + + dest_updater_path = os.path.join(top_path, "update/update.py") + dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat") + dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat") + + try: + with open(dest_bat_path, 'rb') as f: + contents = f.read() + except: + return + + if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"): + return + + shutil.copy(updater_path, dest_updater_path) + try: + with open(dest_bat_deps_path, 'rb') as f: + contents = f.read() + contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause') + with open(dest_bat_deps_path, 'wb') as f: + f.write(contents) + except: + pass + shutil.copy(bat_path, dest_bat_path) + print("Updated the windows standalone package updater.") diff --git a/nodes.py b/nodes.py index 24e591fd..6c05e3a2 100644 --- a/nodes.py +++ b/nodes.py @@ -8,8 +8,9 @@ import traceback import math import time import random +import logging -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch @@ -40,7 +41,7 @@ def before_node_execution(): def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) -MAX_RESOLUTION=8192 +MAX_RESOLUTION=16384 class CLIPTextEncode: @classmethod @@ -83,7 +84,7 @@ class ConditioningAverage : out = [] if len(conditioning_from) > 1: - print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") cond_from = conditioning_from[0][0] pooled_output_from = conditioning_from[0][1].get("pooled_output", None) @@ -122,7 +123,7 @@ class ConditioningConcat: out = [] if len(conditioning_from) > 1: - print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") cond_from = conditioning_from[0][0] @@ -184,6 +185,26 @@ class ConditioningSetAreaPercentage: c.append(n) return (c, ) +class ConditioningSetAreaStrength: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, strength): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['strength'] = strength + c.append(n) + return (c, ) + + class ConditioningSetMask: @classmethod def INPUT_TYPES(s): @@ -289,18 +310,7 @@ class VAEEncode: CATEGORY = "latent" - @staticmethod - def vae_encode_crop_pixels(pixels): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 - if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] - return pixels - def encode(self, vae, pixels): - pixels = self.vae_encode_crop_pixels(pixels) t = vae.encode(pixels[:,:,:,:3]) return ({"samples":t}, ) @@ -316,7 +326,6 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size): - pixels = VAEEncode.vae_encode_crop_pixels(pixels) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) return ({"samples":t}, ) @@ -330,14 +339,14 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 + x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio + y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 + x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] @@ -359,6 +368,62 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class InpaintModelConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + "mask": ("MASK", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, positive, negative, pixels, vae, mask): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + + orig_pixels = pixels + pixels = orig_pixels.clone() + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + + m = (1.0 - mask.round()).squeeze(1) + for i in range(3): + pixels[:,:,:,i] -= 0.5 + pixels[:,:,:,i] *= m + pixels[:,:,:,i] += 0.5 + concat_latent = vae.encode(pixels) + orig_latent = vae.encode(orig_pixels) + + out_latent = {} + + out_latent["samples"] = orig_latent + out_latent["noise_mask"] = mask + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + d["concat_mask"] = mask + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + + class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -778,15 +843,20 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), + "type": (["stable_diffusion", "stable_cascade"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name): + def load_clip(self, clip_name, type="stable_diffusion"): + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + if type == "stable_cascade": + clip_type = comfy.sd.CLIPType.STABLE_CASCADE + clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class DualCLIPLoader: @@ -934,7 +1004,7 @@ class GLIGENTextBoxApply: def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): c = [] - cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected") for t in conditioning_to: n = [t[0], t[1].copy()] position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)] @@ -947,8 +1017,8 @@ class GLIGENTextBoxApply: return (c, ) class EmptyLatentImage: - def __init__(self, device="cpu"): - self.device = device + def __init__(self): + self.device = comfy.model_management.intermediate_device() @classmethod def INPUT_TYPES(s): @@ -961,7 +1031,7 @@ class EmptyLatentImage: CATEGORY = "latent" def generate(self, width, height, batch_size=1): - latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) return ({"samples":latent}, ) @@ -1358,7 +1428,7 @@ class SaveImage: filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() - for image in images: + for (batch_number, image) in enumerate(images): i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) metadata = None @@ -1370,7 +1440,8 @@ class SaveImage: for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) - file = f"{filename}_{counter:05}_.png" + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, @@ -1410,17 +1481,32 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) + img = Image.open(image_path) + output_images = [] + output_masks = [] + for i in ImageSequence.Iterator(img): + i = ImageOps.exif_transpose(i) + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + if len(output_images) > 1: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (image, mask.unsqueeze(0)) + output_image = output_images[0] + output_mask = output_masks[0] + + return (output_image, output_mask) @classmethod def IS_CHANGED(s, image): @@ -1457,6 +1543,8 @@ class LoadImageMask: i = Image.open(image_path) i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) i = i.convert("RGBA") mask = None c = channel[0].upper() @@ -1478,13 +1566,10 @@ class LoadImageMask: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, image, channel): + def VALIDATE_INPUTS(s, image): if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) - if channel not in s._color_channels: - return "Invalid color channel: {}".format(channel) - return True class ImageScale: @@ -1614,10 +1699,11 @@ class ImagePadForOutpaint: def expand_image(self, image, left, top, right, bottom, feathering): d1, d2, d3, d4 = image.size() - new_image = torch.zeros( + new_image = torch.ones( (d1, d2 + top + bottom, d3 + left + right, d4), dtype=torch.float32, - ) + ) * 0.5 + new_image[:, top:top + d2, left:left + d3, :] = image mask = torch.ones( @@ -1683,6 +1769,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, + "ConditioningSetAreaStrength": ConditioningSetAreaStrength, "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, @@ -1709,6 +1796,7 @@ NODE_CLASS_MAPPINGS = { "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "GLIGENLoader": GLIGENLoader, "GLIGENTextBoxApply": GLIGENTextBoxApply, + "InpaintModelConditioning": InpaintModelConditioning, "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, @@ -1812,11 +1900,11 @@ def load_custom_node(module_path, ignore=set()): NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) return True else: - print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") return False except Exception as e: - print(traceback.format_exc()) - print(f"Cannot import {module_path} module for custom nodes:", e) + logging.warning(traceback.format_exc()) + logging.warning(f"Cannot import {module_path} module for custom nodes: {e}") return False def load_custom_nodes(): @@ -1837,14 +1925,14 @@ def load_custom_nodes(): node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: - print("\nImport times for custom nodes:") + logging.info("\nImport times for custom nodes:") for n in sorted(node_import_times): if n[2]: import_message = "" else: import_message = " (IMPORT FAILED)" - print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) - print() + logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) + logging.info("") def init_custom_nodes(): extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras") @@ -1867,9 +1955,31 @@ def init_custom_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_sag.py", + "nodes_perpneg.py", + "nodes_stable3d.py", + "nodes_sdupscale.py", + "nodes_photomaker.py", + "nodes_cond.py", + "nodes_morphology.py", + "nodes_stable_cascade.py", + "nodes_differential_diffusion.py", ] + import_failed = [] for node_file in extras_files: - load_custom_node(os.path.join(extras_dir, node_file)) + if not load_custom_node(os.path.join(extras_dir, node_file)): + import_failed.append(node_file) load_custom_nodes() + + if len(import_failed) > 0: + logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n") + for node in import_failed: + logging.warning("IMPORT FAILED: {}".format(node)) + logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.") + if args.windows_standalone_build: + logging.warning("Please run the update script: update/update_comfyui.bat") + else: + logging.warning("Please do a: pip install -r requirements.txt") + logging.warning("") diff --git a/requirements.txt b/requirements.txt index 4a303b81..a1a29218 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,14 @@ torch torchsde +torchvision einops transformers>=4.25.1 safetensors>=0.3.0 aiohttp -accelerate pyyaml Pillow scipy tqdm psutil +kornia>=0.7.1 spandrel==0.3.1 \ No newline at end of file diff --git a/script_examples/websockets_api_example_ws_images.py b/script_examples/websockets_api_example_ws_images.py new file mode 100644 index 00000000..73748862 --- /dev/null +++ b/script_examples/websockets_api_example_ws_images.py @@ -0,0 +1,159 @@ +#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without +#them being saved to disk + +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import json +import urllib.request +import urllib.parse + +server_address = "127.0.0.1:8188" +client_id = str(uuid.uuid4()) + +def queue_prompt(prompt): + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + +def get_image(filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: + return response.read() + +def get_history(prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: + return json.loads(response.read()) + +def get_images(ws, prompt): + prompt_id = queue_prompt(prompt)['prompt_id'] + output_images = {} + current_node = "" + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['prompt_id'] == prompt_id: + if data['node'] is None: + break #Execution is done + else: + current_node = data['node'] + else: + if current_node == 'save_image_websocket_node': + images_output = output_images.get(current_node, []) + images_output.append(out[8:]) + output_images[current_node] = images_output + + return output_images + +prompt_text = """ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": 8, + "denoise": 1, + "latent_image": [ + "5", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 8566257, + "steps": 20 + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "v1-5-pruned-emaonly.ckpt" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": 512, + "width": 512 + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "masterpiece best quality girl" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "bad hands" + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "save_image_websocket_node": { + "class_type": "SaveImageWebsocket", + "inputs": { + "images": [ + "8", + 0 + ] + } + } +} +""" + +prompt = json.loads(prompt_text) +#set the text prompt for our positive CLIPTextEncode +prompt["6"]["inputs"]["text"] = "masterpiece best quality man" + +#set the seed for our KSampler node +prompt["3"]["inputs"]["seed"] = 5 + +ws = websocket.WebSocket() +ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) +images = get_images(ws, prompt) + +#Commented out code to display the output images: + +# for node_id in images: +# for image_data in images[node_id]: +# from PIL import Image +# import io +# image = Image.open(io.BytesIO(image_data)) +# image.show() + diff --git a/server.py b/server.py index 9b1e3269..5642bd5e 100644 --- a/server.py +++ b/server.py @@ -15,21 +15,16 @@ from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo from io import BytesIO -try: - import aiohttp - from aiohttp import web -except ImportError: - print("Module 'aiohttp' not installed. Please install it via:") - print("pip install aiohttp") - print("or") - print("pip install -r requirements.txt") - sys.exit() +import aiohttp +from aiohttp import web +import logging import mimetypes from comfy.cli_args import args import comfy.utils import comfy.model_management +from app.user_manager import UserManager class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -39,7 +34,7 @@ async def send_socket_catch_exception(function, message): try: await function(message) except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err: - print("send error:", err) + logging.warning("send error: {}".format(err)) @web.middleware async def cache_control(request: web.Request, handler): @@ -72,6 +67,7 @@ class PromptServer(): mimetypes.init() mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' + self.user_manager = UserManager() self.supports = ["custom_nodes_from_web"] self.prompt_queue = None self.loop = loop @@ -116,7 +112,7 @@ class PromptServer(): async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: - print('ws connection closed with exception %s' % ws.exception()) + logging.warning('ws connection closed with exception %s' % ws.exception()) finally: self.sockets.pop(sid, None) return ws @@ -417,8 +413,8 @@ class PromptServer(): try: out[x] = node_info(x) except Exception as e: - print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr) - traceback.print_exc() + logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") + logging.error(traceback.format_exc()) return web.json_response(out) @routes.get("/object_info/{node_class}") @@ -451,7 +447,7 @@ class PromptServer(): @routes.post("/prompt") async def post_prompt(request): - print("got prompt") + logging.info("got prompt") resp_code = 200 out_string = "" json_data = await request.json() @@ -483,7 +479,7 @@ class PromptServer(): response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: - print("invalid prompt:", valid[1]) + logging.warning("invalid prompt: {}".format(valid[1])) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: return web.json_response({"error": "no prompt", "node_errors": []}, status=400) @@ -507,6 +503,17 @@ class PromptServer(): nodes.interrupt_processing() return web.Response(status=200) + @routes.post("/free") + async def post_free(request): + json_data = await request.json() + unload_models = json_data.get("unload_models", False) + free_memory = json_data.get("free_memory", False) + if unload_models: + self.prompt_queue.set_flag("unload_models", unload_models) + if free_memory: + self.prompt_queue.set_flag("free_memory", free_memory) + return web.Response(status=200) + @routes.post("/history") async def post_history(request): json_data = await request.json() @@ -521,15 +528,16 @@ class PromptServer(): return web.Response(status=200) def add_routes(self): + self.user_manager.add_routes(self.routes) self.app.add_routes(self.routes) for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([ - web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True), + web.static('/extensions/' + urllib.parse.quote(name), dir), ]) self.app.add_routes([ - web.static('/', self.web_root, follow_symlinks=True), + web.static('/', self.web_root), ]) def get_queue_info(self): @@ -584,7 +592,8 @@ class PromptServer(): message = self.encode_bytes(event, data) if sid is None: - for ws in self.sockets.values(): + sockets = list(self.sockets.values()) + for ws in sockets: await send_socket_catch_exception(ws.send_bytes, message) elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_bytes, message) @@ -593,7 +602,8 @@ class PromptServer(): message = {"type": event, "data": data} if sid is None: - for ws in self.sockets.values(): + sockets = list(self.sockets.values()) + for ws in sockets: await send_socket_catch_exception(ws.send_json, message) elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_json, message) @@ -616,11 +626,9 @@ class PromptServer(): site = web.TCPSite(runner, address, port) await site.start() - if address == '': - address = '0.0.0.0' if verbose: - print("Starting server\n") - print("To see the GUI go to: http://{}:{}".format(address, port)) + logging.info("Starting server\n") + logging.info("To see the GUI go to: http://{}:{}".format(address, port)) if call_on_start is not None: call_on_start(address, port) @@ -632,7 +640,7 @@ class PromptServer(): try: json_data = handler(json_data) except Exception as e: - print(f"[ERROR] An error occurred during the on_prompt_handler processing") - traceback.print_exc() + logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing") + logging.warning(traceback.format_exc()) return json_data diff --git a/tests-ui/afterSetup.js b/tests-ui/afterSetup.js new file mode 100644 index 00000000..983f3af6 --- /dev/null +++ b/tests-ui/afterSetup.js @@ -0,0 +1,9 @@ +const { start } = require("./utils"); +const lg = require("./utils/litegraph"); + +// Load things once per test file before to ensure its all warmed up for the tests +beforeAll(async () => { + lg.setup(global); + await start({ resetEnv: true }); + lg.teardown(global); +}); diff --git a/tests-ui/babel.config.json b/tests-ui/babel.config.json index 526ddfd8..f27d6c39 100644 --- a/tests-ui/babel.config.json +++ b/tests-ui/babel.config.json @@ -1,3 +1,4 @@ { - "presets": ["@babel/preset-env"] + "presets": ["@babel/preset-env"], + "plugins": ["babel-plugin-transform-import-meta"] } diff --git a/tests-ui/jest.config.js b/tests-ui/jest.config.js index b5a5d646..86fff505 100644 --- a/tests-ui/jest.config.js +++ b/tests-ui/jest.config.js @@ -2,8 +2,10 @@ const config = { testEnvironment: "jsdom", setupFiles: ["./globalSetup.js"], + setupFilesAfterEnv: ["./afterSetup.js"], clearMocks: true, resetModules: true, + testTimeout: 10000 }; module.exports = config; diff --git a/tests-ui/package-lock.json b/tests-ui/package-lock.json index 35911cd7..0f409ca2 100644 --- a/tests-ui/package-lock.json +++ b/tests-ui/package-lock.json @@ -11,6 +11,7 @@ "devDependencies": { "@babel/preset-env": "^7.22.20", "@types/jest": "^29.5.5", + "babel-plugin-transform-import-meta": "^2.2.1", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0" } @@ -2591,6 +2592,19 @@ "@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0" } }, + "node_modules/babel-plugin-transform-import-meta": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/babel-plugin-transform-import-meta/-/babel-plugin-transform-import-meta-2.2.1.tgz", + "integrity": "sha512-AxNh27Pcg8Kt112RGa3Vod2QS2YXKKJ6+nSvRtv7qQTJAdx0MZa4UHZ4lnxHUWA2MNbLuZQv5FVab4P1CoLOWw==", + "dev": true, + "dependencies": { + "@babel/template": "^7.4.4", + "tslib": "^2.4.0" + }, + "peerDependencies": { + "@babel/core": "^7.10.0" + } + }, "node_modules/babel-preset-current-node-syntax": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz", @@ -5233,6 +5247,12 @@ "node": ">=12" } }, + "node_modules/tslib": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==", + "dev": true + }, "node_modules/type-detect": { "version": "4.0.8", "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", diff --git a/tests-ui/package.json b/tests-ui/package.json index e7b60ad8..ae7e4908 100644 --- a/tests-ui/package.json +++ b/tests-ui/package.json @@ -24,6 +24,7 @@ "devDependencies": { "@babel/preset-env": "^7.22.20", "@types/jest": "^29.5.5", + "babel-plugin-transform-import-meta": "^2.2.1", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0" } diff --git a/tests-ui/tests/extensions.test.js b/tests-ui/tests/extensions.test.js index b82e55c3..159e5113 100644 --- a/tests-ui/tests/extensions.test.js +++ b/tests-ui/tests/extensions.test.js @@ -52,7 +52,7 @@ describe("extensions", () => { const nodeNames = Object.keys(defs); const nodeCount = nodeNames.length; expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); - for (let i = 0; i < nodeCount; i++) { + for (let i = 0; i < 10; i++) { // It should be send the JS class and the original JSON definition const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; @@ -133,7 +133,7 @@ describe("extensions", () => { expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); - }); + }, 15000); it("allows custom nodeDefs and widgets to be registered", async () => { const widgetMock = jest.fn((node, inputName, inputData, app) => { diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index ce54c115..53a5828d 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -1,7 +1,7 @@ // @ts-check /// -const { start, createDefaultWorkflow } = require("../utils"); +const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils"); const lg = require("../utils/litegraph"); describe("group node", () => { @@ -273,7 +273,7 @@ describe("group node", () => { let reroutes = []; let prevNode = nodes.ckpt; - for(let i = 0; i < 5; i++) { + for (let i = 0; i < 5; i++) { const reroute = ez.Reroute(); prevNode.outputs[0].connectTo(reroute.inputs[0]); prevNode = reroute; @@ -283,7 +283,7 @@ describe("group node", () => { const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]); expect((await graph.toPrompt()).output).toEqual(getOutput()); - + group.menu["Convert to nodes"].call(); expect((await graph.toPrompt()).output).toEqual(getOutput()); }); @@ -383,6 +383,43 @@ describe("group node", () => { getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id]) ); }); + test("groups can connect to each other via internal reroutes", async () => { + const { ez, graph, app } = await start(); + + const latent = ez.EmptyLatentImage(); + const vae = ez.VAELoader(); + const latentReroute = ez.Reroute(); + const vaeReroute = ez.Reroute(); + + latent.outputs[0].connectTo(latentReroute.inputs[0]); + vae.outputs[0].connectTo(vaeReroute.inputs[0]); + + const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]); + group1.menu.Clone.call(); + expect(app.graph._nodes).toHaveLength(4); + const group2 = graph.find(app.graph._nodes[3]); + expect(group2.node.type).toEqual("workflow/test"); + expect(group2.id).not.toEqual(group1.id); + + group1.outputs.VAE.connectTo(group2.inputs.VAE); + group1.outputs.LATENT.connectTo(group2.inputs.LATENT); + + const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE); + const preview = ez.PreviewImage(decode.outputs[0]); + + const output = { + [latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" }, + [vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, + [decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" }, + [preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" }, + }; + expect((await graph.toPrompt()).output).toEqual(output); + + // Ensure missing connections dont cause errors + group2.inputs.VAE.disconnect(); + delete output[decode.id].inputs.vae; + expect((await graph.toPrompt()).output).toEqual(output); + }); test("displays generated image on group node", async () => { const { ez, graph, app } = await start(); const nodes = createDefaultWorkflow(ez, graph); @@ -642,6 +679,55 @@ describe("group node", () => { 2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" }, }); }); + test("correctly handles widget inputs", async () => { + const { ez, graph, app } = await start(); + const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0]; + + const image = ez.LoadImage(); + const scale1 = ez.ImageScaleBy(image.outputs[0]); + const scale2 = ez.ImageScaleBy(image.outputs[0]); + const preview1 = ez.PreviewImage(scale1.outputs[0]); + const preview2 = ez.PreviewImage(scale2.outputs[0]); + scale1.widgets.upscale_method.value = upscaleMethods[1]; + scale1.widgets.upscale_method.convertToInput(); + + const group = await convertToGroup(app, graph, "test", [scale1, scale2]); + expect(group.inputs.length).toBe(3); + expect(group.inputs[0].input.type).toBe("IMAGE"); + expect(group.inputs[1].input.type).toBe("IMAGE"); + expect(group.inputs[2].input.type).toBe("COMBO"); + + // Ensure links are maintained + expect(group.inputs[0].connection?.originNode?.id).toBe(image.id); + expect(group.inputs[1].connection?.originNode?.id).toBe(image.id); + expect(group.inputs[2].connection).toBeFalsy(); + + // Ensure primitive gets correct type + const primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(group.inputs[2]); + expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods); + expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied + primitive.widgets.value.value = upscaleMethods[1]; + + await checkBeforeAndAfterReload(graph, async (r) => { + const scale1id = r ? `${group.id}:0` : scale1.id; + const scale2id = r ? `${group.id}:1` : scale2.id; + // Ensure widget value is applied to prompt + expect((await graph.toPrompt()).output).toStrictEqual({ + [image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" }, + [scale1id]: { + inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] }, + class_type: "ImageScaleBy", + }, + [scale2id]: { + inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] }, + class_type: "ImageScaleBy", + }, + [preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" }, + [preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" }, + }); + }); + }); test("adds widgets in node execution order", async () => { const { ez, graph, app } = await start(); const scale = ez.LatentUpscale(); @@ -815,4 +901,105 @@ describe("group node", () => { expect(p2.widgets.control_after_generate.value).toBe("randomize"); expect(p2.widgets.control_filter_list.value).toBe("/.+/"); }); + test("internal reroutes work with converted inputs and merge options", async () => { + const { ez, graph, app } = await start(); + const vae = ez.VAELoader(); + const latent = ez.EmptyLatentImage(); + const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE); + const scale = ez.ImageScale(decode.outputs.IMAGE); + ez.PreviewImage(scale.outputs.IMAGE); + + const r1 = ez.Reroute(); + const r2 = ez.Reroute(); + + latent.widgets.width.value = 64; + latent.widgets.height.value = 128; + + latent.widgets.width.convertToInput(); + latent.widgets.height.convertToInput(); + latent.widgets.batch_size.convertToInput(); + + scale.widgets.width.convertToInput(); + scale.widgets.height.convertToInput(); + + r1.inputs[0].input.label = "hbw"; + r1.outputs[0].connectTo(latent.inputs.height); + r1.outputs[0].connectTo(latent.inputs.batch_size); + r1.outputs[0].connectTo(scale.inputs.width); + + r2.inputs[0].input.label = "wh"; + r2.outputs[0].connectTo(latent.inputs.width); + r2.outputs[0].connectTo(scale.inputs.height); + + const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]); + + expect(group.inputs[0].input.type).toBe("VAE"); + expect(group.inputs[1].input.type).toBe("INT"); + expect(group.inputs[2].input.type).toBe("INT"); + + const p1 = ez.PrimitiveNode(); + const p2 = ez.PrimitiveNode(); + p1.outputs[0].connectTo(group.inputs[1]); + p2.outputs[0].connectTo(group.inputs[2]); + + expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p2.widgets.value.widget.options?.max).toBe(16384); // width/height max + expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p1.widgets.value.value).toBe(128); + expect(p2.widgets.value.value).toBe(64); + + p1.widgets.value.value = 16; + p2.widgets.value.value = 32; + + await checkBeforeAndAfterReload(graph, async (r) => { + const id = (v) => (r ? `${group.id}:` : "") + v; + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, + [id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" }, + [id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" }, + [id(4)]: { + inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] }, + class_type: "ImageScale", + }, + 5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" }, + }); + }); + }); + test("converted inputs with linked widgets map values correctly on creation", async () => { + const { ez, graph, app } = await start(); + const k1 = ez.KSampler(); + const k2 = ez.KSampler(); + k1.widgets.seed.convertToInput(); + k2.widgets.seed.convertToInput(); + + const rr = ez.Reroute(); + rr.outputs[0].connectTo(k1.inputs.seed); + rr.outputs[0].connectTo(k2.inputs.seed); + + const group = await convertToGroup(app, graph, "test", [k1, k2, rr]); + expect(group.widgets.steps.value).toBe(20); + expect(group.widgets.cfg.value).toBe(8); + expect(group.widgets.scheduler.value).toBe("normal"); + expect(group.widgets["KSampler steps"].value).toBe(20); + expect(group.widgets["KSampler cfg"].value).toBe(8); + expect(group.widgets["KSampler scheduler"].value).toBe("normal"); + }); + test("allow multiple of the same node type to be added", async () => { + const { ez, graph, app } = await start(); + const nodes = [...Array(10)].map(() => ez.ImageScaleBy()); + const group = await convertToGroup(app, graph, "test", nodes); + expect(group.inputs.length).toBe(10); + expect(group.outputs.length).toBe(10); + expect(group.widgets.length).toBe(20); + expect(group.widgets.map((w) => w.widget.name)).toStrictEqual( + [...Array(10)] + .map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`) + .flatMap((p) => [`${p}upscale_method`, `${p}scale_by`]) + ); + }); }); diff --git a/tests-ui/tests/users.test.js b/tests-ui/tests/users.test.js new file mode 100644 index 00000000..5e073073 --- /dev/null +++ b/tests-ui/tests/users.test.js @@ -0,0 +1,295 @@ +// @ts-check +/// +const { start } = require("../utils"); +const lg = require("../utils/litegraph"); + +describe("users", () => { + beforeEach(() => { + lg.setup(global); + }); + + afterEach(() => { + lg.teardown(global); + }); + + function expectNoUserScreen() { + // Ensure login isnt visible + const selection = document.querySelectorAll("#comfy-user-selection")?.[0]; + expect(selection["style"].display).toBe("none"); + const menu = document.querySelectorAll(".comfy-menu")?.[0]; + expect(window.getComputedStyle(menu)?.display).not.toBe("none"); + } + + describe("multi-user", () => { + function mockAddStylesheet() { + const utils = require("../../web/scripts/utils"); + utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve()); + } + + async function waitForUserScreenShow() { + mockAddStylesheet(); + + // Wait for "show" to be called + const { UserSelectionScreen } = require("../../web/scripts/ui/userSelection"); + let resolve, reject; + const fn = UserSelectionScreen.prototype.show; + const p = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => { + const res = fn(...args); + await new Promise(process.nextTick); // wait for promises to resolve + resolve(); + return res; + }); + // @ts-ignore + setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500); + await p; + await new Promise(process.nextTick); // wait for promises to resolve + } + + async function testUserScreen(onShown, users) { + if (!users) { + users = {}; + } + const starting = start({ + resetEnv: true, + userConfig: { storage: "server", users }, + }); + + // Ensure no current user + expect(localStorage["Comfy.userId"]).toBeFalsy(); + expect(localStorage["Comfy.userName"]).toBeFalsy(); + + await waitForUserScreenShow(); + + const selection = document.querySelectorAll("#comfy-user-selection")?.[0]; + expect(selection).toBeTruthy(); + + // Ensure login is visible + expect(window.getComputedStyle(selection)?.display).not.toBe("none"); + // Ensure menu is hidden + const menu = document.querySelectorAll(".comfy-menu")?.[0]; + expect(window.getComputedStyle(menu)?.display).toBe("none"); + + const isCreate = await onShown(selection); + + // Submit form + selection.querySelectorAll("form")[0].submit(); + await new Promise(process.nextTick); // wait for promises to resolve + + // Wait for start + const s = await starting; + + // Ensure login is removed + expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0); + expect(window.getComputedStyle(menu)?.display).not.toBe("none"); + + // Ensure settings + templates are saved + const { api } = require("../../web/scripts/api"); + expect(api.createUser).toHaveBeenCalledTimes(+isCreate); + expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate); + expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate); + if (isCreate) { + expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false }); + expect(s.app.isNewUserSession).toBeTruthy(); + } else { + expect(s.app.isNewUserSession).toBeFalsy(); + } + + return { users, selection, ...s }; + } + + it("allows user creation if no users", async () => { + const { users } = await testUserScreen((selection) => { + // Ensure we have no users flag added + expect(selection.classList.contains("no-users")).toBeTruthy(); + + // Enter a username + const input = selection.getElementsByTagName("input")[0]; + input.focus(); + input.value = "Test User"; + + return true; + }); + + expect(users).toStrictEqual({ + "Test User!": "Test User", + }); + + expect(localStorage["Comfy.userId"]).toBe("Test User!"); + expect(localStorage["Comfy.userName"]).toBe("Test User"); + }); + it("allows user creation if no current user but other users", async () => { + const users = { + "Test User 2!": "Test User 2", + }; + + await testUserScreen((selection) => { + expect(selection.classList.contains("no-users")).toBeFalsy(); + + // Enter a username + const input = selection.getElementsByTagName("input")[0]; + input.focus(); + input.value = "Test User 3"; + return true; + }, users); + + expect(users).toStrictEqual({ + "Test User 2!": "Test User 2", + "Test User 3!": "Test User 3", + }); + + expect(localStorage["Comfy.userId"]).toBe("Test User 3!"); + expect(localStorage["Comfy.userName"]).toBe("Test User 3"); + }); + it("allows user selection if no current user but other users", async () => { + const users = { + "A!": "A", + "B!": "B", + "C!": "C", + }; + + await testUserScreen((selection) => { + expect(selection.classList.contains("no-users")).toBeFalsy(); + + // Check user list + const select = selection.getElementsByTagName("select")[0]; + const options = select.getElementsByTagName("option"); + expect( + [...options] + .filter((o) => !o.disabled) + .reduce((p, n) => { + p[n.getAttribute("value")] = n.textContent; + return p; + }, {}) + ).toStrictEqual(users); + + // Select an option + select.focus(); + select.value = options[2].value; + + return false; + }, users); + + expect(users).toStrictEqual(users); + + expect(localStorage["Comfy.userId"]).toBe("B!"); + expect(localStorage["Comfy.userName"]).toBe("B"); + }); + it("doesnt show user screen if current user", async () => { + const starting = start({ + resetEnv: true, + userConfig: { + storage: "server", + users: { + "User!": "User", + }, + }, + localStorage: { + "Comfy.userId": "User!", + "Comfy.userName": "User", + }, + }); + await new Promise(process.nextTick); // wait for promises to resolve + + expectNoUserScreen(); + + await starting; + }); + it("allows user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { + storage: "server", + users: { + "User!": "User", + }, + }, + localStorage: { + "Comfy.userId": "User!", + "Comfy.userName": "User", + }, + }); + + // cant actually test switching user easily but can check the setting is present + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy(); + }); + }); + describe("single-user", () => { + it("doesnt show user creation if no default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: false, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(1); + expect(api.storeUserData).toHaveBeenCalledTimes(1); + expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false }); + expect(app.isNewUserSession).toBeTruthy(); + }); + it("doesnt show user creation if default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt allow user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy(); + }); + }); + describe("browser-user", () => { + it("doesnt show user creation if no default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: false, storage: "browser" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt show user creation if default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt allow user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "browser" }, + }); + expectNoUserScreen(); + + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy(); + }); + }); +}); diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index 8e191adf..67e3fa34 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -1,7 +1,13 @@ // @ts-check /// -const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils"); +const { + start, + makeNodeDef, + checkBeforeAndAfterReload, + assertNotNullOrUndefined, + createDefaultWorkflow, +} = require("../utils"); const lg = require("../utils/litegraph"); /** @@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate; expect(controlWidget.widget.type).toBe("combo"); - if(widgetType === "combo") { + if (widgetType === "combo") { const filterWidget = primitive.widgets.control_filter_list; expect(filterWidget.widget.type).toBe("string"); } @@ -308,8 +314,8 @@ describe("widget inputs", () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef("TestNode1", {}, [["A", "B"]]), - ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }), - ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }), + ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }), + ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }), }, }); @@ -330,7 +336,7 @@ describe("widget inputs", () => { const n1 = ez.TestNode1(); n1.widgets.example.convertToInput(); - const p = ez.PrimitiveNode() + const p = ez.PrimitiveNode(); p.outputs[0].connectTo(n1.inputs[0]); const value = p.widgets.value; @@ -380,7 +386,7 @@ describe("widget inputs", () => { // Check random control.value = "randomize"; filter.value = "/D/"; - for(let i = 0; i < 100; i++) { + for (let i = 0; i < 100; i++) { control["afterQueued"](); expect(value.value === "D" || value.value === "DD").toBeTruthy(); } @@ -392,4 +398,160 @@ describe("widget inputs", () => { control["afterQueued"](); expect(value.value).toBe("B"); }); + + describe("reroutes", () => { + async function checkOutput(graph, values) { + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" }, + 2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 4: { + inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 }, + class_type: "EmptyLatentImage", + }, + 5: { + inputs: { + seed: 0, + steps: 20, + cfg: 8, + sampler_name: "euler", + scheduler: values?.scheduler ?? "normal", + denoise: 1, + model: ["1", 0], + positive: ["2", 0], + negative: ["3", 0], + latent_image: ["4", 0], + }, + class_type: "KSampler", + }, + 6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" }, + 7: { + inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] }, + class_type: "SaveImage", + }, + }); + } + + async function waitForWidget(node) { + // widgets are created slightly after the graph is ready + // hard to find an exact hook to get these so just wait for them to be ready + for (let i = 0; i < 10; i++) { + await new Promise((r) => setTimeout(r, 10)); + if (node.widgets?.value) { + return; + } + } + } + + it("can connect primitive via a reroute path to a widget input", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.sampler.widgets.scheduler.convertToInput(); + nodes.save.widgets.filename_prefix.convertToInput(); + + let widthReroute = ez.Reroute(); + let schedulerReroute = ez.Reroute(); + let fileReroute = ez.Reroute(); + + let widthNext = widthReroute; + let schedulerNext = schedulerReroute; + let fileNext = fileReroute; + + for (let i = 0; i < 5; i++) { + let next = ez.Reroute(); + widthNext.outputs[0].connectTo(next.inputs[0]); + widthNext = next; + + next = ez.Reroute(); + schedulerNext.outputs[0].connectTo(next.inputs[0]); + schedulerNext = next; + + next = ez.Reroute(); + fileNext.outputs[0].connectTo(next.inputs[0]); + fileNext = next; + } + + widthNext.outputs[0].connectTo(nodes.empty.inputs.width); + schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler); + fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix); + + let widthPrimitive = ez.PrimitiveNode(); + let schedulerPrimitive = ez.PrimitiveNode(); + let filePrimitive = ez.PrimitiveNode(); + + widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]); + schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]); + filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]); + expect(widthPrimitive.widgets.value.value).toBe(512); + widthPrimitive.widgets.value.value = 1024; + expect(schedulerPrimitive.widgets.value.value).toBe("normal"); + schedulerPrimitive.widgets.value.value = "simple"; + expect(filePrimitive.widgets.value.value).toBe("ComfyUI"); + filePrimitive.widgets.value.value = "ComfyTest"; + + await checkBeforeAndAfterReload(graph, async () => { + widthPrimitive = graph.find(widthPrimitive); + schedulerPrimitive = graph.find(schedulerPrimitive); + filePrimitive = graph.find(filePrimitive); + await waitForWidget(filePrimitive); + expect(widthPrimitive.widgets.length).toBe(2); + expect(schedulerPrimitive.widgets.length).toBe(3); + expect(filePrimitive.widgets.length).toBe(1); + + await checkOutput(graph, { + width: 1024, + scheduler: "simple", + filename_prefix: "ComfyTest", + }); + }); + }); + it("can connect primitive via a reroute path to multiple widget inputs", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.empty.widgets.height.convertToInput(); + nodes.empty.widgets.batch_size.convertToInput(); + + let reroute = ez.Reroute(); + let prevReroute = reroute; + for (let i = 0; i < 5; i++) { + const next = ez.Reroute(); + prevReroute.outputs[0].connectTo(next.inputs[0]); + prevReroute = next; + } + + const r1 = ez.Reroute(prevReroute.outputs[0]); + const r2 = ez.Reroute(prevReroute.outputs[0]); + const r3 = ez.Reroute(r2.outputs[0]); + const r4 = ez.Reroute(r2.outputs[0]); + + r1.outputs[0].connectTo(nodes.empty.inputs.width); + r3.outputs[0].connectTo(nodes.empty.inputs.height); + r4.outputs[0].connectTo(nodes.empty.inputs.batch_size); + + let primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(reroute.inputs[0]); + expect(primitive.widgets.value.value).toBe(1); + primitive.widgets.value.value = 64; + + await checkBeforeAndAfterReload(graph, async (r) => { + primitive = graph.find(primitive); + await waitForWidget(primitive); + + // Ensure widget configs are merged + expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + await checkOutput(graph, { + width: 64, + height: 64, + batch_size: 64, + }); + }); + }); + }); }); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 898b82db..8a55246e 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -78,6 +78,14 @@ export class EzInput extends EzSlot { this.input = input; } + get connection() { + const link = this.node.node.inputs?.[this.index]?.link; + if (link == null) { + return null; + } + return new EzConnection(this.node.app, this.node.app.graph.links[link]); + } + disconnect() { this.node.node.disconnectInput(this.index); } @@ -117,7 +125,7 @@ export class EzOutput extends EzSlot { const inp = input.input; const inName = inp.name || inp.label || inp.type; throw new Error( - `Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${ + `Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${ this.output.name ?? this.output.type }#${this.index}] failed.` ); @@ -179,6 +187,7 @@ export class EzWidget { set value(v) { this.widget.value = v; + this.widget.callback?.call?.(this.widget, v) } get isConvertedToInput() { @@ -319,7 +328,7 @@ export class EzGraph { } stringify() { - return JSON.stringify(this.app.graph.serialize(), undefined, "\t"); + return JSON.stringify(this.app.graph.serialize(), undefined); } /** diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index 3a018f56..74b6cf93 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -1,10 +1,18 @@ const { mockApi } = require("./setup"); const { Ez } = require("./ezgraph"); const lg = require("./litegraph"); +const fs = require("fs"); +const path = require("path"); + +const html = fs.readFileSync(path.resolve(__dirname, "../../web/index.html")) /** * - * @param { Parameters[0] & { resetEnv?: boolean, preSetup?(app): Promise } } config + * @param { Parameters[0] & { + * resetEnv?: boolean, + * preSetup?(app): Promise, + * localStorage?: Record + * } } config * @returns */ export async function start(config = {}) { @@ -12,12 +20,18 @@ export async function start(config = {}) { jest.resetModules(); jest.resetAllMocks(); lg.setup(global); + localStorage.clear(); + sessionStorage.clear(); } + Object.assign(localStorage, config.localStorage ?? {}); + document.body.innerHTML = html; + mockApi(config); const { app } = require("../../web/scripts/app"); config.preSetup?.(app); await app.setup(); + return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; } @@ -104,3 +118,12 @@ export function createDefaultWorkflow(ez, graph) { return { ckpt, pos, neg, empty, sampler, decode, save }; } + +export async function getNodeDefs() { + const { api } = require("../../web/scripts/api"); + return api.getNodeDefs(); +} + +export async function getNodeDef(nodeId) { + return (await getNodeDefs())[nodeId]; +} \ No newline at end of file diff --git a/tests-ui/utils/setup.js b/tests-ui/utils/setup.js index dd150214..e4625894 100644 --- a/tests-ui/utils/setup.js +++ b/tests-ui/utils/setup.js @@ -18,9 +18,21 @@ function* walkSync(dir) { */ /** - * @param { { mockExtensions?: string[], mockNodeDefs?: Record } } config + * @param {{ + * mockExtensions?: string[], + * mockNodeDefs?: Record, +* settings?: Record +* userConfig?: {storage: "server" | "browser", users?: Record, migrated?: boolean }, +* userData?: Record + * }} config */ -export function mockApi({ mockExtensions, mockNodeDefs } = {}) { +export function mockApi(config = {}) { + let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = { + userConfig, + settings: {}, + userData: {}, + ...config, + }; if (!mockExtensions) { mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core"))) .filter((x) => x.endsWith(".js")) @@ -40,6 +52,26 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) { getNodeDefs: jest.fn(() => mockNodeDefs), init: jest.fn(), apiURL: jest.fn((x) => "../../web/" + x), + createUser: jest.fn((username) => { + if(username in userConfig.users) { + return { status: 400, json: () => "Duplicate" } + } + userConfig.users[username + "!"] = username; + return { status: 200, json: () => username + "!" } + }), + getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }), + getSettings: jest.fn(() => settings), + storeSettings: jest.fn((v) => Object.assign(settings, v)), + getUserData: jest.fn((f) => { + if (f in userData) { + return { status: 200, json: () => userData[f] }; + } else { + return { status: 404 }; + } + }), + storeUserData: jest.fn((file, data) => { + userData[file] = data; + }), }; jest.mock("../../web/scripts/api", () => ({ get api() { diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 4b4bf74f..0b0763d1 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -1,6 +1,7 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; import { mergeIfValid } from "./widgetInputs.js"; +import { ManageGroupDialog } from "./groupNodeManage.js"; const GROUP = Symbol(); @@ -61,11 +62,7 @@ class GroupNodeBuilder { ); return; case Workflow.InUse.Registered: - if ( - !confirm( - "An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?" - ) - ) { + if (!confirm("A group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?")) { return; } break; @@ -151,6 +148,8 @@ export class GroupNodeConfig { this.primitiveDefs = {}; this.widgetToPrimitive = {}; this.primitiveToWidget = {}; + this.nodeInputs = {}; + this.outputVisibility = []; } async registerType(source = "workflow") { @@ -158,6 +157,7 @@ export class GroupNodeConfig { output: [], output_name: [], output_is_list: [], + output_is_hidden: [], name: source + "/" + this.name, display_name: this.name, category: "group nodes" + ("/" + source), @@ -174,6 +174,11 @@ export class GroupNodeConfig { node.index = i; this.processNode(node, seenInputs, seenOutputs); } + + for (const p of this.#convertedToProcess) { + p(); + } + this.#convertedToProcess = null; await app.registerNodeDef("workflow/" + this.name, this.nodeDef); } @@ -192,7 +197,10 @@ export class GroupNodeConfig { if (!this.linksFrom[sourceNodeId]) { this.linksFrom[sourceNodeId] = {}; } - this.linksFrom[sourceNodeId][sourceNodeSlot] = l; + if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) { + this.linksFrom[sourceNodeId][sourceNodeSlot] = []; + } + this.linksFrom[sourceNodeId][sourceNodeSlot].push(l); if (!this.linksTo[targetNodeId]) { this.linksTo[targetNodeId] = {}; @@ -230,11 +238,11 @@ export class GroupNodeConfig { // Skip as its not linked if (!linksFrom) return; - let type = linksFrom["0"][5]; + let type = linksFrom["0"][0][5]; if (type === "COMBO") { // Use the array items const source = node.outputs[0].widget.name; - const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type; + const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type; const fromType = globalDefs[fromTypeName]; const input = fromType.input.required[source] ?? fromType.input.optional[source]; type = input[0]; @@ -258,10 +266,32 @@ export class GroupNodeConfig { return null; } + let config = {}; let rerouteType = "*"; if (linksFrom) { - const [, , id, slot] = linksFrom["0"]; - rerouteType = this.nodeData.nodes[id].inputs[slot].type; + for (const [, , id, slot] of linksFrom["0"]) { + const node = this.nodeData.nodes[id]; + const input = node.inputs[slot]; + if (rerouteType === "*") { + rerouteType = input.type; + } + if (input.widget) { + const targetDef = globalDefs[node.type]; + const targetWidget = targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; + + const widget = [targetWidget[0], config]; + const res = mergeIfValid( + { + widget, + }, + targetWidget, + false, + null, + widget + ); + config = res?.customConfig ?? config; + } + } } else if (linksTo) { const [id, slot] = linksTo["0"]; rerouteType = this.nodeData.nodes[id].outputs[slot].type; @@ -282,10 +312,11 @@ export class GroupNodeConfig { } } + config.forceInput = true; return { input: { required: { - [rerouteType]: [rerouteType, {}], + [rerouteType]: [rerouteType, config], }, }, output: [rerouteType], @@ -298,17 +329,19 @@ export class GroupNodeConfig { } getInputConfig(node, inputName, seenInputs, config, extra) { - let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; + const customConfig = this.nodeData.config?.[node.index]?.input?.[inputName]; + let name = customConfig?.name ?? node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; + let key = name; let prefix = ""; // Special handling for primitive to include the title if it is set rather than just "value" if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) { prefix = `${node.title ?? node.type} `; - name = `${prefix}${inputName}`; + key = name = `${prefix}${inputName}`; if (name in seenInputs) { name = `${prefix}${seenInputs[name]} ${inputName}`; } } - seenInputs[name] = (seenInputs[name] ?? 1) + 1; + seenInputs[key] = (seenInputs[key] ?? 1) + 1; if (inputName === "seed" || inputName === "noise_seed") { if (!extra) extra = {}; @@ -316,14 +349,14 @@ export class GroupNodeConfig { } if (config[0] === "IMAGEUPLOAD") { if (!extra) extra = {}; - extra.widget = `${prefix}${config[1]?.widget ?? "image"}`; + extra.widget = this.oldToNewWidgetMap[node.index]?.[config[1]?.widget ?? "image"] ?? "image"; } if (extra) { config = [config[0], { ...config[1], ...extra }]; } - return { name, config }; + return { name, config, customConfig }; } processWidgetInputs(inputs, node, inputNames, seenInputs) { @@ -333,9 +366,7 @@ export class GroupNodeConfig { for (const inputName of inputNames) { let widgetType = app.getWidgetType(inputs[inputName], inputName); if (widgetType) { - const convertedIndex = node.inputs?.findIndex( - (inp) => inp.name === inputName && inp.widget?.name === inputName - ); + const convertedIndex = node.inputs?.findIndex((inp) => inp.name === inputName && inp.widget?.name === inputName); if (convertedIndex > -1) { // This widget has been converted to a widget // We need to store this in the correct position so link ids line up @@ -391,6 +422,7 @@ export class GroupNodeConfig { } processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) { + this.nodeInputs[node.index] = {}; for (let i = 0; i < slots.length; i++) { const inputName = slots[i]; if (linksTo[i]) { @@ -399,7 +431,11 @@ export class GroupNodeConfig { continue; } - const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + const { name, config, customConfig } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + + this.nodeInputs[node.index][inputName] = name; + if(customConfig?.visible === false) continue; + this.nodeDef.input.required[name] = config; inputMap[i] = this.inputCount++; } @@ -419,11 +455,20 @@ export class GroupNodeConfig { const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], { defaultInput: true, }); + this.nodeDef.input.required[name] = config; + this.newToOldWidgetMap[name] = { node, inputName }; + + if (!this.oldToNewWidgetMap[node.index]) { + this.oldToNewWidgetMap[node.index] = {}; + } + this.oldToNewWidgetMap[node.index][inputName] = name; + inputMap[slots.length + i] = this.inputCount++; } } + #convertedToProcess = []; processNodeInputs(node, seenInputs, inputs) { const inputMapping = []; @@ -434,7 +479,9 @@ export class GroupNodeConfig { const linksTo = this.linksTo[node.index] ?? {}; const inputMap = (this.oldToNewInputMap[node.index] = {}); this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); - this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs); + + // Converted inputs have to be processed after all other nodes as they'll be at the end of the list + this.#convertedToProcess.push(() => this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)); return inputMapping; } @@ -445,8 +492,12 @@ export class GroupNodeConfig { // Add outputs for (let outputId = 0; outputId < def.output.length; outputId++) { const linksFrom = this.linksFrom[node.index]; - if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) { - // This output is linked internally so we can skip it + // If this output is linked internally we flag it to hide + const hasLink = linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]; + const customConfig = this.nodeData.config?.[node.index]?.output?.[outputId]; + const visible = customConfig?.visible ?? !hasLink; + this.outputVisibility.push(visible); + if (!visible) { continue; } @@ -455,11 +506,15 @@ export class GroupNodeConfig { this.nodeDef.output.push(def.output[outputId]); this.nodeDef.output_is_list.push(def.output_is_list[outputId]); - let label = def.output_name?.[outputId] ?? def.output[outputId]; - const output = node.outputs.find((o) => o.name === label); - if (output?.label) { - label = output.label; + let label = customConfig?.name; + if (!label) { + label = def.output_name?.[outputId] ?? def.output[outputId]; + const output = node.outputs.find((o) => o.name === label); + if (output?.label) { + label = output.label; + } } + let name = label; if (name in seenOutputs) { const prefix = `${node.title ?? node.type} `; @@ -475,6 +530,16 @@ export class GroupNodeConfig { } static async registerFromWorkflow(groupNodes, missingNodeTypes) { + const clean = app.clean; + app.clean = function () { + for (const g in groupNodes) { + try { + LiteGraph.unregisterNodeType("workflow/" + g); + } catch (error) {} + } + app.clean = clean; + }; + for (const g in groupNodes) { const groupData = groupNodes[g]; @@ -482,7 +547,24 @@ export class GroupNodeConfig { for (const n of groupData.nodes) { // Find missing node types if (!(n.type in LiteGraph.registered_node_types)) { - missingNodeTypes.push(n.type); + missingNodeTypes.push({ + type: n.type, + hint: ` (In group node 'workflow/${g}')`, + }); + + missingNodeTypes.push({ + type: "workflow/" + g, + action: { + text: "Remove from workflow", + callback: (e) => { + delete groupNodes[g]; + e.target.textContent = "Removed"; + e.target.style.pointerEvents = "none"; + e.target.style.opacity = 0.7; + }, + }, + }); + hasMissing = true; } } @@ -570,11 +652,19 @@ export class GroupNodeHandler { const output = this.groupData.newToOldOutputMap[link.origin_slot]; let innerNode = this.innerNodes[output.node.index]; let l; - while (innerNode.type === "Reroute") { + while (innerNode?.type === "Reroute") { l = innerNode.getInputLink(0); innerNode = innerNode.getInputNode(0); } + if (!innerNode) { + return null; + } + + if (l && GroupNodeHandler.isGroupNode(innerNode)) { + return innerNode.updateLink(l); + } + link.origin_id = innerNode.id; link.origin_slot = l?.origin_slot ?? output.slot; return link; @@ -597,6 +687,25 @@ export class GroupNodeHandler { return this.innerNodes; }; + this.node.recreate = async () => { + const id = this.node.id; + const sz = this.node.size; + const nodes = this.node.convertToNodes(); + + const groupNode = LiteGraph.createNode(this.node.type); + groupNode.id = id; + + // Reuse the existing nodes for this instance + groupNode.setInnerNodes(nodes); + groupNode[GROUP].populateWidgets(); + app.graph.add(groupNode); + groupNode.size = [Math.max(groupNode.size[0], sz[0]), Math.max(groupNode.size[1], sz[1])]; + + // Remove all converted nodes and relink them + groupNode[GROUP].replaceNodes(nodes); + return groupNode; + }; + this.node.convertToNodes = () => { const addInnerNodes = () => { const backup = localStorage.getItem("litegrapheditor_clipboard"); @@ -638,6 +747,8 @@ export class GroupNodeHandler { top = newNode.pos[1]; } + if (!newNode.widgets) continue; + const map = this.groupData.oldToNewWidgetMap[innerNode.index]; if (map) { const widgets = Object.keys(map); @@ -687,6 +798,7 @@ export class GroupNodeHandler { const slot = node.inputs[groupSlotId]; if (slot.link == null) continue; const link = app.graph.links[slot.link]; + if (!link) continue; // connect this node output to the input of another node const originNode = app.graph.getNodeById(link.origin_id); originNode.connect(link.origin_slot, newNode, +innerInputId); @@ -694,7 +806,7 @@ export class GroupNodeHandler { } }; - const reconnectOutputs = () => { + const reconnectOutputs = (selectedIds) => { for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) { const output = node.outputs[groupOutputId]; if (!output.links) continue; @@ -724,12 +836,23 @@ export class GroupNodeHandler { let optionIndex = options.findIndex((o) => o.content === "Outputs"); if (optionIndex === -1) optionIndex = options.length; else optionIndex++; - options.splice(optionIndex, 0, null, { - content: "Convert to nodes", - callback: () => { - return this.convertToNodes(); + options.splice( + optionIndex, + 0, + null, + { + content: "Convert to nodes", + callback: () => { + return this.convertToNodes(); + }, }, - }); + { + content: "Manage Group Node", + callback: () => { + new ManageGroupDialog(app).show(this.type); + }, + } + ); }; // Draw custom collapse icon to identity this as a group @@ -761,6 +884,7 @@ export class GroupNodeHandler { const r = onDrawForeground?.apply?.(this, arguments); if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) { const n = groupData.nodes[this.runningInternalNodeId]; + if(!n) return; const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`; ctx.save(); ctx.font = "12px sans-serif"; @@ -783,6 +907,31 @@ export class GroupNodeHandler { return onExecutionStart?.apply(this, arguments); }; + const self = this; + const onNodeCreated = this.node.onNodeCreated; + this.node.onNodeCreated = function () { + if (!this.widgets) { + return; + } + const config = self.groupData.nodeData.config; + if (config) { + for (const n in config) { + const inputs = config[n]?.input; + for (const w in inputs) { + if (inputs[w].visible !== false) continue; + const widgetName = self.groupData.oldToNewWidgetMap[n][w]; + const widget = this.widgets.find((w) => w.name === widgetName); + if (widget) { + widget.type = "hidden"; + widget.computeSize = () => [0, -4]; + } + } + } + } + + return onNodeCreated?.apply(this, arguments); + }; + function handleEvent(type, getId, getEvent) { const handler = ({ detail }) => { const id = getId(detail); @@ -820,6 +969,26 @@ export class GroupNodeHandler { api.removeEventListener("executing", executing); api.removeEventListener("executed", executed); }; + + this.node.refreshComboInNode = (defs) => { + // Update combo widget options + for (const widgetName in this.groupData.newToOldWidgetMap) { + const widget = this.node.widgets.find((w) => w.name === widgetName); + if (widget?.type === "combo") { + const old = this.groupData.newToOldWidgetMap[widgetName]; + const def = defs[old.node.type]; + const input = def?.input?.required?.[old.inputName] ?? def?.input?.optional?.[old.inputName]; + if (!input) continue; + + widget.options.values = input[0]; + + if (old.inputName !== "image" && !widget.options.values.includes(widget.value)) { + widget.value = widget.options.values[0]; + widget.callback(widget.value); + } + } + } + }; } updateInnerWidgets() { @@ -834,7 +1003,7 @@ export class GroupNodeHandler { if (innerNode.type === "PrimitiveNode") { innerNode.primitiveValue = newValue; const primitiveLinked = this.groupData.primitiveToWidget[old.node.index]; - for (const linked of primitiveLinked) { + for (const linked of primitiveLinked ?? []) { const node = this.innerNodes[linked.nodeId]; const widget = node.widgets.find((w) => w.name === linked.inputName); @@ -843,6 +1012,20 @@ export class GroupNodeHandler { } } continue; + } else if (innerNode.type === "Reroute") { + const rerouteLinks = this.groupData.linksFrom[old.node.index]; + if (rerouteLinks) { + for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { + const node = this.innerNodes[targetNodeId]; + const input = node.inputs[targetSlot]; + if (input.widget) { + const widget = node.widgets?.find((w) => w.name === input.widget.name); + if (widget) { + widget.value = newValue; + } + } + } + } } const widget = innerNode.widgets?.find((w) => w.name === old.inputName); @@ -870,33 +1053,57 @@ export class GroupNodeHandler { this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value; } } + return true; + } + + populateReroute(node, nodeId, map) { + if (node.type !== "Reroute") return; + + const link = this.groupData.linksFrom[nodeId]?.[0]?.[0]; + if (!link) return; + const [, , targetNodeId, targetNodeSlot] = link; + const targetNode = this.groupData.nodeData.nodes[targetNodeId]; + const inputs = targetNode.inputs; + const targetWidget = inputs?.[targetNodeSlot]?.widget; + if (!targetWidget) return; + + const offset = inputs.length - (targetNode.widgets_values?.length ?? 0); + const v = targetNode.widgets_values?.[targetNodeSlot - offset]; + if (v == null) return; + + const widgetName = Object.values(map)[0]; + const widget = this.node.widgets.find((w) => w.name === widgetName); + if (widget) { + widget.value = v; + } } populateWidgets() { + if (!this.node.widgets) return; + for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) { const node = this.groupData.nodeData.nodes[nodeId]; - - if (!node.widgets_values?.length) continue; - - const map = this.groupData.oldToNewWidgetMap[nodeId]; + const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {}; const widgets = Object.keys(map); + if (!node.widgets_values?.length) { + // special handling for populating values into reroutes + // this allows primitives connect to them to pick up the correct value + this.populateReroute(node, nodeId, map); + continue; + } + let linkedShift = 0; for (let i = 0; i < widgets.length; i++) { const oldName = widgets[i]; const newName = map[oldName]; const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); const mainWidget = this.node.widgets[widgetIndex]; - if (!newName) { - // New name will be null if its a converted widget - this.populatePrimitive(node, nodeId, oldName, i, linkedShift); - + if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) { // Find the inner widget and shift by the number of linked widgets as they will have been removed too const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName); - linkedShift += innerWidget.linkedWidgets?.length ?? 0; - continue; + linkedShift += innerWidget?.linkedWidgets?.length ?? 0; } - if (widgetIndex === -1) { continue; } @@ -961,7 +1168,7 @@ export class GroupNodeHandler { } static getGroupData(node) { - return node.constructor?.nodeData?.[GROUP]; + return (node.nodeData ?? node.constructor?.nodeData)?.[GROUP]; } static isGroupNode(node) { @@ -993,7 +1200,7 @@ export class GroupNodeHandler { } function addConvertToGroupOptions() { - function addOption(options, index) { + function addConvertOption(options, index) { const selected = Object.values(app.canvas.selected_nodes ?? {}); const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n)); options.splice(index + 1, null, { @@ -1005,12 +1212,25 @@ function addConvertToGroupOptions() { }); } + function addManageOption(options, index) { + const groups = app.graph.extra?.groupNodes; + const disabled = !groups || !Object.keys(groups).length; + options.splice(index + 1, null, { + content: `Manage Group Nodes`, + disabled, + callback: () => { + new ManageGroupDialog(app).show(); + }, + }); + } + // Add to canvas const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; LGraphCanvas.prototype.getCanvasMenuOptions = function () { const options = getCanvasMenuOptions.apply(this, arguments); const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length; - addOption(options, index); + addConvertOption(options, index); + addManageOption(options, index + 1); return options; }; @@ -1020,7 +1240,7 @@ function addConvertToGroupOptions() { const options = getNodeMenuOptions.apply(this, arguments); if (!GroupNodeHandler.isGroupNode(node)) { const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1; - addOption(options, index); + addConvertOption(options, index); } return options; }; @@ -1048,6 +1268,14 @@ const ext = { node[GROUP] = new GroupNodeHandler(node); } }, + async refreshComboInNodes(defs) { + // Re-register group nodes so new ones are created with the correct options + Object.assign(globalDefs, defs); + const nodes = app.graph.extra?.groupNodes; + if (nodes) { + await GroupNodeConfig.registerFromWorkflow(nodes, {}); + } + } }; app.registerExtension(ext); diff --git a/web/extensions/core/groupNodeManage.css b/web/extensions/core/groupNodeManage.css new file mode 100644 index 00000000..5470ecb5 --- /dev/null +++ b/web/extensions/core/groupNodeManage.css @@ -0,0 +1,149 @@ +.comfy-group-manage { + background: var(--bg-color); + color: var(--fg-color); + padding: 0; + font-family: Arial, Helvetica, sans-serif; + border-color: black; + margin: 20vh auto; + max-height: 60vh; +} +.comfy-group-manage-outer { + max-height: 60vh; + min-width: 500px; + display: flex; + flex-direction: column; +} +.comfy-group-manage-outer > header { + display: flex; + align-items: center; + gap: 10px; + justify-content: space-between; + background: var(--comfy-menu-bg); + padding: 15px 20px; +} +.comfy-group-manage-outer > header select { + background: var(--comfy-input-bg); + border: 1px solid var(--border-color); + color: var(--input-text); + padding: 5px 10px; + border-radius: 5px; +} +.comfy-group-manage h2 { + margin: 0; + font-weight: normal; +} +.comfy-group-manage main { + display: flex; + overflow: hidden; +} +.comfy-group-manage .drag-handle { + font-weight: bold; +} +.comfy-group-manage-list { + border-right: 1px solid var(--comfy-menu-bg); +} +.comfy-group-manage-list ul { + margin: 40px 0 0; + padding: 0; + list-style: none; +} +.comfy-group-manage-list-items { + max-height: calc(100% - 40px); + overflow-y: scroll; + overflow-x: hidden; +} +.comfy-group-manage-list li { + display: flex; + padding: 10px 20px 10px 10px; + cursor: pointer; + align-items: center; + gap: 5px; +} +.comfy-group-manage-list div { + display: flex; + flex-direction: column; +} +.comfy-group-manage-list li:not(.selected):hover div { + text-decoration: underline; +} +.comfy-group-manage-list li.selected { + background: var(--border-color); +} +.comfy-group-manage-list li span { + opacity: 0.7; + font-size: smaller; +} +.comfy-group-manage-node { + flex: auto; + background: var(--border-color); + display: flex; + flex-direction: column; +} +.comfy-group-manage-node > div { + overflow: auto; +} +.comfy-group-manage-node header { + display: flex; + background: var(--bg-color); + height: 40px; +} +.comfy-group-manage-node header a { + text-align: center; + flex: auto; + border-right: 1px solid var(--comfy-menu-bg); + border-bottom: 1px solid var(--comfy-menu-bg); + padding: 10px; + cursor: pointer; + font-size: 15px; +} +.comfy-group-manage-node header a:last-child { + border-right: none; +} +.comfy-group-manage-node header a:not(.active):hover { + text-decoration: underline; +} +.comfy-group-manage-node header a.active { + background: var(--border-color); + border-bottom: none; +} +.comfy-group-manage-node-page { + display: none; + overflow: auto; +} +.comfy-group-manage-node-page.active { + display: block; +} +.comfy-group-manage-node-page div { + padding: 10px; + display: flex; + align-items: center; + gap: 10px; +} +.comfy-group-manage-node-page input { + border: none; + color: var(--input-text); + background: var(--comfy-input-bg); + padding: 5px 10px; +} +.comfy-group-manage-node-page input[type="text"] { + flex: auto; +} +.comfy-group-manage-node-page label { + display: flex; + gap: 5px; + align-items: center; +} +.comfy-group-manage footer { + border-top: 1px solid var(--comfy-menu-bg); + padding: 10px; + display: flex; + gap: 10px; +} +.comfy-group-manage footer button { + font-size: 14px; + padding: 5px 10px; + border-radius: 0; +} +.comfy-group-manage footer button:first-child { + margin-right: auto; +} diff --git a/web/extensions/core/groupNodeManage.js b/web/extensions/core/groupNodeManage.js new file mode 100644 index 00000000..1ab33838 --- /dev/null +++ b/web/extensions/core/groupNodeManage.js @@ -0,0 +1,422 @@ +import { $el, ComfyDialog } from "../../scripts/ui.js"; +import { DraggableList } from "../../scripts/ui/draggableList.js"; +import { addStylesheet } from "../../scripts/utils.js"; +import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; + +addStylesheet(import.meta.url); + +const ORDER = Symbol(); + +function merge(target, source) { + if (typeof target === "object" && typeof source === "object") { + for (const key in source) { + const sv = source[key]; + if (typeof sv === "object") { + let tv = target[key]; + if (!tv) tv = target[key] = {}; + merge(tv, source[key]); + } else { + target[key] = sv; + } + } + } + + return target; +} + +export class ManageGroupDialog extends ComfyDialog { + /** @type { Record<"Inputs" | "Outputs" | "Widgets", {tab: HTMLAnchorElement, page: HTMLElement}> } */ + tabs = {}; + /** @type { number | null | undefined } */ + selectedNodeIndex; + /** @type { keyof ManageGroupDialog["tabs"] } */ + selectedTab = "Inputs"; + /** @type { string | undefined } */ + selectedGroup; + + /** @type { Record>> } */ + modifications = {}; + + get selectedNodeInnerIndex() { + return +this.nodeItems[this.selectedNodeIndex].dataset.nodeindex; + } + + constructor(app) { + super(); + this.app = app; + this.element = $el("dialog.comfy-group-manage", { + parent: document.body, + }); + } + + changeTab(tab) { + this.tabs[this.selectedTab].tab.classList.remove("active"); + this.tabs[this.selectedTab].page.classList.remove("active"); + this.tabs[tab].tab.classList.add("active"); + this.tabs[tab].page.classList.add("active"); + this.selectedTab = tab; + } + + changeNode(index, force) { + if (!force && this.selectedNodeIndex === index) return; + + if (this.selectedNodeIndex != null) { + this.nodeItems[this.selectedNodeIndex].classList.remove("selected"); + } + this.nodeItems[index].classList.add("selected"); + this.selectedNodeIndex = index; + + if (!this.buildInputsPage() && this.selectedTab === "Inputs") { + this.changeTab("Widgets"); + } + if (!this.buildWidgetsPage() && this.selectedTab === "Widgets") { + this.changeTab("Outputs"); + } + if (!this.buildOutputsPage() && this.selectedTab === "Outputs") { + this.changeTab("Inputs"); + } + + this.changeTab(this.selectedTab); + } + + getGroupData() { + this.groupNodeType = LiteGraph.registered_node_types["workflow/" + this.selectedGroup]; + this.groupNodeDef = this.groupNodeType.nodeData; + this.groupData = GroupNodeHandler.getGroupData(this.groupNodeType); + } + + changeGroup(group, reset = true) { + this.selectedGroup = group; + this.getGroupData(); + + const nodes = this.groupData.nodeData.nodes; + this.nodeItems = nodes.map((n, i) => + $el( + "li.draggable-item", + { + dataset: { + nodeindex: n.index + "", + }, + onclick: () => { + this.changeNode(i); + }, + }, + [ + $el("span.drag-handle"), + $el( + "div", + { + textContent: n.title ?? n.type, + }, + n.title + ? $el("span", { + textContent: n.type, + }) + : [] + ), + ] + ) + ); + + this.innerNodesList.replaceChildren(...this.nodeItems); + + if (reset) { + this.selectedNodeIndex = null; + this.changeNode(0); + } else { + const items = this.draggable.getAllItems(); + let index = items.findIndex(item => item.classList.contains("selected")); + if(index === -1) index = this.selectedNodeIndex; + this.changeNode(index, true); + } + + const ordered = [...nodes]; + this.draggable?.dispose(); + this.draggable = new DraggableList(this.innerNodesList, "li"); + this.draggable.addEventListener("dragend", ({ detail: { oldPosition, newPosition } }) => { + if (oldPosition === newPosition) return; + ordered.splice(newPosition, 0, ordered.splice(oldPosition, 1)[0]); + for (let i = 0; i < ordered.length; i++) { + this.storeModification({ nodeIndex: ordered[i].index, section: ORDER, prop: "order", value: i }); + } + }); + } + + storeModification({ nodeIndex, section, prop, value }) { + const groupMod = (this.modifications[this.selectedGroup] ??= {}); + const nodesMod = (groupMod.nodes ??= {}); + const nodeMod = (nodesMod[nodeIndex ?? this.selectedNodeInnerIndex] ??= {}); + const typeMod = (nodeMod[section] ??= {}); + if (typeof value === "object") { + const objMod = (typeMod[prop] ??= {}); + Object.assign(objMod, value); + } else { + typeMod[prop] = value; + } + } + + getEditElement(section, prop, value, placeholder, checked, checkable = true) { + if (value === placeholder) value = ""; + + const mods = this.modifications[this.selectedGroup]?.nodes?.[this.selectedNodeInnerIndex]?.[section]?.[prop]; + if (mods) { + if (mods.name != null) { + value = mods.name; + } + if (mods.visible != null) { + checked = mods.visible; + } + } + + return $el("div", [ + $el("input", { + value, + placeholder, + type: "text", + onchange: (e) => { + this.storeModification({ section, prop, value: { name: e.target.value } }); + }, + }), + $el("label", { textContent: "Visible" }, [ + $el("input", { + type: "checkbox", + checked, + disabled: !checkable, + onchange: (e) => { + this.storeModification({ section, prop, value: { visible: !!e.target.checked } }); + }, + }), + ]), + ]); + } + + buildWidgetsPage() { + const widgets = this.groupData.oldToNewWidgetMap[this.selectedNodeInnerIndex]; + const items = Object.keys(widgets ?? {}); + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.input; + this.widgetsPage.replaceChildren( + ...items.map((oldName) => { + return this.getEditElement("input", oldName, widgets[oldName], oldName, config?.[oldName]?.visible !== false); + }) + ); + return !!items.length; + } + + buildInputsPage() { + const inputs = this.groupData.nodeInputs[this.selectedNodeInnerIndex]; + const items = Object.keys(inputs ?? {}); + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.input; + this.inputsPage.replaceChildren( + ...items + .map((oldName) => { + let value = inputs[oldName]; + if (!value) { + return; + } + + return this.getEditElement("input", oldName, value, oldName, config?.[oldName]?.visible !== false); + }) + .filter(Boolean) + ); + return !!items.length; + } + + buildOutputsPage() { + const nodes = this.groupData.nodeData.nodes; + const innerNodeDef = this.groupData.getNodeDef(nodes[this.selectedNodeInnerIndex]); + const outputs = innerNodeDef?.output ?? []; + const groupOutputs = this.groupData.oldToNewOutputMap[this.selectedNodeInnerIndex]; + + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.output; + const node = this.groupData.nodeData.nodes[this.selectedNodeInnerIndex]; + const checkable = node.type !== "PrimitiveNode"; + this.outputsPage.replaceChildren( + ...outputs + .map((type, slot) => { + const groupOutputIndex = groupOutputs?.[slot]; + const oldName = innerNodeDef.output_name?.[slot] ?? type; + let value = config?.[slot]?.name; + const visible = config?.[slot]?.visible || groupOutputIndex != null; + if (!value || value === oldName) { + value = ""; + } + return this.getEditElement("output", slot, value, oldName, visible, checkable); + }) + .filter(Boolean) + ); + return !!outputs.length; + } + + show(type) { + const groupNodes = Object.keys(app.graph.extra?.groupNodes ?? {}).sort((a, b) => a.localeCompare(b)); + + this.innerNodesList = $el("ul.comfy-group-manage-list-items"); + this.widgetsPage = $el("section.comfy-group-manage-node-page"); + this.inputsPage = $el("section.comfy-group-manage-node-page"); + this.outputsPage = $el("section.comfy-group-manage-node-page"); + const pages = $el("div", [this.widgetsPage, this.inputsPage, this.outputsPage]); + + this.tabs = [ + ["Inputs", this.inputsPage], + ["Widgets", this.widgetsPage], + ["Outputs", this.outputsPage], + ].reduce((p, [name, page]) => { + p[name] = { + tab: $el("a", { + onclick: () => { + this.changeTab(name); + }, + textContent: name, + }), + page, + }; + return p; + }, {}); + + const outer = $el("div.comfy-group-manage-outer", [ + $el("header", [ + $el("h2", "Group Nodes"), + $el( + "select", + { + onchange: (e) => { + this.changeGroup(e.target.value); + }, + }, + groupNodes.map((g) => + $el("option", { + textContent: g, + selected: "workflow/" + g === type, + value: g, + }) + ) + ), + ]), + $el("main", [ + $el("section.comfy-group-manage-list", this.innerNodesList), + $el("section.comfy-group-manage-node", [ + $el( + "header", + Object.values(this.tabs).map((t) => t.tab) + ), + pages, + ]), + ]), + $el("footer", [ + $el( + "button.comfy-btn", + { + onclick: (e) => { + const node = app.graph._nodes.find((n) => n.type === "workflow/" + this.selectedGroup); + if (node) { + alert("This group node is in use in the current workflow, please first remove these."); + return; + } + if (confirm(`Are you sure you want to remove the node: "${this.selectedGroup}"`)) { + delete app.graph.extra.groupNodes[this.selectedGroup]; + LiteGraph.unregisterNodeType("workflow/" + this.selectedGroup); + } + this.show(); + }, + }, + "Delete Group Node" + ), + $el( + "button.comfy-btn", + { + onclick: async () => { + let nodesByType; + let recreateNodes = []; + const types = {}; + for (const g in this.modifications) { + const type = app.graph.extra.groupNodes[g]; + let config = (type.config ??= {}); + + let nodeMods = this.modifications[g]?.nodes; + if (nodeMods) { + const keys = Object.keys(nodeMods); + if (nodeMods[keys[0]][ORDER]) { + // If any node is reordered, they will all need sequencing + const orderedNodes = []; + const orderedMods = {}; + const orderedConfig = {}; + + for (const n of keys) { + const order = nodeMods[n][ORDER].order; + orderedNodes[order] = type.nodes[+n]; + orderedMods[order] = nodeMods[n]; + orderedNodes[order].index = order; + } + + // Rewrite links + for (const l of type.links) { + if (l[0] != null) l[0] = type.nodes[l[0]].index; + if (l[2] != null) l[2] = type.nodes[l[2]].index; + } + + // Rewrite externals + if (type.external) { + for (const ext of type.external) { + ext[0] = type.nodes[ext[0]]; + } + } + + // Rewrite modifications + for (const id of keys) { + if (config[id]) { + orderedConfig[type.nodes[id].index] = config[id]; + } + delete config[id]; + } + + type.nodes = orderedNodes; + nodeMods = orderedMods; + type.config = config = orderedConfig; + } + + merge(config, nodeMods); + } + + types[g] = type; + + if (!nodesByType) { + nodesByType = app.graph._nodes.reduce((p, n) => { + p[n.type] ??= []; + p[n.type].push(n); + return p; + }, {}); + } + + const nodes = nodesByType["workflow/" + g]; + if (nodes) recreateNodes.push(...nodes); + } + + await GroupNodeConfig.registerFromWorkflow(types, {}); + + for (const node of recreateNodes) { + node.recreate(); + } + + this.modifications = {}; + this.app.graph.setDirtyCanvas(true, true); + this.changeGroup(this.selectedGroup, false); + }, + }, + "Save" + ), + $el("button.comfy-btn", { onclick: () => this.element.close() }, "Close"), + ]), + ]); + + this.element.replaceChildren(outer); + this.changeGroup(type ? groupNodes.find((g) => "workflow/" + g === type) : groupNodes[0]); + this.element.showModal(); + + this.element.addEventListener("close", () => { + this.draggable?.dispose(); + }); + } +} \ No newline at end of file diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index f6292b9e..4f69ac76 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -33,6 +33,18 @@ function loadedImageToBlob(image) { return blob; } +function loadImage(imagePath) { + return new Promise((resolve, reject) => { + const image = new Image(); + + image.onload = function() { + resolve(image); + }; + + image.src = imagePath; + }); +} + async function uploadMask(filepath, formData) { await api.fetchApi('/upload/mask', { method: 'POST', @@ -42,7 +54,7 @@ async function uploadMask(filepath, formData) { }); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); - ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = api.apiURL("/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam()); + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = api.apiURL("/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam() + app.getRandParam()); if(ComfyApp.clipspace.images) ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; @@ -50,25 +62,25 @@ async function uploadMask(filepath, formData) { ClipspaceDialog.invalidatePreview(); } -function prepareRGB(image, backupCanvas, backupCtx) { +function prepare_mask(image, maskCanvas, maskCtx, maskColor) { // paste mask data into alpha channel - backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); - const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); + maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height); + const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height); - // refine mask image - for (let i = 0; i < backupData.data.length; i += 4) { - if(backupData.data[i+3] == 255) - backupData.data[i+3] = 0; + // invert mask + for (let i = 0; i < maskData.data.length; i += 4) { + if(maskData.data[i+3] == 255) + maskData.data[i+3] = 0; else - backupData.data[i+3] = 255; + maskData.data[i+3] = 255; - backupData.data[i] = 0; - backupData.data[i+1] = 0; - backupData.data[i+2] = 0; + maskData.data[i] = maskColor.r; + maskData.data[i+1] = maskColor.g; + maskData.data[i+2] = maskColor.b; } - backupCtx.globalCompositeOperation = 'source-over'; - backupCtx.putImageData(backupData, 0, 0); + maskCtx.globalCompositeOperation = 'source-over'; + maskCtx.putImageData(maskData, 0, 0); } class MaskEditorDialog extends ComfyDialog { @@ -98,6 +110,7 @@ class MaskEditorDialog extends ComfyDialog { createButton(name, callback) { var button = document.createElement("button"); + button.style.pointerEvents = "auto"; button.innerText = name; button.addEventListener("click", callback); return button; @@ -134,6 +147,7 @@ class MaskEditorDialog extends ComfyDialog { divElement.style.display = "flex"; divElement.style.position = "relative"; divElement.style.top = "2px"; + divElement.style.pointerEvents = "auto"; self.brush_slider_input = document.createElement('input'); self.brush_slider_input.setAttribute('type', 'range'); self.brush_slider_input.setAttribute('min', '1'); @@ -155,16 +169,13 @@ class MaskEditorDialog extends ComfyDialog { // If it is specified as relative, using it only as a hidden placeholder for padding is recommended // to prevent anomalies where it exceeds a certain size and goes outside of the window. - var placeholder = document.createElement("div"); - placeholder.style.position = "relative"; - placeholder.style.height = "50px"; - var bottom_panel = document.createElement("div"); bottom_panel.style.position = "absolute"; bottom_panel.style.bottom = "0px"; bottom_panel.style.left = "20px"; bottom_panel.style.right = "20px"; bottom_panel.style.height = "50px"; + bottom_panel.style.pointerEvents = "none"; var brush = document.createElement("div"); brush.id = "brush"; @@ -180,19 +191,32 @@ class MaskEditorDialog extends ComfyDialog { this.brush = brush; this.element.appendChild(imgCanvas); this.element.appendChild(maskCanvas); - this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button this.element.appendChild(bottom_panel); document.body.appendChild(brush); - var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { + var clearButton = this.createLeftButton("Clear", () => { + self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); + }); + + this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { self.brush_size = event.target.value; self.updateBrushPreview(self, null, null); }); - var clearButton = this.createLeftButton("Clear", - () => { - self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); - self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height); - }); + + this.colorButton = this.createLeftButton(this.getColorButtonText(), () => { + if (self.brush_color_mode === "black") { + self.brush_color_mode = "white"; + } + else if (self.brush_color_mode === "white") { + self.brush_color_mode = "negative"; + } + else { + self.brush_color_mode = "black"; + } + + self.updateWhenBrushColorModeChanged(); + }); + var cancelButton = this.createRightButton("Cancel", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); @@ -207,40 +231,47 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(imgCanvas); this.element.appendChild(maskCanvas); - this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button this.element.appendChild(bottom_panel); bottom_panel.appendChild(clearButton); bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); - bottom_panel.appendChild(brush_size_slider); + bottom_panel.appendChild(this.brush_size_slider); + bottom_panel.appendChild(this.colorButton); + + imgCanvas.style.position = "absolute"; + maskCanvas.style.position = "absolute"; - imgCanvas.style.position = "relative"; imgCanvas.style.top = "200"; imgCanvas.style.left = "0"; - maskCanvas.style.position = "absolute"; + maskCanvas.style.top = imgCanvas.style.top; + maskCanvas.style.left = imgCanvas.style.left; + + const maskCanvasStyle = this.getMaskCanvasStyle(); + maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode; + maskCanvas.style.opacity = maskCanvasStyle.opacity; } - show() { + async show() { + this.zoom_ratio = 1.0; + this.pan_x = 0; + this.pan_y = 0; + if(!this.is_layout_created) { // layout const imgCanvas = document.createElement('canvas'); const maskCanvas = document.createElement('canvas'); - const backupCanvas = document.createElement('canvas'); imgCanvas.id = "imageCanvas"; maskCanvas.id = "maskCanvas"; - backupCanvas.id = "backupCanvas"; this.setlayout(imgCanvas, maskCanvas); // prepare content this.imgCanvas = imgCanvas; this.maskCanvas = maskCanvas; - this.backupCanvas = backupCanvas; - this.maskCtx = maskCanvas.getContext('2d'); - this.backupCtx = backupCanvas.getContext('2d'); + this.maskCtx = maskCanvas.getContext('2d', {willReadFrequently: true }); this.setEventHandler(maskCanvas); @@ -252,6 +283,8 @@ class MaskEditorDialog extends ComfyDialog { mutations.forEach(function(mutation) { if (mutation.type === 'attributes' && mutation.attributeName === 'style') { if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + self.brush.style.display = "none"; ComfyApp.onClipspaceEditorClosed(); } @@ -264,7 +297,8 @@ class MaskEditorDialog extends ComfyDialog { observer.observe(this.element, config); } - this.setImages(this.imgCanvas, this.backupCanvas); + // The keydown event needs to be reconfigured when closing the dialog as it gets removed. + document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); if(ComfyApp.clipspace_return_node) { this.saveButton.innerText = "Save to node"; @@ -275,100 +309,237 @@ class MaskEditorDialog extends ComfyDialog { this.saveButton.disabled = false; this.element.style.display = "block"; + this.element.style.width = "85%"; + this.element.style.margin = "0 7.5%"; + this.element.style.height = "100vh"; + this.element.style.top = "50%"; + this.element.style.left = "42%"; this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + + await this.setImages(this.imgCanvas); + + this.is_visible = true; } isOpened() { return this.element.style.display == "block"; } - setImages(imgCanvas, backupCanvas) { - const imgCtx = imgCanvas.getContext('2d'); - const backupCtx = backupCanvas.getContext('2d'); + invalidateCanvas(orig_image, mask_image) { + this.imgCanvas.width = orig_image.width; + this.imgCanvas.height = orig_image.height; + + this.maskCanvas.width = orig_image.width; + this.maskCanvas.height = orig_image.height; + + let imgCtx = this.imgCanvas.getContext('2d', {willReadFrequently: true }); + let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true }); + + imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height); + prepare_mask(mask_image, this.maskCanvas, maskCtx, this.getMaskColor()); + } + + async setImages(imgCanvas) { + let self = this; + + const imgCtx = imgCanvas.getContext('2d', {willReadFrequently: true }); const maskCtx = this.maskCtx; const maskCanvas = this.maskCanvas; - backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); // image load - const orig_image = new Image(); - window.addEventListener("resize", () => { - // repositioning - imgCanvas.width = window.innerWidth - 250; - imgCanvas.height = window.innerHeight - 200; - - // redraw image - let drawWidth = orig_image.width; - let drawHeight = orig_image.height; - if (orig_image.width > imgCanvas.width) { - drawWidth = imgCanvas.width; - drawHeight = (drawWidth / orig_image.width) * orig_image.height; - } - - if (drawHeight > imgCanvas.height) { - drawHeight = imgCanvas.height; - drawWidth = (drawHeight / orig_image.height) * orig_image.width; - } - - imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); - - // update mask - maskCanvas.width = drawWidth; - maskCanvas.height = drawHeight; - maskCanvas.style.top = imgCanvas.offsetTop + "px"; - maskCanvas.style.left = imgCanvas.offsetLeft + "px"; - backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); - maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); - }); - const filepath = ComfyApp.clipspace.images; - const touched_image = new Image(); - - touched_image.onload = function() { - backupCanvas.width = touched_image.width; - backupCanvas.height = touched_image.height; - - prepareRGB(touched_image, backupCanvas, backupCtx); - }; - const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) alpha_url.searchParams.delete('channel'); alpha_url.searchParams.delete('preview'); alpha_url.searchParams.set('channel', 'a'); - touched_image.src = alpha_url; + let mask_image = await loadImage(alpha_url); // original image load - orig_image.onload = function() { - window.dispatchEvent(new Event('resize')); - }; - const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); rgb_url.searchParams.delete('channel'); rgb_url.searchParams.set('channel', 'rgb'); - orig_image.src = rgb_url; - this.image = orig_image; + this.image = new Image(); + this.image.onload = function() { + maskCanvas.width = self.image.width; + maskCanvas.height = self.image.height; + + self.invalidateCanvas(self.image, mask_image); + self.initializeCanvasPanZoom(); + }; + this.image.src = rgb_url; } - setEventHandler(maskCanvas) { - maskCanvas.addEventListener("contextmenu", (event) => { - event.preventDefault(); - }); + initializeCanvasPanZoom() { + // set initialize + let drawWidth = this.image.width; + let drawHeight = this.image.height; + + let width = this.element.clientWidth; + let height = this.element.clientHeight; + + if (this.image.width > width) { + drawWidth = width; + drawHeight = (drawWidth / this.image.width) * this.image.height; + } + + if (drawHeight > height) { + drawHeight = height; + drawWidth = (drawHeight / this.image.height) * this.image.width; + } + + this.zoom_ratio = drawWidth/this.image.width; + + const canvasX = (width - drawWidth) / 2; + const canvasY = (height - drawHeight) / 2; + this.pan_x = canvasX; + this.pan_y = canvasY; + + this.invalidatePanZoom(); + } + + + invalidatePanZoom() { + let raw_width = this.image.width * this.zoom_ratio; + let raw_height = this.image.height * this.zoom_ratio; + if(this.pan_x + raw_width < 10) { + this.pan_x = 10 - raw_width; + } + + if(this.pan_y + raw_height < 10) { + this.pan_y = 10 - raw_height; + } + + let width = `${raw_width}px`; + let height = `${raw_height}px`; + + let left = `${this.pan_x}px`; + let top = `${this.pan_y}px`; + + this.maskCanvas.style.width = width; + this.maskCanvas.style.height = height; + this.maskCanvas.style.left = left; + this.maskCanvas.style.top = top; + + this.imgCanvas.style.width = width; + this.imgCanvas.style.height = height; + this.imgCanvas.style.left = left; + this.imgCanvas.style.top = top; + } + + + setEventHandler(maskCanvas) { const self = this; - maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); - maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); - document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); - maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); - maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); - maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); - maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); - document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); + + if(!this.handler_registered) { + maskCanvas.addEventListener("contextmenu", (event) => { + event.preventDefault(); + }); + + this.element.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); + this.element.addEventListener('pointermove', (event) => this.pointMoveEvent(self,event)); + this.element.addEventListener('touchmove', (event) => this.pointMoveEvent(self,event)); + + this.element.addEventListener('dragstart', (event) => { + if(event.ctrlKey) { + event.preventDefault(); + } + }); + + maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); + maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); + maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); + + document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); + + this.handler_registered = true; + } + } + + getMaskCanvasStyle() { + if (this.brush_color_mode === "negative") { + return { + mixBlendMode: "difference", + opacity: "1", + }; + } + else { + return { + mixBlendMode: "initial", + opacity: "0.7", + }; + } + } + + getMaskColor() { + if (this.brush_color_mode === "black") { + return { r: 0, g: 0, b: 0 }; + } + if (this.brush_color_mode === "white") { + return { r: 255, g: 255, b: 255 }; + } + if (this.brush_color_mode === "negative") { + // negative effect only works with white color + return { r: 255, g: 255, b: 255 }; + } + + return { r: 0, g: 0, b: 0 }; + } + + getMaskFillStyle() { + const maskColor = this.getMaskColor(); + + return "rgb(" + maskColor.r + "," + maskColor.g + "," + maskColor.b + ")"; + } + + getColorButtonText() { + let colorCaption = "unknown"; + + if (this.brush_color_mode === "black") { + colorCaption = "black"; + } + else if (this.brush_color_mode === "white") { + colorCaption = "white"; + } + else if (this.brush_color_mode === "negative") { + colorCaption = "negative"; + } + + return "Color: " + colorCaption; + } + + updateWhenBrushColorModeChanged() { + this.colorButton.innerText = this.getColorButtonText(); + + // update mask canvas css styles + + const maskCanvasStyle = this.getMaskCanvasStyle(); + this.maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode; + this.maskCanvas.style.opacity = maskCanvasStyle.opacity; + + // update mask canvas rgb colors + + const maskColor = this.getMaskColor(); + + const maskData = this.maskCtx.getImageData(0, 0, this.maskCanvas.width, this.maskCanvas.height); + + for (let i = 0; i < maskData.data.length; i += 4) { + maskData.data[i] = maskColor.r; + maskData.data[i+1] = maskColor.g; + maskData.data[i+2] = maskColor.b; + } + + this.maskCtx.putImageData(maskData, 0, 0); } brush_size = 10; + brush_color_mode = "black"; drawing_mode = false; lastx = -1; lasty = -1; @@ -378,8 +549,10 @@ class MaskEditorDialog extends ComfyDialog { const self = MaskEditorDialog.instance; if (event.key === ']') { self.brush_size = Math.min(self.brush_size+2, 100); + self.brush_slider_input.value = self.brush_size; } else if (event.key === '[') { self.brush_size = Math.max(self.brush_size-2, 1); + self.brush_slider_input.value = self.brush_size; } else if(event.key === 'Enter') { self.save(); } @@ -389,6 +562,10 @@ class MaskEditorDialog extends ComfyDialog { static handlePointerUp(event) { event.preventDefault(); + + this.mousedown_x = null; + this.mousedown_y = null; + MaskEditorDialog.instance.drawing_mode = false; } @@ -398,24 +575,83 @@ class MaskEditorDialog extends ComfyDialog { var centerX = self.cursorX; var centerY = self.cursorY; - brush.style.width = self.brush_size * 2 + "px"; - brush.style.height = self.brush_size * 2 + "px"; - brush.style.left = (centerX - self.brush_size) + "px"; - brush.style.top = (centerY - self.brush_size) + "px"; + brush.style.width = self.brush_size * 2 * this.zoom_ratio + "px"; + brush.style.height = self.brush_size * 2 * this.zoom_ratio + "px"; + brush.style.left = (centerX - self.brush_size * this.zoom_ratio) + "px"; + brush.style.top = (centerY - self.brush_size * this.zoom_ratio) + "px"; } handleWheelEvent(self, event) { - if(event.deltaY < 0) - self.brush_size = Math.min(self.brush_size+2, 100); - else - self.brush_size = Math.max(self.brush_size-2, 1); + event.preventDefault(); - self.brush_slider_input.value = self.brush_size; + if(event.ctrlKey) { + // zoom canvas + if(event.deltaY < 0) { + this.zoom_ratio = Math.min(10.0, this.zoom_ratio+0.2); + } + else { + this.zoom_ratio = Math.max(0.2, this.zoom_ratio-0.2); + } + + this.invalidatePanZoom(); + } + else { + // adjust brush size + if(event.deltaY < 0) + this.brush_size = Math.min(this.brush_size+2, 100); + else + this.brush_size = Math.max(this.brush_size-2, 1); + + this.brush_slider_input.value = this.brush_size; + + this.updateBrushPreview(this); + } + } + + pointMoveEvent(self, event) { + this.cursorX = event.pageX; + this.cursorY = event.pageY; self.updateBrushPreview(self); + + if(event.ctrlKey) { + event.preventDefault(); + self.pan_move(self, event); + } + + let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1; + + if(event.shiftKey && left_button_down) { + self.drawing_mode = false; + + const y = event.clientY; + let delta = (self.zoom_lasty - y)*0.005; + self.zoom_ratio = Math.max(Math.min(10.0, self.last_zoom_ratio - delta), 0.2); + + this.invalidatePanZoom(); + return; + } + } + + pan_move(self, event) { + if(event.buttons == 1) { + if(this.mousedown_x) { + let deltaX = this.mousedown_x - event.clientX; + let deltaY = this.mousedown_y - event.clientY; + + self.pan_x = this.mousedown_pan_x - deltaX; + self.pan_y = this.mousedown_pan_y - deltaY; + + self.invalidatePanZoom(); + } + } } draw_move(self, event) { + if(event.ctrlKey || event.shiftKey) { + return; + } + event.preventDefault(); this.cursorX = event.pageX; @@ -423,7 +659,10 @@ class MaskEditorDialog extends ComfyDialog { self.updateBrushPreview(self); - if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) { + let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1; + let right_button_down = [2, 5, 32].includes(event.buttons); + + if (!event.altKey && left_button_down) { var diff = performance.now() - self.lasttime; const maskRect = self.maskCanvas.getBoundingClientRect(); @@ -439,6 +678,9 @@ class MaskEditorDialog extends ComfyDialog { y = event.targetTouches[0].clientY - maskRect.top; } + x /= self.zoom_ratio; + y /= self.zoom_ratio; + var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { brush_size *= event.pressure; @@ -455,7 +697,7 @@ class MaskEditorDialog extends ComfyDialog { if(diff > 20 && !this.drawing_mode) requestAnimationFrame(() => { self.maskCtx.beginPath(); - self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); self.maskCtx.fill(); @@ -465,7 +707,7 @@ class MaskEditorDialog extends ComfyDialog { else requestAnimationFrame(() => { self.maskCtx.beginPath(); - self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; var dx = x - self.lastx; @@ -487,10 +729,10 @@ class MaskEditorDialog extends ComfyDialog { self.lasttime = performance.now(); } - else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { + else if((event.altKey && left_button_down) || right_button_down) { const maskRect = self.maskCanvas.getBoundingClientRect(); - const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; - const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; + const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { @@ -540,6 +782,17 @@ class MaskEditorDialog extends ComfyDialog { } handlePointerDown(self, event) { + if(event.ctrlKey) { + if (event.buttons == 1) { + this.mousedown_x = event.clientX; + this.mousedown_y = event.clientY; + + this.mousedown_pan_x = this.pan_x; + this.mousedown_pan_y = this.pan_y; + } + return; + } + var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { brush_size *= event.pressure; @@ -550,13 +803,20 @@ class MaskEditorDialog extends ComfyDialog { self.drawing_mode = true; event.preventDefault(); + + if(event.shiftKey) { + self.zoom_lasty = event.clientY; + self.last_zoom_ratio = self.zoom_ratio; + return; + } + const maskRect = self.maskCanvas.getBoundingClientRect(); - const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; - const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; + const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; self.maskCtx.beginPath(); - if (event.button == 0) { - self.maskCtx.fillStyle = "rgb(0,0,0)"; + if (!event.altKey && event.button == 0) { + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; } else { self.maskCtx.globalCompositeOperation = "destination-out"; @@ -570,15 +830,18 @@ class MaskEditorDialog extends ComfyDialog { } async save() { - const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); + const backupCanvas = document.createElement('canvas'); + const backupCtx = backupCanvas.getContext('2d', {willReadFrequently:true}); + backupCanvas.width = this.image.width; + backupCanvas.height = this.image.height; - backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + backupCtx.clearRect(0,0, backupCanvas.width, backupCanvas.height); backupCtx.drawImage(this.maskCanvas, 0, 0, this.maskCanvas.width, this.maskCanvas.height, - 0, 0, this.backupCanvas.width, this.backupCanvas.height); + 0, 0, backupCanvas.width, backupCanvas.height); // paste mask data into alpha channel - const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); + const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); // refine mask image for (let i = 0; i < backupData.data.length; i += 4) { @@ -615,7 +878,7 @@ class MaskEditorDialog extends ComfyDialog { ComfyApp.clipspace.widgets[index].value = item; } - const dataURL = this.backupCanvas.toDataURL(); + const dataURL = backupCanvas.toDataURL(); const blob = dataURLToBlob(dataURL); let original_url = new URL(this.image.src); @@ -657,4 +920,4 @@ app.registerExtension({ const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor); } -}); \ No newline at end of file +}); diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js index 2d482174..9350ba65 100644 --- a/web/extensions/core/nodeTemplates.js +++ b/web/extensions/core/nodeTemplates.js @@ -1,4 +1,5 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; import { ComfyDialog, $el } from "../../scripts/ui.js"; import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; @@ -20,16 +21,20 @@ import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; // Open the manage dialog and Drag and drop elements using the "Name:" label as handle const id = "Comfy.NodeTemplates"; +const file = "comfy.templates.json"; class ManageTemplates extends ComfyDialog { constructor() { super(); + this.load().then((v) => { + this.templates = v; + }); + this.element.classList.add("comfy-manage-templates"); - this.templates = this.load(); this.draggedEl = null; this.saveVisualCue = null; this.emptyImg = new Image(); - this.emptyImg.src = ''; + this.emptyImg.src = ""; this.importInput = $el("input", { type: "file", @@ -67,17 +72,50 @@ class ManageTemplates extends ComfyDialog { return btns; } - load() { - const templates = localStorage.getItem(id); - if (templates) { - return JSON.parse(templates); + async load() { + let templates = []; + if (app.storageLocation === "server") { + if (app.isNewUserSession) { + // New user so migrate existing templates + const json = localStorage.getItem(id); + if (json) { + templates = JSON.parse(json); + } + await api.storeUserData(file, json, { stringify: false }); + } else { + const res = await api.getUserData(file); + if (res.status === 200) { + try { + templates = await res.json(); + } catch (error) { + } + } else if (res.status !== 404) { + console.error(res.status + " " + res.statusText); + } + } } else { - return []; + const json = localStorage.getItem(id); + if (json) { + templates = JSON.parse(json); + } } + + return templates ?? []; } - store() { - localStorage.setItem(id, JSON.stringify(this.templates)); + async store() { + if(app.storageLocation === "server") { + const templates = JSON.stringify(this.templates, undefined, 4); + localStorage.setItem(id, templates); // Backwards compatibility + try { + await api.storeUserData(file, templates, { stringify: false }); + } catch (error) { + console.error(error); + alert(error.message); + } + } else { + localStorage.setItem(id, JSON.stringify(this.templates)); + } } async importAll() { @@ -85,14 +123,14 @@ class ManageTemplates extends ComfyDialog { if (file.type === "application/json" || file.name.endsWith(".json")) { const reader = new FileReader(); reader.onload = async () => { - var importFile = JSON.parse(reader.result); - if (importFile && importFile?.templates) { + const importFile = JSON.parse(reader.result); + if (importFile?.templates) { for (const template of importFile.templates) { if (template?.name && template?.data) { this.templates.push(template); } } - this.store(); + await this.store(); } }; await reader.readAsText(file); @@ -159,12 +197,12 @@ class ManageTemplates extends ComfyDialog { e.currentTarget.style.border = "1px dashed transparent"; e.currentTarget.removeAttribute("draggable"); - // rearrange the elements in the localStorage + // rearrange the elements this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { var prev_i = el.dataset.id; if ( el == this.draggedEl && prev_i != i ) { - [this.templates[i], this.templates[prev_i]] = [this.templates[prev_i], this.templates[i]]; + this.templates.splice(i, 0, this.templates.splice(prev_i, 1)[0]); } el.dataset.id = i; }); diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index 499a171d..4feff91e 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -1,10 +1,11 @@ import { app } from "../../scripts/app.js"; +import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js"; // Node that allows you to redirect connections for cleaner graphs app.registerExtension({ name: "Comfy.RerouteNode", - registerCustomNodes() { + registerCustomNodes(app) { class RerouteNode { constructor() { if (!this.properties) { @@ -16,6 +17,12 @@ app.registerExtension({ this.addInput("", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*"); + this.onAfterGraphConfigured = function () { + requestAnimationFrame(() => { + this.onConnectionsChange(LiteGraph.INPUT, null, true, null); + }); + }; + this.onConnectionsChange = function (type, index, connected, link_info) { this.applyOrientation(); @@ -47,6 +54,7 @@ app.registerExtension({ const linkId = currentNode.inputs[0].link; if (linkId !== null) { const link = app.graph.links[linkId]; + if (!link) return; const node = app.graph.getNodeById(link.origin_id); const type = node.constructor.type; if (type === "Reroute") { @@ -54,8 +62,7 @@ app.registerExtension({ // We've found a circle currentNode.disconnectInput(link.target_slot); currentNode = null; - } - else { + } else { // Move the previous node currentNode = node; } @@ -94,8 +101,11 @@ app.registerExtension({ updateNodes.push(node); } else { // We've found an output - const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null; - if (inputType && nodeOutType !== inputType) { + const nodeOutType = + node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type + ? node.inputs[link.target_slot].type + : null; + if (inputType && inputType !== "*" && nodeOutType !== inputType) { // The output doesnt match our input so disconnect it node.disconnectInput(link.target_slot); } else { @@ -111,6 +121,9 @@ app.registerExtension({ const displayType = inputType || outputType || "*"; const color = LGraphCanvas.link_type_colors[displayType]; + let widgetConfig; + let targetWidget; + let widgetType; // Update the types of each node for (const node of updateNodes) { // If we dont have an input type we are always wildcard but we'll show the output type @@ -125,10 +138,38 @@ app.registerExtension({ const link = app.graph.links[l]; if (link) { link.color = color; + + if (app.configuringGraph) continue; + const targetNode = app.graph.getNodeById(link.target_id); + const targetInput = targetNode.inputs?.[link.target_slot]; + if (targetInput?.widget) { + const config = getWidgetConfig(targetInput); + if (!widgetConfig) { + widgetConfig = config[1] ?? {}; + widgetType = config[0]; + } + if (!targetWidget) { + targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name); + } + + const merged = mergeIfValid(targetInput, [config[0], widgetConfig]); + if (merged.customConfig) { + widgetConfig = merged.customConfig; + } + } } } } + for (const node of updateNodes) { + if (widgetConfig && outputType) { + node.inputs[0].widget = { name: "value" }; + setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget); + } else { + setWidgetConfig(node.inputs[0], null); + } + } + if (inputNode) { const link = app.graph.links[inputNode.inputs[0].link]; if (link) { @@ -173,8 +214,8 @@ app.registerExtension({ }, { // naming is inverted with respect to LiteGraphNode.horizontal - // LiteGraphNode.horizontal == true means that - // each slot in the inputs and outputs are layed out horizontally, + // LiteGraphNode.horizontal == true means that + // each slot in the inputs and outputs are layed out horizontally, // which is the opposite of the visual orientation of the inputs and outputs as a node content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"), callback: () => { @@ -187,7 +228,7 @@ app.registerExtension({ applyOrientation() { this.horizontal = this.properties.horizontal; if (this.horizontal) { - // we correct the input position, because LiteGraphNode.horizontal + // we correct the input position, because LiteGraphNode.horizontal // doesn't account for title presence // which reroute nodes don't have this.inputs[0].pos = [this.size[0] / 2, 0]; diff --git a/web/extensions/core/saveImageExtraOutput.js b/web/extensions/core/saveImageExtraOutput.js index 99e2213b..a0506b43 100644 --- a/web/extensions/core/saveImageExtraOutput.js +++ b/web/extensions/core/saveImageExtraOutput.js @@ -1,5 +1,5 @@ import { app } from "../../scripts/app.js"; - +import { applyTextReplacements } from "../../scripts/utils.js"; // Use widget values and dates in output filenames app.registerExtension({ @@ -7,84 +7,19 @@ app.registerExtension({ async beforeRegisterNodeDef(nodeType, nodeData, app) { if (nodeData.name === "SaveImage") { const onNodeCreated = nodeType.prototype.onNodeCreated; - - // Simple date formatter - const parts = { - d: (d) => d.getDate(), - M: (d) => d.getMonth() + 1, - h: (d) => d.getHours(), - m: (d) => d.getMinutes(), - s: (d) => d.getSeconds(), - }; - const format = - Object.keys(parts) - .map((k) => k + k + "?") - .join("|") + "|yyy?y?"; - - function formatDate(text, date) { - return text.replace(new RegExp(format, "g"), function (text) { - if (text === "yy") return (date.getFullYear() + "").substring(2); - if (text === "yyyy") return date.getFullYear(); - if (text[0] in parts) { - const p = parts[text[0]](date); - return (p + "").padStart(text.length, "0"); - } - return text; - }); - } - - // When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R + // When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; const widget = this.widgets.find((w) => w.name === "filename_prefix"); widget.serializeValue = () => { - return widget.value.replace(/%([^%]+)%/g, function (match, text) { - const split = text.split("."); - if (split.length !== 2) { - // Special handling for dates - if (split[0].startsWith("date:")) { - return formatDate(split[0].substring(5), new Date()); - } - - if (text !== "width" && text !== "height") { - // Dont warn on standard replacements - console.warn("Invalid replacement pattern", text); - } - return match; - } - - // Find node with matching S&R property name - let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]); - // If we cant, see if there is a node with that title - if (!nodes.length) { - nodes = app.graph._nodes.filter((n) => n.title === split[0]); - } - if (!nodes.length) { - console.warn("Unable to find node", split[0]); - return match; - } - - if (nodes.length > 1) { - console.warn("Multiple nodes matched", split[0], "using first match"); - } - - const node = nodes[0]; - - const widget = node.widgets?.find((w) => w.name === split[1]); - if (!widget) { - console.warn("Unable to find widget", split[1], "on node", split[0], node); - return match; - } - - return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); - }); + return applyTextReplacements(app, widget.value); }; return r; }; } else { - // When any other node is created add a property to alias the node + // When any other node is created add a property to alias the node const onNodeCreated = nodeType.prototype.onNodeCreated; nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; diff --git a/web/extensions/core/simpleTouchSupport.js b/web/extensions/core/simpleTouchSupport.js new file mode 100644 index 00000000..041fc2c4 --- /dev/null +++ b/web/extensions/core/simpleTouchSupport.js @@ -0,0 +1,102 @@ +import { app } from "../../scripts/app.js"; + +let touchZooming; +let touchCount = 0; + +app.registerExtension({ + name: "Comfy.SimpleTouchSupport", + setup() { + let zoomPos; + let touchTime; + let lastTouch; + + function getMultiTouchPos(e) { + return Math.hypot(e.touches[0].clientX - e.touches[1].clientX, e.touches[0].clientY - e.touches[1].clientY); + } + + app.canvasEl.addEventListener( + "touchstart", + (e) => { + touchCount++; + lastTouch = null; + if (e.touches?.length === 1) { + // Store start time for press+hold for context menu + touchTime = new Date(); + lastTouch = e.touches[0]; + } else { + touchTime = null; + if (e.touches?.length === 2) { + // Store center pos for zoom + zoomPos = getMultiTouchPos(e); + app.canvas.pointer_is_down = false; + } + } + }, + true + ); + + app.canvasEl.addEventListener("touchend", (e) => { + touchZooming = false; + touchCount = e.touches?.length ?? touchCount - 1; + if (touchTime && !e.touches?.length) { + if (new Date() - touchTime > 600) { + try { + // hack to get litegraph to use this event + e.constructor = CustomEvent; + } catch (error) {} + e.clientX = lastTouch.clientX; + e.clientY = lastTouch.clientY; + + app.canvas.pointer_is_down = true; + app.canvas._mousedown_callback(e); + } + touchTime = null; + } + }); + + app.canvasEl.addEventListener( + "touchmove", + (e) => { + touchTime = null; + if (e.touches?.length === 2) { + app.canvas.pointer_is_down = false; + touchZooming = true; + LiteGraph.closeAllContextMenus(); + app.canvas.search_box?.close(); + const newZoomPos = getMultiTouchPos(e); + + const midX = (e.touches[0].clientX + e.touches[1].clientX) / 2; + const midY = (e.touches[0].clientY + e.touches[1].clientY) / 2; + + let scale = app.canvas.ds.scale; + const diff = zoomPos - newZoomPos; + if (diff > 0.5) { + scale *= 1 / 1.07; + } else if (diff < -0.5) { + scale *= 1.07; + } + app.canvas.ds.changeScale(scale, [midX, midY]); + app.canvas.setDirty(true, true); + zoomPos = newZoomPos; + } + }, + true + ); + }, +}); + +const processMouseDown = LGraphCanvas.prototype.processMouseDown; +LGraphCanvas.prototype.processMouseDown = function (e) { + if (touchZooming || touchCount) { + return; + } + return processMouseDown.apply(this, arguments); +}; + +const processMouseMove = LGraphCanvas.prototype.processMouseMove; +LGraphCanvas.prototype.processMouseMove = function (e) { + if (touchZooming || touchCount > 1) { + return; + } + return processMouseMove.apply(this, arguments); +}; diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index c6613b0f..900eed2a 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -1,4 +1,5 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js" const MAX_HISTORY = 50; @@ -15,6 +16,7 @@ function checkState() { } activeState = clone(currentState); redo.length = 0; + api.dispatchEvent(new CustomEvent("graphChanged", { detail: activeState })); } } @@ -71,31 +73,28 @@ function graphEqual(a, b, root = true) { } const undoRedo = async (e) => { + const updateState = async (source, target) => { + const prevState = source.pop(); + if (prevState) { + target.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState, false); + activeState = prevState; + } + } if (e.ctrlKey || e.metaKey) { if (e.key === "y") { - const prevState = redo.pop(); - if (prevState) { - undo.push(activeState); - isOurLoad = true; - await app.loadGraphData(prevState); - activeState = prevState; - } + updateState(redo, undo); return true; } else if (e.key === "z") { - const prevState = undo.pop(); - if (prevState) { - redo.push(activeState); - isOurLoad = true; - await app.loadGraphData(prevState); - activeState = prevState; - } + updateState(undo, redo); return true; } } }; const bindInput = (activeEl) => { - if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") { + if (activeEl && activeEl.tagName !== "CANVAS" && activeEl.tagName !== "BODY") { for (const evt of ["change", "input", "blur"]) { if (`on${evt}` in activeEl) { const listener = () => { @@ -109,15 +108,23 @@ const bindInput = (activeEl) => { } }; +let keyIgnored = false; window.addEventListener( "keydown", (e) => { requestAnimationFrame(async () => { - const activeEl = document.activeElement; - if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { - // Ignore events on inputs, they have their native history - return; + let activeEl; + // If we are auto queue in change mode then we do want to trigger on inputs + if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") { + activeEl = document.activeElement; + if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { + // Ignore events on inputs, they have their native history + return; + } } + + keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta"; + if (keyIgnored) return; // Check if this is a ctrl+z ctrl+y if (await undoRedo(e)) return; @@ -130,11 +137,23 @@ window.addEventListener( true ); +window.addEventListener("keyup", (e) => { + if (keyIgnored) { + keyIgnored = false; + checkState(); + } +}); + // Handle clicking DOM elements (e.g. widgets) window.addEventListener("mouseup", () => { checkState(); }); +// Handle prompt queue event for dynamic widget changes +api.addEventListener("promptQueued", () => { + checkState(); +}); + // Handle litegraph clicks const processMouseUp = LGraphCanvas.prototype.processMouseUp; LGraphCanvas.prototype.processMouseUp = function (e) { @@ -148,3 +167,11 @@ LGraphCanvas.prototype.processMouseDown = function (e) { checkState(); return v; }; + +// Handle litegraph context menu for COMBO widgets +const close = LiteGraph.ContextMenu.prototype.close; +LiteGraph.ContextMenu.prototype.close = function(e) { + const v = close.apply(this, arguments); + checkState(); + return v; +} \ No newline at end of file diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index b6fa411f..23f51d81 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -1,10 +1,16 @@ import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js"; import { app } from "../../scripts/app.js"; +import { applyTextReplacements } from "../../scripts/utils.js"; const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); +const TARGET = Symbol(); // Used for reroutes to specify the real target widget + +export function getWidgetConfig(slot) { + return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG](); +} function getConfig(widgetName) { const { nodeData } = this.constructor; @@ -16,6 +22,7 @@ function isConvertableWidget(widget, config) { } function hideWidget(node, widget, suffix = "") { + if (widget.type?.startsWith(CONVERTED_TYPE)) return; widget.origType = widget.type; widget.origComputeSize = widget.computeSize; widget.origSerializeValue = widget.serializeValue; @@ -100,7 +107,6 @@ function getWidgetType(config) { return { type }; } - function isValidCombo(combo, obj) { // New input isnt a combo if (!(obj instanceof Array)) { @@ -121,6 +127,31 @@ function isValidCombo(combo, obj) { return true; } +export function setWidgetConfig(slot, config, target) { + if (!slot.widget) return; + if (config) { + slot.widget[GET_CONFIG] = () => config; + slot.widget[TARGET] = target; + } else { + delete slot.widget; + } + + if (slot.link) { + const link = app.graph.links[slot.link]; + if (link) { + const originNode = app.graph.getNodeById(link.origin_id); + if (originNode.type === "PrimitiveNode") { + if (config) { + originNode.recreateWidget(); + } else if(!app.configuringGraph) { + originNode.disconnectOutput(0); + originNode.onLastDisconnect(); + } + } + } + } +} + export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) { if (!config1) { config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG](); @@ -150,7 +181,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; for (const k of keys.values()) { - if (k !== "default" && k !== "forceInput" && k !== "defaultInput") { + if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") { let v1 = config1[1][k]; let v2 = config2[1]?.[k]; @@ -230,6 +261,12 @@ app.registerExtension({ async beforeRegisterNodeDef(nodeType, nodeData, app) { // Add menu options to conver to/from widgets const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; + nodeType.prototype.convertWidgetToInput = function (widget) { + const config = getConfig.call(this, widget.name) ?? [widget.type, widget.options || {}]; + if (!isConvertableWidget(widget, config)) return false; + convertToInput(this, widget, config); + return true; + }; nodeType.prototype.getExtraMenuOptions = function (_, options) { const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : undefined; @@ -405,11 +442,16 @@ app.registerExtension({ }; }, registerCustomNodes() { + const replacePropertyName = "Run widget replace on values"; class PrimitiveNode { constructor() { this.addOutput("connect to widget input", "*"); this.serialize_widgets = true; this.isVirtualNode = true; + + if (!this.properties || !(replacePropertyName in this.properties)) { + this.addProperty(replacePropertyName, false, "boolean"); + } } applyToGraph(extraLinks = []) { @@ -430,18 +472,29 @@ app.registerExtension({ } let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks]; + let v = this.widgets?.[0].value; + if(v && this.properties[replacePropertyName]) { + v = applyTextReplacements(app, v); + } + // For each output link copy our value over the original widget value for (const linkInfo of links) { const node = this.graph.getNodeById(linkInfo.target_id); const input = node.inputs[linkInfo.target_slot]; - const widgetName = input.widget.name; - if (widgetName) { - const widget = node.widgets.find((w) => w.name === widgetName); - if (widget) { - widget.value = this.widgets[0].value; - if (widget.callback) { - widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); - } + let widget; + if (input.widget[TARGET]) { + widget = input.widget[TARGET]; + } else { + const widgetName = input.widget.name; + if (widgetName) { + widget = node.widgets.find((w) => w.name === widgetName); + } + } + + if (widget) { + widget.value = v; + if (widget.callback) { + widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); } } } @@ -494,14 +547,13 @@ app.registerExtension({ this.#mergeWidgetConfig(); if (!links?.length) { - this.#onLastDisconnect(); + this.onLastDisconnect(); } } } onConnectOutput(slot, type, input, target_node, target_slot) { // Fires before the link is made allowing us to reject it if it isn't valid - // No widget, we cant connect if (!input.widget) { if (!(input.type in ComfyWidgets)) return false; @@ -519,6 +571,10 @@ app.registerExtension({ #onFirstConnection(recreating) { // First connection can fire before the graph is ready on initial load so random things can be missing + if (!this.outputs[0].links) { + this.onLastDisconnect(); + return; + } const linkId = this.outputs[0].links[0]; const link = this.graph.links[linkId]; if (!link) return; @@ -546,10 +602,10 @@ app.registerExtension({ this.outputs[0].name = type; this.outputs[0].widget = widget; - this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating); + this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]); } - #createWidget(inputData, node, widgetName, recreating) { + #createWidget(inputData, node, widgetName, recreating, targetWidget) { let type = inputData[0]; if (type instanceof Array) { @@ -563,7 +619,9 @@ app.registerExtension({ widget = this.addWidget(type, "value", null, () => {}, {}); } - if (node?.widgets && widget) { + if (targetWidget) { + widget.value = targetWidget.value; + } else if (node?.widgets && widget) { const theirWidget = node.widgets.find((w) => w.name === widgetName); if (theirWidget) { widget.value = theirWidget.value; @@ -577,11 +635,19 @@ app.registerExtension({ } addValueControlWidgets(this, widget, control_value, undefined, inputData); let filter = this.widgets_values?.[2]; - if(filter && this.widgets.length === 3) { + if (filter && this.widgets.length === 3) { this.widgets[2].value = filter; } } + // Restore any saved control values + const controlValues = this.controlValues; + if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) { + for(let i = 0; i < controlValues.length; i++) { + this.widgets[i + 1].value = controlValues[i]; + } + } + // When our value changes, update other widgets to reflect our changes // e.g. so LoadImage shows correct image const callback = widget.callback; @@ -610,12 +676,14 @@ app.registerExtension({ } } - #recreateWidget() { - const values = this.widgets.map((w) => w.value); + recreateWidget() { + const values = this.widgets?.map((w) => w.value); this.#removeWidgets(); this.#onFirstConnection(true); - for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; - return this.widgets[0]; + if (values?.length) { + for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; + } + return this.widgets?.[0]; } #mergeWidgetConfig() { @@ -631,7 +699,7 @@ app.registerExtension({ if (links?.length < 2 && hasConfig) { // Copy the widget options from the source if (links.length) { - this.#recreateWidget(); + this.recreateWidget(); } return; @@ -657,7 +725,7 @@ app.registerExtension({ // Only allow connections where the configs match const output = this.outputs[0]; const config2 = input.widget[GET_CONFIG](); - return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget); + return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget); } #removeWidgets() { @@ -668,11 +736,20 @@ app.registerExtension({ w.onRemove(); } } + + // Temporarily store the current values in case the node is being recreated + // e.g. by group node conversion + this.controlValues = []; + this.lastType = this.widgets[0]?.type; + for(let i = 1; i < this.widgets.length; i++) { + this.controlValues.push(this.widgets[i].value); + } + setTimeout(() => { delete this.lastType; delete this.controlValues }, 15); this.widgets.length = 0; } } - #onLastDisconnect() { + onLastDisconnect() { // We cant remove + re-add the output here as if you drag a link over the same link // it removes, then re-adds, causing it to break this.outputs[0].type = "*"; diff --git a/web/index.html b/web/index.html index 41bc246c..094db9d1 100644 --- a/web/index.html +++ b/web/index.html @@ -16,5 +16,33 @@ window.graph = app.graph; - + + + + ComfyUI + + + New user: + + + + + OR + + + Existing user: + + Select a user + + + + + + + + +