Browse Source

Merge remote-tracking branch 'comfyanonymous/master' into replace-chainner-models-spandrel

pull/2146/head
Joey Ballentine 8 months ago
parent
commit
8796bde936
  1. 3
      .ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat
  2. 2
      .ci/nightly/windows_base_files/run_nvidia_gpu.bat
  3. 49
      .ci/update_windows/update.py
  4. 8
      .ci/update_windows/update_comfyui.bat
  5. 3
      .ci/update_windows/update_comfyui_and_python_dependencies.bat
  6. 11
      .ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat
  7. 2
      .github/workflows/test-ui.yaml
  8. 71
      .github/workflows/windows_release_cu118_dependencies.yml
  9. 37
      .github/workflows/windows_release_cu118_dependencies_2.yml
  10. 79
      .github/workflows/windows_release_cu118_package.yml
  11. 7
      .github/workflows/windows_release_dependencies.yml
  12. 33
      .github/workflows/windows_release_nightly_pytorch.yml
  13. 2
      .github/workflows/windows_release_package.yml
  14. 1
      .gitignore
  15. 9
      .vscode/settings.json
  16. 17
      README.md
  17. 54
      app/app_settings.py
  18. 140
      app/user_manager.py
  19. 34
      comfy/cldm/cldm.py
  20. 22
      comfy/cli_args.py
  21. 194
      comfy/clip_model.py
  22. 73
      comfy/clip_vision.py
  23. 1
      comfy/conds.py
  24. 143
      comfy/controlnet.py
  25. 11
      comfy/diffusers_convert.py
  26. 1
      comfy/diffusers_load.py
  27. 37
      comfy/extra_samplers/uni_pc.py
  28. 54
      comfy/gligen.py
  29. 2
      comfy/k_diffusion/sampling.py
  30. 69
      comfy/latent_formats.py
  31. 161
      comfy/ldm/cascade/common.py
  32. 93
      comfy/ldm/cascade/controlnet.py
  33. 255
      comfy/ldm/cascade/stage_a.py
  34. 257
      comfy/ldm/cascade/stage_b.py
  35. 274
      comfy/ldm/cascade/stage_c.py
  36. 95
      comfy/ldm/cascade/stage_c_coder.py
  37. 5
      comfy/ldm/models/autoencoder.py
  38. 123
      comfy/ldm/modules/attention.py
  39. 54
      comfy/ldm/modules/diffusionmodules/model.py
  40. 96
      comfy/ldm/modules/diffusionmodules/openaimodel.py
  41. 16
      comfy/ldm/modules/diffusionmodules/upscaling.py
  42. 67
      comfy/ldm/modules/diffusionmodules/util.py
  43. 8
      comfy/ldm/modules/encoders/noise_aug_modules.py
  44. 31
      comfy/ldm/modules/sub_quadratic_attention.py
  45. 13
      comfy/ldm/modules/temporal_ae.py
  46. 46
      comfy/lora.py
  47. 271
      comfy/model_base.py
  48. 82
      comfy/model_detection.py
  49. 316
      comfy/model_management.py
  50. 252
      comfy/model_patcher.py
  51. 98
      comfy/model_sampling.py
  52. 201
      comfy/ops.py
  53. 12
      comfy/sample.py
  54. 532
      comfy/samplers.py
  55. 215
      comfy/sd.py
  56. 150
      comfy/sd1_clip.py
  57. 6
      comfy/sd2_clip.py
  58. 40
      comfy/sdxl_clip.py
  59. 238
      comfy/supported_models.py
  60. 21
      comfy/supported_models_base.py
  61. 5
      comfy/taesd/taesd.py
  62. 62
      comfy/utils.py
  63. 272
      comfy_extras/nodes_canny.py
  64. 25
      comfy_extras/nodes_cond.py
  65. 104
      comfy_extras/nodes_custom_sampler.py
  66. 42
      comfy_extras/nodes_differential_diffusion.py
  67. 10
      comfy_extras/nodes_freelunch.py
  68. 3
      comfy_extras/nodes_hypernetwork.py
  69. 36
      comfy_extras/nodes_hypertile.py
  70. 26
      comfy_extras/nodes_images.py
  71. 49
      comfy_extras/nodes_latent.py
  72. 22
      comfy_extras/nodes_mask.py
  73. 96
      comfy_extras/nodes_model_advanced.py
  74. 129
      comfy_extras/nodes_model_merging.py
  75. 49
      comfy_extras/nodes_morphology.py
  76. 55
      comfy_extras/nodes_perpneg.py
  77. 187
      comfy_extras/nodes_photomaker.py
  78. 3
      comfy_extras/nodes_post_processing.py
  79. 30
      comfy_extras/nodes_rebatch.py
  80. 170
      comfy_extras/nodes_sag.py
  81. 47
      comfy_extras/nodes_sdupscale.py
  82. 143
      comfy_extras/nodes_stable3d.py
  83. 140
      comfy_extras/nodes_stable_cascade.py
  84. 45
      comfy_extras/nodes_video_model.py
  85. 10
      cuda_malloc.py
  86. 16
      custom_nodes/example_node.py.example
  87. 45
      custom_nodes/websocket_image_save.py
  88. 172
      execution.py
  89. 22
      folder_paths.py
  90. 3
      latent_preview.py
  91. 58
      main.py
  92. 0
      models/photomaker/put_photomaker_models_here
  93. 35
      new_updater.py
  94. 212
      nodes.py
  95. 3
      requirements.txt
  96. 159
      script_examples/websockets_api_example_ws_images.py
  97. 58
      server.py
  98. 9
      tests-ui/afterSetup.js
  99. 3
      tests-ui/babel.config.json
  100. 2
      tests-ui/jest.config.js
  101. Some files were not shown because too many files have changed in this diff Show More

3
.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat

@ -1,3 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
pause

2
.ci/nightly/windows_base_files/run_nvidia_gpu.bat

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention
pause

49
.ci/update_windows/update.py

@ -1,6 +1,9 @@
import pygit2 import pygit2
from datetime import datetime from datetime import datetime
import sys import sys
import os
import shutil
import filecmp
def pull(repo, remote_name='origin', branch='master'): def pull(repo, remote_name='origin', branch='master'):
for remote in repo.remotes: for remote in repo.remotes:
@ -42,7 +45,8 @@ def pull(repo, remote_name='origin', branch='master'):
raise AssertionError('Unknown merge analysis result') raise AssertionError('Unknown merge analysis result')
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
repo = pygit2.Repository(str(sys.argv[1])) repo_path = str(sys.argv[1])
repo = pygit2.Repository(repo_path)
ident = pygit2.Signature('comfyui', 'comfy@ui') ident = pygit2.Signature('comfyui', 'comfy@ui')
try: try:
print("stashing current changes") print("stashing current changes")
@ -51,7 +55,10 @@ except KeyError:
print("nothing to stash") print("nothing to stash")
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) print("creating backup branch: {}".format(backup_branch_name))
repo.branches.local.create(backup_branch_name, repo.head.peel()) try:
repo.branches.local.create(backup_branch_name, repo.head.peel())
except:
pass
print("checking out master branch") print("checking out master branch")
branch = repo.lookup_branch('master') branch = repo.lookup_branch('master')
@ -63,3 +70,41 @@ pull(repo)
print("Done!") print("Done!")
self_update = True
if len(sys.argv) > 2:
self_update = '--skip_self_update' not in sys.argv
update_py_path = os.path.realpath(__file__)
repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py")
cur_path = os.path.dirname(update_py_path)
req_path = os.path.join(cur_path, "current_requirements.txt")
repo_req_path = os.path.join(repo_path, "requirements.txt")
def files_equal(file1, file2):
try:
return filecmp.cmp(file1, file2, shallow=False)
except:
return False
def file_size(f):
try:
return os.path.getsize(f)
except:
return 0
if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10:
shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py"))
exit()
if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
import subprocess
try:
subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path])
shutil.copy(repo_req_path, req_path)
except:
pass

8
.ci/update_windows/update_comfyui.bat

@ -1,2 +1,8 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\ ..\python_embeded\python.exe .\update.py ..\ComfyUI\
pause if exist update_new.py (
move /y update_new.py update.py
echo Running updater again since it got updated.
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update
)
if "%~1"=="" pause

3
.ci/update_windows/update_comfyui_and_python_dependencies.bat

@ -1,3 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 xformers -r ../ComfyUI/requirements.txt pygit2
pause

11
.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat

@ -1,11 +0,0 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\
echo
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
echo You should not be running this anyways unless you really have to
echo
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo
pause
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers -r ../ComfyUI/requirements.txt pygit2
pause

2
.github/workflows/test-ui.yaml

@ -22,5 +22,5 @@ jobs:
run: | run: |
npm ci npm ci
npm run test:generate npm run test:generate
npm test npm test -- --verbose
working-directory: ./tests-ui working-directory: ./tests-ui

71
.github/workflows/windows_release_cu118_dependencies.yml

@ -1,71 +0,0 @@
name: "Windows Release cu118 dependencies"
on:
workflow_dispatch:
# push:
# branches:
# - master
jobs:
build_dependencies:
env:
# you need at least cuda 5.0 for some of the stuff compiled here.
TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6 8.9"
FORCE_CUDA: 1
MAX_JOBS: 1 # will crash otherwise
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
XFORMERS_BUILD_TYPE: "Release"
runs-on: windows-latest
steps:
- name: Cache Built Dependencies
uses: actions/cache@v3
id: cache-cu118_python_stuff
with:
path: cu118_python_deps.tar
key: ${{ runner.os }}-build-cu118
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
uses: actions/checkout@v3
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
uses: actions/setup-python@v4
with:
python-version: '3.10.9'
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
uses: comfyanonymous/cuda-toolkit@test
id: cuda-toolkit
with:
cuda: '11.8.0'
# copied from xformers github
- name: Setup MSVC
uses: ilammy/msvc-dev-cmd@v1
- name: Configure Pagefile
# windows runners will OOM with many CUDA architectures
# we cheat here with a page file
uses: al-cheb/configure-pagefile-action@v1.3
with:
minimum-size: 2GB
# really unfortunate: https://github.com/ilammy/msvc-dev-cmd#name-conflicts-with-shell-bash
- name: Remove link.exe
shell: bash
run: rm /usr/bin/link
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
shell: bash
run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
git clone --recurse-submodules https://github.com/facebookresearch/xformers.git
cd xformers
python -m pip install --no-cache-dir wheel setuptools twine
echo building xformers
python setup.py bdist_wheel -d ../temp_wheel_dir/
cd ..
rm -rf xformers
ls -lah temp_wheel_dir
mv temp_wheel_dir cu118_python_deps
tar cf cu118_python_deps.tar cu118_python_deps

37
.github/workflows/windows_release_cu118_dependencies_2.yml

@ -1,37 +0,0 @@
name: "Windows Release cu118 dependencies 2"
on:
workflow_dispatch:
inputs:
xformers:
description: 'xformers version'
required: true
type: string
default: "xformers"
# push:
# branches:
# - master
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10.9'
- shell: bash
run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir cu118_python_deps
tar cf cu118_python_deps.tar cu118_python_deps
- uses: actions/cache/save@v3
with:
path: cu118_python_deps.tar
key: ${{ runner.os }}-build-cu118

79
.github/workflows/windows_release_cu118_package.yml

@ -1,79 +0,0 @@
name: "Windows Release cu118 packaging"
on:
workflow_dispatch:
# push:
# branches:
# - master
jobs:
package_comfyui:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
runs-on: windows-latest
steps:
- uses: actions/cache/restore@v3
id: cache
with:
path: cu118_python_deps.tar
key: ${{ runner.os }}-build-cu118
- shell: bash
run: |
mv cu118_python_deps.tar ../
cd ..
tar xf cu118_python_deps.tar
pwd
ls
- uses: actions/checkout@v3
with:
fetch-depth: 0
persist-credentials: false
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
echo 'import site' >> ./python310._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../cu118_python_deps/*
sed -i '1i../ComfyUI' ./python310._pth
cd ..
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
cd ComfyUI_windows_portable
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/update_windows_cu118/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
ls
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
tag: "latest"
overwrite: true

7
.github/workflows/windows_release_dependencies.yml

@ -24,7 +24,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "6" default: "8"
# push: # push:
# branches: # branches:
# - master # - master
@ -41,10 +41,9 @@ jobs:
- shell: bash - shell: bash
run: | run: |
echo "@echo off echo "@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\\ call update_comfyui.bat nopause
echo - echo -
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff echo This will try to update pytorch and all python dependencies.
echo You should not be running this anyways unless you really have to
echo - echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead. echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo - echo -

33
.github/workflows/windows_release_nightly_pytorch.yml

@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch"
on: on:
workflow_dispatch: workflow_dispatch:
inputs:
cu:
description: 'cuda version'
required: true
type: string
default: "121"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "2"
# push: # push:
# branches: # branches:
# - master # - master
@ -20,21 +38,21 @@ jobs:
persist-credentials: false persist-credentials: false
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.11.6' python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash - shell: bash
run: | run: |
cd .. cd ..
cp -r ComfyUI ComfyUI_copy cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded unzip python_embeded.zip -d python_embeded
cd python_embeded cd python_embeded
echo 'import site' >> ./python311._pth echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python311._pth sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd .. cd ..
git clone https://github.com/comfyanonymous/taesd git clone https://github.com/comfyanonymous/taesd
@ -49,9 +67,10 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
echo "call update_comfyui.bat nopause
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > ./update/update_comfyui_and_python_dependencies.bat
cd .. cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch

2
.github/workflows/windows_release_package.yml

@ -19,7 +19,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "6" default: "8"
# push: # push:
# branches: # branches:
# - master # - master

1
.gitignore vendored

@ -15,3 +15,4 @@ venv/
!/web/extensions/logging.js.example !/web/extensions/logging.js.example
!/web/extensions/core/ !/web/extensions/core/
/tests-ui/data/object_info.json /tests-ui/data/object_info.json
/user/

9
.vscode/settings.json vendored

@ -1,9 +0,0 @@
{
"path-intellisense.mappings": {
"../": "${workspaceFolder}/web/extensions/core"
},
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

17
README.md

@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
@ -77,6 +77,8 @@ There is a portable standalone build for Windows that should work for running on
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
If you have trouble extracting it, right click the file -> properties -> unblock
#### How do I share models between another UI and ComfyUI? #### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -93,23 +95,26 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae Put your VAE in: models/vae
Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier.
### AMD GPUs (Linux only) ### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7```
This is the command to install the nightly with ROCm 5.7 that might have some performance improvements: This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```
### NVIDIA ### NVIDIA
Nvidia users should install pytorch using this command: Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
This is the command to install pytorch nightly instead which might have performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
#### Troubleshooting #### Troubleshooting
If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with: If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with:

54
app/app_settings.py

@ -0,0 +1,54 @@
import os
import json
from aiohttp import web
class AppSettings():
def __init__(self, user_manager):
self.user_manager = user_manager
def get_settings(self, request):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
if os.path.isfile(file):
with open(file) as f:
return json.load(f)
else:
return {}
def save_settings(self, request, settings):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
with open(file, "w") as f:
f.write(json.dumps(settings, indent=4))
def add_routes(self, routes):
@routes.get("/settings")
async def get_settings(request):
return web.json_response(self.get_settings(request))
@routes.get("/settings/{id}")
async def get_setting(request):
value = None
settings = self.get_settings(request)
setting_id = request.match_info.get("id", None)
if setting_id and setting_id in settings:
value = settings[setting_id]
return web.json_response(value)
@routes.post("/settings")
async def post_settings(request):
settings = self.get_settings(request)
new_settings = await request.json()
self.save_settings(request, {**settings, **new_settings})
return web.Response(status=200)
@routes.post("/settings/{id}")
async def post_setting(request):
setting_id = request.match_info.get("id", None)
if not setting_id:
return web.Response(status=400)
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)

140
app/user_manager.py

@ -0,0 +1,140 @@
import json
import os
import re
import uuid
from aiohttp import web
from comfy.cli_args import args
from folder_paths import user_directory
from .app_settings import AppSettings
default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class UserManager():
def __init__(self):
global user_directory
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(users_file):
with open(users_file) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
if user not in self.users:
raise KeyError("Unknown user: " + user)
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory
if type == "userdata":
root_dir = user_directory
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
return None
parent = user_root
if file is not None:
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
return None
if create_dir and not os.path.exists(parent):
os.mkdir(parent)
return path
def add_user(self, name):
name = name.strip()
if not name:
raise ValueError("username not provided")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
global users_file
with open(users_file, "w") as f:
json.dump(self.users, f)
return user_id
def add_routes(self, routes):
self.settings.add_routes(routes)
@routes.get("/users")
async def get_users(request):
if args.multi_user:
return web.json_response({"storage": "server", "users": self.users})
else:
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
return web.json_response({
"storage": "server",
"migrated": os.path.exists(user_dir)
})
@routes.post("/users")
async def post_users(request):
body = await request.json()
username = body["username"]
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
user_id = self.add_user(username)
return web.json_response(user_id)
@routes.get("/userdata/{file}")
async def getuserdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if not os.path.exists(path):
return web.Response(status=404)
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
body = await request.read()
with open(path, "wb") as f:
f.write(body)
return web.Response(status=200)

34
comfy/cldm/cldm.py

@ -53,7 +53,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None, transformer_depth_middle=None,
transformer_depth_output=None, transformer_depth_output=None,
device=None, device=None,
operations=comfy.ops, operations=comfy.ops.disable_weight_init,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -141,24 +141,24 @@ class ControlNet(nn.Module):
) )
] ]
) )
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)]) self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
self.input_hint_block = TimestepEmbedSequential( self.input_hint_block = TimestepEmbedSequential(
operations.conv_nd(dims, hint_channels, 16, 3, padding=1), operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 16, 16, 3, padding=1), operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2), operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 32, 32, 3, padding=1), operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2), operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 96, 96, 3, padding=1), operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2), operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1)) operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
self._feature_size = model_channels self._feature_size = model_channels
@ -206,7 +206,7 @@ class ControlNet(nn.Module):
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self._feature_size += ch self._feature_size += ch
input_block_chans.append(ch) input_block_chans.append(ch)
if level != len(channel_mult) - 1: if level != len(channel_mult) - 1:
@ -234,7 +234,7 @@ class ControlNet(nn.Module):
) )
ch = out_ch ch = out_ch
input_block_chans.append(ch) input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
ds *= 2 ds *= 2
self._feature_size += ch self._feature_size += ch
@ -276,14 +276,14 @@ class ControlNet(nn.Module):
operations=operations operations=operations
)] )]
self.middle_block = TimestepEmbedSequential(*mid_block) self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations) self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch self._feature_size += ch
def make_zero_conv(self, channels, operations=None): def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
def forward(self, x, hint, timesteps, context, y=None, **kwargs): def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context) guided_hint = self.input_hint_block(hint, emb, context)
@ -295,7 +295,7 @@ class ControlNet(nn.Module):
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x
for module, zero_conv in zip(self.input_blocks, self.zero_convs): for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None: if guided_hint is not None:
h = module(h, emb, context) h = module(h, emb, context)

22
comfy/cli_args.py

@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpvae_group = parser.add_mutually_exclusive_group() fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
fpte_group = parser.add_mutually_exclusive_group() fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@ -106,6 +112,11 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()
else: else:
@ -116,3 +127,10 @@ if args.windows_standalone_build:
if args.disable_auto_launch: if args.disable_auto_launch:
args.auto_launch = False args.auto_launch = False
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)

194
comfy/clip_model.py

@ -0,0 +1,194 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPMLP(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
super().__init__()
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
self.activation = ACTIVATIONS[activation]
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
def forward(self, x, mask=None, optimized_attention=None):
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
dtype=dtype,
device=device
)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
return (x[0], x[1], out)

73
comfy/clip_vision.py

@ -1,64 +1,62 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
from .utils import load_torch_file, transformers_convert, common_upscale
import os import os
import torch import torch
import contextlib import json
import logging
import comfy.ops import comfy.ops
import comfy.model_patcher import comfy.model_patcher
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
import comfy.clip_model
class Output:
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224): def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2])) image = image.movedim(-1, 1)
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) if not (image.shape[2] == size and image.shape[3] == size):
h = (image.shape[2] - size)//2 scale = (size / min(image.shape[2], image.shape[3]))
w = (image.shape[3] - size)//2 image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
image = image[:,:,h:h+size,w:w+size] h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0 image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1]) return (image - mean.view([3,1,1])) / std.view([3,1,1])
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config) with open(json_config) as f:
config = json.load(f)
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32 self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.dtype = torch.float16 self.model.eval()
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd): def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False) return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image): def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)) pixel_values = clip_preprocess(image.to(self.load_device)).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
for k in outputs:
t = outputs[k]
if t is not None:
if k == 'hidden_states':
outputs["penultimate_hidden_states"] = t[-2].cpu()
outputs["hidden_states"] = None
else:
outputs[k] = t.cpu()
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
return outputs return outputs
def convert_to_transformers(sd, prefix): def convert_to_transformers(sd, prefix):
@ -82,6 +80,9 @@ def convert_to_transformers(sd, prefix):
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, prefix, "vision_model.", 48) sd = transformers_convert(sd, prefix, "vision_model.", 48)
else:
replace_prefix = {prefix: ""}
sd = state_dict_prefix_replace(sd, replace_prefix)
return sd return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False): def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
clip = ClipVisionModel(json_config) clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd) m, u = clip.load_sd(sd)
if len(m) > 0: if len(m) > 0:
print("missing clip vision:", m) logging.warning("missing clip vision: {}".format(m))
u = set(u) u = set(u)
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:

1
comfy/conds.py

@ -1,4 +1,3 @@
import enum
import torch import torch
import math import math
import comfy.utils import comfy.utils

143
comfy/controlnet.py

@ -1,13 +1,16 @@
import torch import torch
import math import math
import os import os
import logging
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import comfy.model_detection import comfy.model_detection
import comfy.model_patcher import comfy.model_patcher
import comfy.ops
import comfy.cldm.cldm import comfy.cldm.cldm
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -34,13 +37,15 @@ class ControlBase:
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
self.timestep_percent_range = (0.0, 1.0) self.timestep_percent_range = (0.0, 1.0)
self.global_average_pooling = False
self.timestep_range = None self.timestep_range = None
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
if device is None: if device is None:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@ -75,6 +80,9 @@ class ControlBase:
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range c.timestep_percent_range = self.timestep_percent_range
c.global_average_pooling = self.global_average_pooling
c.compression_ratio = self.compression_ratio
c.upscale_algorithm = self.upscale_algorithm
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -123,16 +131,21 @@ class ControlBase:
if o[i] is None: if o[i] is None:
o[i] = prev_val o[i] = prev_val
else: else:
o[i] += prev_val if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] += prev_val
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.load_device = load_device
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@ -146,28 +159,31 @@ class ControlNet(ControlBase):
else: else:
return None return None
dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
context = cond['c_crossattn']
y = cond.get('y', None) y = cond.get('y', None)
if y is not None: if y is not None:
y = y.to(self.control_model.dtype) y = y.to(dtype)
timestep = self.model_sampling_current.timestep(t) timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(None, control, control_prev, output_dtype)
def copy(self): def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
self.copy_to(c) self.copy_to(c)
return c return c
@ -185,7 +201,7 @@ class ControlNet(ControlBase):
super().cleanup() super().cleanup()
class ControlLoraOps: class ControlLoraOps:
class Linear(torch.nn.Module): class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None: device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
@ -198,12 +214,13 @@ class ControlLoraOps:
self.bias = None self.bias = None
def forward(self, input): def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None: if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else: else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module): class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__( def __init__(
self, self,
in_channels, in_channels,
@ -237,16 +254,11 @@ class ControlLoraOps:
def forward(self, input): def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None: if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else: else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class ControlLora(ControlNet): class ControlLora(ControlNet):
@ -260,25 +272,34 @@ class ControlLora(ControlNet):
controlnet_config = model.model_config.unet_config.copy() controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps() self.manual_cast_dtype = model.manual_cast_dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
dtype = model.get_dtype() dtype = model.get_dtype()
self.control_model.to(dtype) if self.manual_cast_dtype is None:
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
pass
else:
class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
pass
dtype = self.manual_cast_dtype
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.control_model.to(comfy.model_management.get_torch_device()) self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict() sd = diffusion_model.state_dict()
cm = self.control_model.state_dict() cm = self.control_model.state_dict()
for k in sd: for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) weight = sd[k]
try: try:
comfy.utils.set_attr(self.control_model, k, weight) comfy.utils.set_attr_param(self.control_model, k, weight)
except: except:
pass pass
for k in self.control_weights: for k in self.control_weights:
if k not in {"lora_controlnet"}: if k not in {"lora_controlnet"}:
comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self): def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
@ -303,9 +324,10 @@ def load_controlnet(ckpt_path, model=None):
return ControlLora(controlnet_data) return ControlLora(controlnet_data)
controlnet_config = None controlnet_config = None
supported_inference_dtypes = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
unet_dtype = comfy.model_management.unet_dtype() controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@ -346,7 +368,7 @@ def load_controlnet(ckpt_path, model=None):
leftover_keys = controlnet_data.keys() leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0: if len(leftover_keys) > 0:
print("leftover keys:", leftover_keys) logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
@ -361,12 +383,24 @@ def load_controlnet(ckpt_path, model=None):
else: else:
net = load_t2i_adapter(controlnet_data) net = load_t2i_adapter(controlnet_data)
if net is None: if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return net return net
if controlnet_config is None: if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes
controlnet_config = model_config.unet_config
load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype() unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
controlnet_config["dtype"] = unet_dtype
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@ -384,7 +418,7 @@ def load_controlnet(ckpt_path, model=None):
cd = controlnet_data[x] cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device) cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
else: else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
pass pass
@ -393,24 +427,29 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = w.load_state_dict(controlnet_data, strict=False) missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else: else:
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
control_model = control_model.to(unet_dtype) if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0] filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
class T2IAdapter(ControlBase): class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, device=None): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device) super().__init__(device)
self.t2i_model = t2i_model self.t2i_model = t2i_model
self.channels_in = channels_in self.channels_in = channels_in
self.control_input = None self.control_input = None
self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm
def scale_image_to(self, width, height): def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount unshuffle_amount = self.t2i_model.unshuffle_amount
@ -430,13 +469,13 @@ class T2IAdapter(ControlBase):
else: else:
return None return None
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
self.control_input = None self.control_input = None
self.cond_hint = None self.cond_hint = None
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1: if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
@ -455,11 +494,14 @@ class T2IAdapter(ControlBase):
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
def copy(self): def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in) c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
self.copy_to(c) self.copy_to(c)
return c return c
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
compression_ratio = 8
upscale_algorithm = 'nearest-exact'
if 'adapter' in t2i_data: if 'adapter' in t2i_data:
t2i_data = t2i_data['adapter'] t2i_data = t2i_data['adapter']
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
@ -487,13 +529,22 @@ def load_t2i_adapter(t2i_data):
if cin == 256 or cin == 768: if cin == 256 or cin == 768:
xl = True xl = True
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
elif "backbone.0.0.weight" in keys:
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 32
upscale_algorithm = 'bilinear'
elif "backbone.10.blocks.0.weight" in keys:
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 1
upscale_algorithm = 'nearest-exact'
else: else:
return None return None
missing, unexpected = model_ad.load_state_dict(t2i_data) missing, unexpected = model_ad.load_state_dict(t2i_data)
if len(missing) > 0: if len(missing) > 0:
print("t2i missing", missing) logging.warning("t2i missing {}".format(missing))
if len(unexpected) > 0: if len(unexpected) > 0:
print("t2i unexpected", unexpected) logging.debug("t2i unexpected {}".format(unexpected))
return T2IAdapter(model_ad, model_ad.input_channels) return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)

11
comfy/diffusers_convert.py

@ -1,5 +1,6 @@
import re import re
import torch import torch
import logging
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
for weight_name in weights_to_convert: for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k: if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format") logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v) new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict return new_state_dict
@ -237,8 +238,12 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
capture_qkv_bias[k_pre][code2idx[k_code]] = v capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue continue
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) text_proj = "transformer.text_projection.weight"
new_state_dict[relabelled_key] = v if k.endswith(text_proj):
new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
else:
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items(): for k_pre, tensors in capture_qkv_weight.items():
if None in tensors: if None in tensors:

1
comfy/diffusers_load.py

@ -1,4 +1,3 @@
import json
import os import os
import comfy.sd import comfy.sd

37
comfy/extra_samplers/uni_pc.py

@ -358,9 +358,6 @@ class UniPC:
thresholding=False, thresholding=False,
max_val=1., max_val=1.,
variant='bh1', variant='bh1',
noise_mask=None,
masked_image=None,
noise=None,
): ):
"""Construct a UniPC. """Construct a UniPC.
@ -372,9 +369,6 @@ class UniPC:
self.predict_x0 = predict_x0 self.predict_x0 = predict_x0
self.thresholding = thresholding self.thresholding = thresholding
self.max_val = max_val self.max_val = max_val
self.noise_mask = noise_mask
self.masked_image = masked_image
self.noise = noise
def dynamic_thresholding_fn(self, x0, t=None): def dynamic_thresholding_fn(self, x0, t=None):
""" """
@ -391,10 +385,7 @@ class UniPC:
""" """
Return the noise prediction model. Return the noise prediction model.
""" """
if self.noise_mask is not None: return self.model(x, t)
return self.model(x, t) * self.noise_mask
else:
return self.model(x, t)
def data_prediction_fn(self, x, t): def data_prediction_fn(self, x, t):
""" """
@ -409,8 +400,6 @@ class UniPC:
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s x0 = torch.clamp(x0, -s, s) / s
if self.noise_mask is not None:
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
return x0 return x0
def model_fn(self, x, t): def model_fn(self, x, t):
@ -723,8 +712,6 @@ class UniPC:
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# with torch.no_grad(): # with torch.no_grad():
for step_index in trange(steps, disable=disable_pbar): for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0: if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)] model_prev_list = [self.model_fn(x, vec_t)]
@ -766,7 +753,7 @@ class UniPC:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x model_prev_list[-1] = model_x
if callback is not None: if callback is not None:
callback(step_index, model_prev_list[-1], x, steps) callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
else: else:
raise NotImplementedError() raise NotImplementedError()
# if denoise_to_zero: # if denoise_to_zero:
@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
return (input - model(input, sigma_in, **kwargs)) / sigma return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
timesteps = sigmas.clone() timesteps = sigmas.clone()
if sigmas[-1] == 0: if sigmas[-1] == 0:
timesteps = sigmas[:] timesteps = sigmas[:]
@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
timesteps = sigmas.clone() timesteps = sigmas.clone()
ns = SigmaConvert() ns = SigmaConvert()
if image is not None: noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
img = image * ns.marginal_alpha(timesteps[0])
if max_denoise:
noise_mult = 1.0
else:
noise_mult = ns.marginal_std(timesteps[0])
img += noise * noise_mult
else:
img = noise
model_type = "noise" model_type = "noise"
model_fn = model_wrapper( model_fn = model_wrapper(
@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
) )
order = min(3, len(timesteps) - 2) order = min(3, len(timesteps) - 2)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
x /= ns.marginal_alpha(timesteps[-1]) x /= ns.marginal_alpha(timesteps[-1])
return x return x
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')

54
comfy/gligen.py

@ -1,8 +1,9 @@
import torch import torch
from torch import nn, einsum from torch import nn
from .ldm.modules.attention import CrossAttention from .ldm.modules.attention import CrossAttention
from inspect import isfunction from inspect import isfunction
import comfy.ops
ops = comfy.ops.manual_cast
def exists(val): def exists(val):
return val is not None return val is not None
@ -22,7 +23,7 @@ def default(val, d):
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2) self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -35,14 +36,14 @@ class FeedForward(nn.Module):
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(
nn.Linear(dim, inner_dim), ops.Linear(dim, inner_dim),
nn.GELU() nn.GELU()
) if not glu else GEGLU(dim, inner_dim) ) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in,
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out) ops.Linear(inner_dim, dim_out)
) )
def forward(self, x): def forward(self, x):
@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module):
query_dim=query_dim, query_dim=query_dim,
context_dim=context_dim, context_dim=context_dim,
heads=n_heads, heads=n_heads,
dim_head=d_head) dim_head=d_head,
operations=ops)
self.ff = FeedForward(query_dim, glu=True) self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim) self.norm1 = ops.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim) self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module):
# we need a linear projection since we need cat visual feature and obj # we need a linear projection since we need cat visual feature and obj
# feature # feature
self.linear = nn.Linear(context_dim, query_dim) self.linear = ops.Linear(context_dim, query_dim)
self.attn = CrossAttention( self.attn = CrossAttention(
query_dim=query_dim, query_dim=query_dim,
context_dim=query_dim, context_dim=query_dim,
heads=n_heads, heads=n_heads,
dim_head=d_head) dim_head=d_head,
operations=ops)
self.ff = FeedForward(query_dim, glu=True) self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim) self.norm1 = ops.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim) self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module):
# we need a linear projection since we need cat visual feature and obj # we need a linear projection since we need cat visual feature and obj
# feature # feature
self.linear = nn.Linear(context_dim, query_dim) self.linear = ops.Linear(context_dim, query_dim)
self.attn = CrossAttention( self.attn = CrossAttention(
query_dim=query_dim, context_dim=query_dim, dim_head=d_head) query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
self.ff = FeedForward(query_dim, glu=True) self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim) self.norm1 = ops.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim) self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@ -201,11 +204,11 @@ class PositionNet(nn.Module):
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
self.linears = nn.Sequential( self.linears = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512), ops.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(), nn.SiLU(),
nn.Linear(512, 512), ops.Linear(512, 512),
nn.SiLU(), nn.SiLU(),
nn.Linear(512, out_dim), ops.Linear(512, out_dim),
) )
self.null_positive_feature = torch.nn.Parameter( self.null_positive_feature = torch.nn.Parameter(
@ -215,16 +218,15 @@ class PositionNet(nn.Module):
def forward(self, boxes, masks, positive_embeddings): def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape B, N, _ = boxes.shape
dtype = self.linears[0].weight.dtype masks = masks.unsqueeze(-1)
masks = masks.unsqueeze(-1).to(dtype) positive_embeddings = positive_embeddings
positive_embeddings = positive_embeddings.to(dtype)
# embedding position (it may includes padding as placeholder) # embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
# learnable null embedding # learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1) positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1) xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
# replace padding with learnable null embedding # replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \ positive_embeddings = positive_embeddings * \
@ -251,7 +253,7 @@ class Gligen(nn.Module):
def func(x, extra_options): def func(x, extra_options):
key = extra_options["transformer_index"] key = extra_options["transformer_index"]
module = self.module_list[key] module = self.module_list[key]
return module(x, objs) return module(x, objs.to(device=x.device, dtype=x.dtype))
return func return func
def set_position(self, latent_image_shape, position_params, device): def set_position(self, latent_image_shape, position_params, device):

2
comfy/k_diffusion/sampling.py

@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
x = denoised x = denoised
if sigmas[i + 1] > 0: if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
return x return x

69
comfy/latent_formats.py

@ -1,3 +1,4 @@
import torch
class LatentFormat: class LatentFormat:
scale_factor = 1.0 scale_factor = 1.0
@ -33,3 +34,71 @@ class SDXL(LatentFormat):
[-0.3112, -0.2359, -0.2076] [-0.3112, -0.2359, -0.2076]
] ]
self.taesd_decoder_name = "taesdxl_decoder" self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
def __init__(self):
self.scale_factor = 0.5
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
self.latent_rgb_factors = [
[-0.2340, -0.3863, -0.3257],
[ 0.0994, 0.0885, -0.0908],
[-0.2833, -0.2349, -0.3741],
[ 0.2523, -0.0055, -0.1651]
]
class SC_Prior(LatentFormat):
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
[-0.0326, -0.0204, -0.0127],
[-0.1592, -0.0427, 0.0216],
[ 0.0873, 0.0638, -0.0020],
[-0.0602, 0.0442, 0.1304],
[ 0.0800, -0.0313, -0.1796],
[-0.0810, -0.0638, -0.1581],
[ 0.1791, 0.1180, 0.0967],
[ 0.0740, 0.1416, 0.0432],
[-0.1745, -0.1888, -0.1373],
[ 0.2412, 0.1577, 0.0928],
[ 0.1908, 0.0998, 0.0682],
[ 0.0209, 0.0365, -0.0092],
[ 0.0448, -0.0650, -0.1728],
[-0.1658, -0.1045, -0.1308],
[ 0.0542, 0.1545, 0.1325],
[-0.0352, -0.1672, -0.2541]
]
class SC_B(LatentFormat):
def __init__(self):
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195],
[-0.3087, -0.1535, 0.0366],
[ 0.0290, -0.1574, -0.4078]
]

161
comfy/ldm/cascade/common.py

@ -0,0 +1,161 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads)
return self.out_proj(out)
class Attention2D(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
def LayerNorm2d_op(operations):
class LayerNorm2d(operations.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return LayerNorm2d
class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def __init__(self, dim, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
super().__init__()
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
# self.depthwise = SAMBlock(c, num_heads, expansion)
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.channelwise = nn.Sequential(
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
nn.GELU(),
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
nn.Dropout(dropout),
operations.Linear(c * 4, c, dtype=dtype, device=device)
)
def forward(self, x, x_skip=None):
x_res = x
x = self.norm(self.depthwise(x))
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x + x_res
class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.self_attn = self_attn
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
self.kv_mapper = nn.Sequential(
nn.SiLU(),
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
def forward(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
return x
class FeedForwardBlock(nn.Module):
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.channelwise = nn.Sequential(
operations.Linear(c, c * 4, dtype=dtype, device=device),
nn.GELU(),
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
nn.Dropout(dropout),
operations.Linear(c * 4, c, dtype=dtype, device=device)
)
def forward(self, x):
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
super().__init__()
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
def forward(self, x, t):
t = t.chunk(len(self.conds) + 1, dim=1)
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
for i, c in enumerate(self.conds):
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
a, b = a + ac, b + bc
return x * (1 + a) + b

93
comfy/ldm/cascade/controlnet.py

@ -0,0 +1,93 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
from .common import LayerNorm2d_op
class CNetResBlock(nn.Module):
def __init__(self, c, dtype=None, device=None, operations=None):
super().__init__()
self.blocks = nn.Sequential(
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c, c, kernel_size=3, padding=1),
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c, c, kernel_size=3, padding=1),
)
def forward(self, x):
return x + self.blocks(x)
class ControlNet(nn.Module):
def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
super().__init__()
if bottleneck_mode is None:
bottleneck_mode = 'effnet'
self.proj_blocks = proj_blocks
if bottleneck_mode == 'effnet':
embd_channels = 1280
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
if c_in != 3:
in_weights = self.backbone[0][0].weight.data
self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
if c_in > 3:
# nn.init.constant_(self.backbone[0][0].weight, 0)
self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
else:
self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
elif bottleneck_mode == 'simple':
embd_channels = c_in
self.backbone = nn.Sequential(
operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
)
elif bottleneck_mode == 'large':
self.backbone = nn.Sequential(
operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
*[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
)
embd_channels = 1280
else:
raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
self.projections = nn.ModuleList()
for _ in range(len(proj_blocks)):
self.projections.append(nn.Sequential(
operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
))
# nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
self.xl = False
self.input_channels = c_in
self.unshuffle_amount = 8
def forward(self, x):
x = self.backbone(x)
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
for i, idx in enumerate(self.proj_blocks):
proj_outputs[idx] = self.projections[i](x)
return proj_outputs

255
comfy/ldm/cascade/stage_a.py

@ -0,0 +1,255 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from torch import nn
from torch.autograd import Function
class vector_quantize(Function):
@staticmethod
def forward(ctx, x, codebook):
with torch.no_grad():
codebook_sqr = torch.sum(codebook ** 2, dim=1)
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
_, indices = dist.min(dim=1)
ctx.save_for_backward(indices, codebook)
ctx.mark_non_differentiable(indices)
nn = torch.index_select(codebook, 0, indices)
return nn, indices
@staticmethod
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output)
return (grad_inputs, grad_codebook)
class VectorQuantize(nn.Module):
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
"""
Takes an input of variable size (as long as the last dimension matches the embedding size).
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
with the same size as the input, vq and commitment components for the loss as a touple
in the second output and the indices of the quantized vectors in the third:
quantized, (vq_loss, commit_loss), indices
"""
super(VectorQuantize, self).__init__()
self.codebook = nn.Embedding(k, embedding_size)
self.codebook.weight.data.uniform_(-1./k, 1./k)
self.vq = vector_quantize.apply
self.ema_decay = ema_decay
self.ema_loss = ema_loss
if ema_loss:
self.register_buffer('ema_element_count', torch.ones(k))
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
def _laplace_smoothing(self, x, epsilon):
n = torch.sum(x)
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
def _updateEMA(self, z_e_x, indices):
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
def idx2vq(self, idx, dim=-1):
q_idx = self.codebook(idx)
if dim != -1:
q_idx = q_idx.movedim(-1, dim)
return q_idx
def forward(self, x, get_losses=True, dim=-1):
if dim != -1:
x = x.movedim(dim, -1)
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
vq_loss, commit_loss = None, None
if self.ema_loss and self.training:
self._updateEMA(z_e_x.detach(), indices.detach())
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
if get_losses:
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
z_q_x = z_q_x.view(x.shape)
if dim != -1:
z_q_x = z_q_x.movedim(-1, dim)
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
class ResBlock(nn.Module):
def __init__(self, c, c_hidden):
super().__init__()
# depthwise/attention
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(c, c, kernel_size=3, groups=c)
)
# channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c, c_hidden),
nn.GELU(),
nn.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
def forward(self, x):
mods = self.gammas
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
try:
x = x + self.depthwise(x_temp) * mods[2]
except: #operation not implemented for bf16
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
x = x + self.depthwise[1](x_temp) * mods[2]
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x
class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
super().__init__()
self.c_latent = c_latent
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
# Encoder blocks
self.in_block = nn.Sequential(
nn.PixelUnshuffle(2),
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
)
down_blocks = []
for i in range(levels):
if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = ResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(nn.Sequential(
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
))
self.down_blocks = nn.Sequential(*down_blocks)
self.down_blocks[0]
self.codebook_size = codebook_size
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
# Decoder blocks
up_blocks = [nn.Sequential(
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
)]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1))
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2),
)
def encode(self, x, quantize=False):
x = self.in_block(x)
x = self.down_blocks(x)
if quantize:
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
return qe, x, indices, vq_loss + commit_loss * 0.25
else:
return x
def decode(self, x):
x = self.up_blocks(x)
x = self.out_block(x)
return x
def forward(self, x, quantize=False):
qe, x, _, vq_loss = self.encode(x, quantize)
x = self.decode(qe)
return x, vq_loss
class Discriminator(nn.Module):
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
super().__init__()
d = max(depth - 3, 3)
layers = [
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = c_hidden // (2 ** max((d - i), 0))
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.logits = nn.Sigmoid()
def forward(self, x, cond=None):
x = self.encoder(x)
if cond is not None:
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
x = torch.cat([x, cond], dim=1)
x = self.shuffle(x)
x = self.logits(x)
return x

257
comfy/ldm/cascade/stage_b.py

@ -0,0 +1,257 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import math
import numpy as np
import torch
from torch import nn
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
class StageB(nn.Module):
def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
super().__init__()
self.dtype = dtype
self.c_r = c_r
self.t_conds = t_conds
self.c_clip_seq = c_clip_seq
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
if not isinstance(self_attn, list):
self_attn = [self_attn] * len(c_hidden)
# CONDITIONING
self.effnet_mapper = nn.Sequential(
operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
self.pixels_mapper = nn.Sequential(
operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == 'C':
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'A':
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'F':
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'T':
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
else:
raise Exception(f'Block type {block_type} not supported')
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(c_hidden)):
if i > 0:
self.down_downscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
))
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(blocks[0][i]):
for block_type in level_config[i]:
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
down_block.append(block)
self.down_blocks.append(down_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[0][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
if i > 0:
self.up_upscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
))
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[1][::-1][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
return emb
def gen_c_embeddings(self, clip):
if len(clip.shape) == 2:
clip = clip.unsqueeze(1)
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
align_corners=True)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
t_cond = kwargs.get(c, torch.zeros_like(r))
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
clip = self.gen_c_embeddings(clip)
# Model Blocks
x = self.embedding(x)
x = x + self.effnet_mapper(
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)

274
comfy/ldm/cascade/stage_c.py

@ -0,0 +1,274 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from torch import nn
import numpy as np
import math
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
# from .controlnet import ControlNetDeliverer
class UpDownBlock2d(nn.Module):
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
super().__init__()
assert mode in ['up', 'down']
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
align_corners=True) if enabled else nn.Identity()
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class StageC(nn.Module):
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
dtype=None, device=None, operations=None):
super().__init__()
self.dtype = dtype
self.c_r = c_r
self.t_conds = t_conds
self.c_clip_seq = c_clip_seq
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
if not isinstance(self_attn, list):
self_attn = [self_attn] * len(c_hidden)
# CONDITIONING
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == 'C':
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'A':
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'F':
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'T':
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
else:
raise Exception(f'Block type {block_type} not supported')
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(c_hidden)):
if i > 0:
self.down_downscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
))
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(blocks[0][i]):
for block_type in level_config[i]:
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
down_block.append(block)
self.down_blocks.append(down_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[0][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
if i > 0:
self.up_upscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
))
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[1][::-1][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
return emb
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
clip_txt = self.clip_txt_mapper(clip_txt)
if len(clip_txt_pooled.shape) == 2:
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
if len(clip_img.shape) == 2:
clip_img = clip_img.unsqueeze(1)
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
if cnet is not None:
next_cnet = cnet.pop()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
align_corners=True).to(x.dtype)
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
align_corners=True)
if cnet is not None:
next_cnet = cnet.pop()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
align_corners=True).to(x.dtype)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
t_cond = kwargs.get(c, torch.zeros_like(r))
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
if control is not None:
cnet = control.get("input")
else:
cnet = None
# Model Blocks
x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet)
x = self._up_decode(level_outputs, r_embed, clip, cnet)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)

95
comfy/ldm/cascade/stage_c_coder.py

@ -0,0 +1,95 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
# EfficientNet
class EfficientNetEncoder(nn.Module):
def __init__(self, c_latent=16):
super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
self.mapper = nn.Sequential(
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
)
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
def forward(self, x):
x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
o = self.mapper(self.backbone(x))
return o
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__()
self.blocks = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
)
def forward(self, x):
return (self.blocks(x) - 0.5) * 2.0
class StageC_coder(nn.Module):
def __init__(self):
super().__init__()
self.previewer = Previewer()
self.encoder = EfficientNetEncoder()
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.previewer(x)

5
comfy/ldm/models/autoencoder.py

@ -8,6 +8,7 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module): class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True): def __init__(self, sample: bool = True):
@ -161,12 +162,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
}, },
**kwargs, **kwargs,
) )
self.quant_conv = torch.nn.Conv2d( self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim, (1 + ddconfig["double_z"]) * embed_dim,
1, 1,
) )
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:

123
comfy/ldm/modules/attention.py

@ -1,12 +1,10 @@
from inspect import isfunction
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
from functools import partial import logging
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
@ -19,10 +17,11 @@ if model_management.xformers_enabled():
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init
# CrossAttn precision handling # CrossAttn precision handling
if args.dont_upcast_attention: if args.dont_upcast_attention:
print("disabling upcasting of attention") logging.info("disabling upcasting of attention")
_ATTN_PRECISION = "fp16" _ATTN_PRECISION = "fp16"
else: else:
_ATTN_PRECISION = "fp32" _ATTN_PRECISION = "fp32"
@ -55,7 +54,7 @@ def init_(tensor):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
@ -65,7 +64,7 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
@ -83,16 +82,6 @@ class FeedForward(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
@ -113,19 +102,25 @@ def attention_basic(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'): sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else: else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k del q, k
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)') if mask.dtype == torch.bool:
max_neg_value = -torch.finfo(sim.dtype).max mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
mask = repeat(mask, 'b j -> (b h) () j', h=h) max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~mask, max_neg_value) mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
sim.add_(mask)
# attention, what we cannot get enough of # attention, what we cannot get enough of
sim = sim.softmax(dim=-1) sim = sim.softmax(dim=-1)
@ -176,6 +171,13 @@ def attention_sub_quad(query, key, value, heads, mask=None):
if query_chunk_size is None: if query_chunk_size is None:
query_chunk_size = 512 query_chunk_size = 512
if mask is not None:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
hidden_states = efficient_dot_product_attention( hidden_states = efficient_dot_product_attention(
query, query,
key, key,
@ -185,6 +187,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
kv_chunk_size_min=kv_chunk_size_min, kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False, use_checkpoint=False,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
mask=mask,
) )
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)
@ -233,6 +236,13 @@ def attention_split(q, k, v, heads, mask=None):
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
if mask is not None:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False first_op_done = False
cleared_cache = False cleared_cache = False
@ -247,6 +257,12 @@ def attention_split(q, k, v, heads, mask=None):
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype) s2 = s1.softmax(dim=-1).to(v.dtype)
del s1 del s1
first_op_done = True first_op_done = True
@ -259,12 +275,12 @@ def attention_split(q, k, v, heads, mask=None):
model_management.soft_empty_cache(True) model_management.soft_empty_cache(True)
if cleared_cache == False: if cleared_cache == False:
cleared_cache = True cleared_cache = True
print("out of memory error, emptying cache and trying again") logging.warning("out of memory error, emptying cache and trying again")
continue continue
steps *= 2 steps *= 2
if steps > 64: if steps > 64:
raise e raise e
print("out of memory error, increasing steps and trying again", steps) logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
else: else:
raise e raise e
@ -302,11 +318,14 @@ def attention_xformers(q, k, v, heads, mask=None):
(q, k, v), (q, k, v),
) )
# actually compute the attention, what we cannot get enough of if mask is not None:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if exists(mask):
raise NotImplementedError
out = ( out = (
out.unsqueeze(0) out.unsqueeze(0)
.reshape(b, heads, -1, dim_head) .reshape(b, heads, -1, dim_head)
@ -331,27 +350,41 @@ def attention_pytorch(q, k, v, heads, mask=None):
optimized_attention = attention_basic optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") logging.info("Using xformers cross attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention") logging.info("Using pytorch cross attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch
else: else:
if args.use_split_cross_attention: if args.use_split_cross_attention:
print("Using split optimization for cross attention") logging.info("Using split optimization for cross attention")
optimized_attention = attention_split optimized_attention = attention_split
else: else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled(): optimized_attention_masked = optimized_attention
optimized_attention_masked = attention_pytorch
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else:
return attention_basic
if device == torch.device("cpu"):
return attention_sub_quad
if mask:
return optimized_attention_masked
return optimized_attention
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -384,7 +417,7 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops): disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
self.ff_in = ff_in or inner_dim is not None self.ff_in = ff_in or inner_dim is not None
@ -394,7 +427,7 @@ class BasicTransformerBlock(nn.Module):
self.is_res = inner_dim == dim self.is_res = inner_dim == dim
if self.ff_in: if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
@ -414,10 +447,10 @@ class BasicTransformerBlock(nn.Module):
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.n_heads = n_heads self.n_heads = n_heads
self.d_head = d_head self.d_head = d_head
@ -553,13 +586,13 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=comfy.ops): use_checkpoint=True, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device) self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if not use_linear: if not use_linear:
self.proj_in = operations.Conv2d(in_channels, self.proj_in = operations.Conv2d(in_channels,
inner_dim, inner_dim,
@ -627,7 +660,7 @@ class SpatialVideoTransformer(SpatialTransformer):
disable_self_attn=False, disable_self_attn=False,
disable_temporal_crossattention=False, disable_temporal_crossattention=False,
max_time_embed_period: int = 10000, max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops dtype=None, device=None, operations=ops
): ):
super().__init__( super().__init__(
in_channels, in_channels,

54
comfy/ldm/modules/diffusionmodules/model.py

@ -5,9 +5,11 @@ import torch.nn as nn
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
import logging
from comfy import model_management from comfy import model_management
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers import xformers
@ -40,7 +42,7 @@ def nonlinearity(x):
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module): class Upsample(nn.Module):
@ -48,7 +50,7 @@ class Upsample(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = comfy.ops.Conv2d(in_channels, self.conv = ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -78,7 +80,7 @@ class Downsample(nn.Module):
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = comfy.ops.Conv2d(in_channels, self.conv = ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
@ -105,30 +107,30 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True) self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels) self.norm1 = Normalize(in_channels)
self.conv1 = comfy.ops.Conv2d(in_channels, self.conv1 = ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
if temb_channels > 0: if temb_channels > 0:
self.temb_proj = comfy.ops.Linear(temb_channels, self.temb_proj = ops.Linear(temb_channels,
out_channels) out_channels)
self.norm2 = Normalize(out_channels) self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True) self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = comfy.ops.Conv2d(out_channels, self.conv2 = ops.Conv2d(out_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = comfy.ops.Conv2d(in_channels, self.conv_shortcut = ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
else: else:
self.nin_shortcut = comfy.ops.Conv2d(in_channels, self.nin_shortcut = ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -189,7 +191,7 @@ def slice_attention(q, k, v):
steps *= 2 steps *= 2
if steps > 128: if steps > 128:
raise e raise e
print("out of memory error, increasing steps and trying again", steps) logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
return r1 return r1
@ -234,7 +236,7 @@ def pytorch_attention(q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W) out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention") logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out return out
@ -245,35 +247,35 @@ class AttnBlock(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels, self.q = ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.k = comfy.ops.Conv2d(in_channels, self.k = ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.v = comfy.ops.Conv2d(in_channels, self.v = ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels, self.proj_out = ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
print("Using xformers attention in VAE") logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
print("Using pytorch attention in VAE") logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention self.optimized_attention = pytorch_attention
else: else:
print("Using split attention in VAE") logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention self.optimized_attention = normal_attention
def forward(self, x): def forward(self, x):
@ -312,14 +314,14 @@ class Model(nn.Module):
# timestep embedding # timestep embedding
self.temb = nn.Module() self.temb = nn.Module()
self.temb.dense = nn.ModuleList([ self.temb.dense = nn.ModuleList([
comfy.ops.Linear(self.ch, ops.Linear(self.ch,
self.temb_ch), self.temb_ch),
comfy.ops.Linear(self.temb_ch, ops.Linear(self.temb_ch,
self.temb_ch), self.temb_ch),
]) ])
# downsampling # downsampling
self.conv_in = comfy.ops.Conv2d(in_channels, self.conv_in = ops.Conv2d(in_channels,
self.ch, self.ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -388,7 +390,7 @@ class Model(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in, self.conv_out = ops.Conv2d(block_in,
out_ch, out_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -461,7 +463,7 @@ class Encoder(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
# downsampling # downsampling
self.conv_in = comfy.ops.Conv2d(in_channels, self.conv_in = ops.Conv2d(in_channels,
self.ch, self.ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -506,7 +508,7 @@ class Encoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in, self.conv_out = ops.Conv2d(block_in,
2*z_channels if double_z else z_channels, 2*z_channels if double_z else z_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -541,7 +543,7 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
conv_out_op=comfy.ops.Conv2d, conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock, resnet_op=ResnetBlock,
attn_op=AttnBlock, attn_op=AttnBlock,
**ignorekwargs): **ignorekwargs):
@ -561,11 +563,11 @@ class Decoder(nn.Module):
block_in = ch*ch_mult[self.num_resolutions-1] block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1) curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res) self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format( logging.debug("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape))) self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in
self.conv_in = comfy.ops.Conv2d(z_channels, self.conv_in = ops.Conv2d(z_channels,
block_in, block_in,
kernel_size=3, kernel_size=3,
stride=1, stride=1,

96
comfy/ldm/modules/diffusionmodules/openaimodel.py

@ -1,24 +1,22 @@
from abc import abstractmethod from abc import abstractmethod
import math
import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from functools import partial import logging
from .util import ( from .util import (
checkpoint, checkpoint,
avg_pool_nd, avg_pool_nd,
zero_module, zero_module,
normalization,
timestep_embedding, timestep_embedding,
AlphaBlender, AlphaBlender,
) )
from ..attention import SpatialTransformer, SpatialVideoTransformer, default from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists from comfy.ldm.util import exists
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):
""" """
@ -70,7 +68,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -106,7 +104,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -159,7 +157,7 @@ class ResBlock(TimestepBlock):
skip_t_emb=False, skip_t_emb=False,
dtype=None, dtype=None,
device=None, device=None,
operations=comfy.ops operations=ops
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -177,7 +175,7 @@ class ResBlock(TimestepBlock):
padding = kernel_size // 2 padding = kernel_size // 2
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device), operations.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
) )
@ -206,12 +204,11 @@ class ResBlock(TimestepBlock):
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) ,
),
) )
if self.out_channels == channels: if self.out_channels == channels:
@ -285,7 +282,7 @@ class VideoResBlock(ResBlock):
down: bool = False, down: bool = False,
dtype=None, dtype=None,
device=None, device=None,
operations=comfy.ops operations=ops
): ):
super().__init__( super().__init__(
channels, channels,
@ -363,7 +360,7 @@ def apply_control(h, control, name):
try: try:
h += ctrl h += ctrl
except: except:
print("warning control could not be applied", h.shape, ctrl.shape) logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h return h
class UNetModel(nn.Module): class UNetModel(nn.Module):
@ -435,12 +432,9 @@ class UNetModel(nn.Module):
disable_temporal_crossattention=False, disable_temporal_crossattention=False,
max_ddpm_temb_period=10000, max_ddpm_temb_period=10000,
device=None, device=None,
operations=comfy.ops, operations=ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None: if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
@ -457,7 +451,6 @@ class UNetModel(nn.Module):
if num_head_channels == -1: if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -492,7 +485,6 @@ class UNetModel(nn.Module):
self.predict_codebook_ids = n_embed is not None self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
@ -503,9 +495,9 @@ class UNetModel(nn.Module):
if self.num_classes is not None: if self.num_classes is not None:
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer") logging.debug("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential": elif self.num_classes == "sequential":
assert adm_in_channels is not None assert adm_in_channels is not None
@ -582,7 +574,7 @@ class UNetModel(nn.Module):
up=False, up=False,
dtype=None, dtype=None,
device=None, device=None,
operations=comfy.ops operations=ops
): ):
if self.use_temporal_resblocks: if self.use_temporal_resblocks:
return VideoResBlock( return VideoResBlock(
@ -716,27 +708,30 @@ class UNetModel(nn.Module):
device=device, device=device,
operations=operations operations=operations
)] )]
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn self.middle_block = None
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, if transformer_depth_middle >= -1:
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint if transformer_depth_middle >= 0:
), mid_block += [get_attention_layer( # always uses a self-attn
get_resblock( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
merge_factor=merge_factor, disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
merge_strategy=merge_strategy, ),
video_kernel_size=video_kernel_size, get_resblock(
ch=ch, merge_factor=merge_factor,
time_embed_dim=time_embed_dim, merge_strategy=merge_strategy,
dropout=dropout, video_kernel_size=video_kernel_size,
out_channels=None, ch=ch,
dims=dims, time_embed_dim=time_embed_dim,
use_checkpoint=use_checkpoint, dropout=dropout,
use_scale_shift_norm=use_scale_shift_norm, out_channels=None,
dtype=self.dtype, dims=dims,
device=device, use_checkpoint=use_checkpoint,
operations=operations use_scale_shift_norm=use_scale_shift_norm,
)] dtype=self.dtype,
self.middle_block = TimestepEmbedSequential(*mid_block) device=device,
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch self._feature_size += ch
self.output_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([])
@ -810,13 +805,13 @@ class UNetModel(nn.Module):
self._feature_size += ch self._feature_size += ch
self.out = nn.Sequential( self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )
@ -835,21 +830,21 @@ class UNetModel(nn.Module):
transformer_patches = transformer_options.get("patches", {}) transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) image_only_indicator = kwargs.get("image_only_indicator", None)
time_context = kwargs.get("time_context", None) time_context = kwargs.get("time_context", None)
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
hs = [] hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id) transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
@ -866,7 +861,8 @@ class UNetModel(nn.Module):
h = p(h, transformer_options) h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0) transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) if self.middle_block is not None:
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle') h = apply_control(h, control, 'middle')

16
comfy/ldm/modules/diffusionmodules/upscaling.py

@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module):
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None, seed=None):
noise = default(noise, lambda: torch.randn_like(x_start)) if noise is None:
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + if seed is None:
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) noise = torch.randn_like(x_start)
else:
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
def forward(self, x): def forward(self, x):
return x, None return x, None
@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
super().__init__(noise_schedule_config=noise_schedule_config) super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None): def forward(self, x, noise_level=None, seed=None):
if noise_level is None: if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else: else:
assert isinstance(noise_level, torch.Tensor) assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level) z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level return z, noise_level

67
comfy/ldm/modules/diffusionmodules/util.py

@ -16,7 +16,6 @@ import numpy as np
from einops import repeat, rearrange from einops import repeat, rearrange
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
import comfy.ops
class AlphaBlender(nn.Module): class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"] strategies = ["learned", "fixed", "learned_with_images"]
@ -47,23 +46,25 @@ class AlphaBlender(nn.Module):
else: else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}") raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed": if self.merge_strategy == "fixed":
# make shape compatible # make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor alpha = self.mix_factor.to(device)
elif self.merge_strategy == "learned": elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor) alpha = torch.sigmoid(self.mix_factor.to(device))
# make shape compatible # make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs) # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images": elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..." if image_only_indicator is None:
alpha = torch.where( alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1")
image_only_indicator.bool(), else:
torch.ones(1, 1, device=image_only_indicator.device), alpha = torch.where(
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), image_only_indicator.bool(),
) torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern) alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible # make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs) # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
@ -77,7 +78,7 @@ class AlphaBlender(nn.Module):
x_temporal, x_temporal,
image_only_indicator=None, image_only_indicator=None,
) -> torch.Tensor: ) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator) alpha = self.get_alpha(image_only_indicator, x_spatial.device)
x = ( x = (
alpha.to(x_spatial.dtype) * x_spatial alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
@ -99,7 +100,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
alphas = torch.cos(alphas).pow(2) alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0] alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1] betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999) betas = torch.clamp(betas, min=0, max=0.999)
elif schedule == "squaredcos_cap_v2": # used for karlo prior elif schedule == "squaredcos_cap_v2": # used for karlo prior
# return early # return early
@ -114,7 +115,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else: else:
raise ValueError(f"schedule '{schedule}' unknown.") raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy() return betas
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
@ -273,46 +274,6 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape)))) return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return comfy.ops.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return comfy.ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs): def avg_pool_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D average pooling module. Create a 1D, 2D, or 3D average pooling module.

8
comfy/ldm/modules/encoders/noise_aug_modules.py

@ -15,21 +15,21 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def scale(self, x): def scale(self, x):
# re-normalize to centered mean and unit variance # re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1. / self.data_std x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
return x return x
def unscale(self, x): def unscale(self, x):
# back to original data stats # back to original data stats
x = (x * self.data_std) + self.data_mean x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x return x
def forward(self, x, noise_level=None): def forward(self, x, noise_level=None, seed=None):
if noise_level is None: if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else: else:
assert isinstance(noise_level, torch.Tensor) assert isinstance(noise_level, torch.Tensor)
x = self.scale(x) x = self.scale(x)
z = self.q_sample(x, noise_level) z = self.q_sample(x, noise_level, seed=seed)
z = self.unscale(z) z = self.unscale(z)
noise_level = self.time_embed(noise_level) noise_level = self.time_embed(noise_level)
return z, noise_level return z, noise_level

31
comfy/ldm/modules/sub_quadratic_attention.py

@ -14,6 +14,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import math import math
import logging
try: try:
from typing import Optional, NamedTuple, List, Protocol from typing import Optional, NamedTuple, List, Protocol
@ -61,6 +62,7 @@ def _summarize_chunk(
value: Tensor, value: Tensor,
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
mask,
) -> AttnChunk: ) -> AttnChunk:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -84,6 +86,8 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach() max_score = max_score.detach()
attn_weights -= max_score attn_weights -= max_score
if mask is not None:
attn_weights += mask
torch.exp(attn_weights, out=attn_weights) torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype) exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value) exp_values = torch.bmm(exp_weights, value)
@ -96,11 +100,12 @@ def _query_chunk_attention(
value: Tensor, value: Tensor,
summarize_chunk: SummarizeChunk, summarize_chunk: SummarizeChunk,
kv_chunk_size: int, kv_chunk_size: int,
mask,
) -> Tensor: ) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape _, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk: def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
key_chunk = dynamic_slice( key_chunk = dynamic_slice(
key_t, key_t,
(0, 0, chunk_idx), (0, 0, chunk_idx),
@ -111,10 +116,13 @@ def _query_chunk_attention(
(0, chunk_idx, 0), (0, chunk_idx, 0),
(batch_x_heads, kv_chunk_size, v_channels_per_head) (batch_x_heads, kv_chunk_size, v_channels_per_head)
) )
return summarize_chunk(query, key_chunk, value_chunk) if mask is not None:
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
chunks: List[AttnChunk] = [ chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
] ]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk chunk_values, chunk_weights, chunk_max = acc_chunk
@ -135,6 +143,7 @@ def _get_attention_scores_no_kv_chunking(
value: Tensor, value: Tensor,
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
mask,
) -> Tensor: ) -> Tensor:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -156,11 +165,13 @@ def _get_attention_scores_no_kv_chunking(
beta=0, beta=0,
) )
if mask is not None:
attn_scores += mask
try: try:
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True) summed = torch.sum(attn_scores, dim=-1, keepdim=True)
@ -183,6 +194,7 @@ def efficient_dot_product_attention(
kv_chunk_size_min: Optional[int] = None, kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True, use_checkpoint=True,
upcast_attention=False, upcast_attention=False,
mask = None,
): ):
"""Computes efficient dot-product attention given query, transposed key, and value. """Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in This is efficient version of attention presented in
@ -209,6 +221,9 @@ def efficient_dot_product_attention(
if kv_chunk_size_min is not None: if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
if mask is not None and len(mask.shape) == 2:
mask = mask.unsqueeze(0)
def get_query_chunk(chunk_idx: int) -> Tensor: def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice( return dynamic_slice(
query, query,
@ -216,6 +231,12 @@ def efficient_dot_product_attention(
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
) )
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@ -237,6 +258,7 @@ def efficient_dot_product_attention(
query=query, query=query,
key_t=key_t, key_t=key_t,
value=value, value=value,
mask=mask,
) )
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
@ -246,6 +268,7 @@ def efficient_dot_product_attention(
query=get_query_chunk(i * query_chunk_size), query=get_query_chunk(i * query_chunk_size),
key_t=key_t, key_t=key_t,
value=value, value=value,
mask=get_mask_chunk(i * query_chunk_size)
) for i in range(math.ceil(q_tokens / query_chunk_size)) ) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1) ], dim=1)
return res return res

13
comfy/ldm/modules/temporal_ae.py

@ -5,6 +5,7 @@ import torch
from einops import rearrange, repeat from einops import rearrange, repeat
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init
from .diffusionmodules.model import ( from .diffusionmodules.model import (
AttnBlock, AttnBlock,
@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock):
x = self.time_stack(x, temb) x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps) alpha = self.get_alpha(bs=b // timesteps).to(x.device)
x = alpha * x + (1.0 - alpha) * x_mix x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
return x return x
class AE3DConv(torch.nn.Conv2d): class AE3DConv(ops.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs) super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable): if isinstance(video_kernel_size, Iterable):
@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
else: else:
padding = int(video_kernel_size // 2) padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d( self.time_mix_conv = ops.Conv3d(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=video_kernel_size, kernel_size=video_kernel_size,
@ -130,9 +131,9 @@ class AttnVideoBlock(AttnBlock):
time_embed_dim = self.in_channels * 4 time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential( self.video_time_embed = torch.nn.Sequential(
comfy.ops.Linear(self.in_channels, time_embed_dim), ops.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(), torch.nn.SiLU(),
comfy.ops.Linear(time_embed_dim, self.in_channels), ops.Linear(time_embed_dim, self.in_channels),
) )
self.merge_strategy = merge_strategy self.merge_strategy = merge_strategy
@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
emb = emb[:, None, :] emb = emb[:, None, :]
x_mix = x_mix + emb x_mix = x_mix + emb
alpha = self.get_alpha() alpha = self.get_alpha().to(x.device)
x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge x = alpha * x + (1.0 - alpha) * x_mix # alpha merge

46
comfy/lora.py

@ -1,4 +1,5 @@
import comfy.utils import comfy.utils
import logging
LORA_CLIP_MAP = { LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1", "mlp.fc1": "mlp_fc1",
@ -20,6 +21,12 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item() alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name) loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
regular_lora = "{}.lora_up.weight".format(x) regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
@ -43,7 +50,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
@ -64,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name) loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
@ -116,8 +123,19 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name) loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
w_norm_name = "{}.w_norm".format(x) w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x) b_norm_name = "{}.b_norm".format(x)
@ -126,26 +144,26 @@ def load_lora(lora, to_load):
if w_norm is not None: if w_norm is not None:
loaded_keys.add(w_norm_name) loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,) patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None: if b_norm is not None:
loaded_keys.add(b_norm_name) loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x) diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None) diff_weight = lora.get(diff_name, None)
if diff_weight is not None: if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,) patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name) loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x) diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None) diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None: if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name) loaded_keys.add(diff_bias_name)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) logging.warning("lora key not loaded: {}".format(x))
return patch_dict return patch_dict
def model_lora_keys_clip(model, key_map={}): def model_lora_keys_clip(model, key_map={}):
@ -186,6 +204,15 @@ def model_lora_keys_clip(model, key_map={}):
key_map[lora_key] = k key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k key_map[lora_key] = k
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
key_map["lora_prior_te_text_projection"] = k #cascade lora?
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
# key_map["lora_te_text_projection"] = k
return key_map return key_map
@ -196,6 +223,7 @@ def model_lora_keys_unet(model, key_map={}):
if k.startswith("diffusion_model.") and k.endswith(".weight"): if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys: for k in diffusers_keys:

271
comfy/model_base.py

@ -1,9 +1,13 @@
import torch import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.ops
from enum import Enum from enum import Enum
from . import utils from . import utils
@ -11,9 +15,11 @@ class ModelType(Enum):
EPS = 1 EPS = 1
V_PREDICTION = 2 V_PREDICTION = 2
V_PREDICTION_EDM = 3 V_PREDICTION_EDM = 3
STABLE_CASCADE = 4
EDM = 5
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -26,6 +32,12 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM: elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION c = V_PREDICTION
s = ModelSamplingContinuousEDM s = ModelSamplingContinuousEDM
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -34,15 +46,20 @@ def model_sampling(model_config, model_type):
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__() super().__init__()
unet_config = model_config.unet_config unet_config = model_config.unet_config
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device) if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.model_type = model_type self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type) self.model_sampling = model_sampling(model_config, model_type)
@ -50,8 +67,8 @@ class BaseModel(torch.nn.Module):
if self.adm_channels is None: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
self.inpaint_model = False self.inpaint_model = False
print("model_type", model_type.name) logging.info("model_type {}".format(model_type.name))
print("adm", self.adm_channels) logging.debug("adm {}".format(self.adm_channels))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t
@ -61,15 +78,21 @@ class BaseModel(torch.nn.Module):
context = c_crossattn context = c_crossattn
dtype = self.get_dtype() dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype) xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float() t = self.model_sampling.timestep(t).float()
context = context.to(dtype) context = context.to(dtype)
extra_conds = {} extra_conds = {}
for o in kwargs: for o in kwargs:
extra = kwargs[o] extra = kwargs[o]
if hasattr(extra, "to"): if hasattr(extra, "dtype"):
extra = extra.to(dtype) if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra_conds[o] = extra extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x) return self.model_sampling.calculate_denoised(sigma, model_output, x)
@ -87,11 +110,29 @@ class BaseModel(torch.nn.Module):
if self.inpaint_model: if self.inpaint_model:
concat_keys = ("mask", "masked_image") concat_keys = ("mask", "masked_image")
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
latent_image = kwargs.get("latent_image", None) concat_latent_image = kwargs.get("concat_latent_image", None)
if concat_latent_image is None:
concat_latent_image = kwargs.get("latent_image", None)
else:
concat_latent_image = self.process_latent_in(concat_latent_image)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image): def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image) blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space # these are the values for "zero" in pixel space translated to latent space
@ -104,9 +145,9 @@ class BaseModel(torch.nn.Module):
for ck in concat_keys: for ck in concat_keys:
if denoise_mask is not None: if denoise_mask is not None:
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device)) cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else: else:
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])
@ -114,9 +155,23 @@ class BaseModel(torch.nn.Module):
cond_concat.append(blank_inpaint_image_like(noise)) cond_concat.append(blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data) out['c_concat'] = comfy.conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs) adm = self.encode_adm(**kwargs)
if adm is not None: if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm) out['y'] = comfy.conds.CONDRegular(adm)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
if cross_attn_cnet is not None:
out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)
c_concat = kwargs.get("noise_concat", None)
if c_concat is not None:
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
return out return out
def load_model_weights(self, sd, unet_prefix=""): def load_model_weights(self, sd, unet_prefix=""):
@ -129,10 +184,10 @@ class BaseModel(torch.nn.Module):
to_load = self.model_config.process_unet_state_dict(to_load) to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False) m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0: if len(m) > 0:
print("unet missing:", m) logging.warning("unet missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
print("unet unexpected:", u) logging.warning("unet unexpected: {}".format(u))
del to_load del to_load
return self return self
@ -142,39 +197,47 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent): def process_latent_out(self, latent):
return self.latent_format.process_out(latent) return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) extra_sds = []
unet_sd = self.diffusion_model.state_dict() if clip_state_dict is not None:
unet_state_dict = {} extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
for k in unet_sd: if vae_state_dict is not None:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
if self.model_type == ModelType.V_PREDICTION: if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([]) unet_state_dict["v_pred"] = torch.tensor([])
return {**unet_state_dict, **vae_state_dict, **clip_state_dict} for sd in extra_sds:
unet_state_dict.update(sd)
return unet_state_dict
def set_inpaint(self): def set_inpaint(self):
self.inpaint_model = True self.inpaint_model = True
def memory_required(self, input_shape): def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked #TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3] area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else: else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3] area = input_shape[0] * input_shape[2] * input_shape[3]
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = [] adm_inputs = []
weights = [] weights = []
noise_aug = [] noise_aug = []
@ -183,7 +246,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
weight = unclip_cond["strength"] weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"] noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed)
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight) weights.append(weight)
noise_aug.append(noise_augment) noise_aug.append(noise_augment)
@ -209,11 +272,11 @@ class SD21UNCLIP(BaseModel):
if unclip_conditioning is None: if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels)) return torch.zeros((1, self.adm_channels))
else: else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
def sdxl_pooled(args, noise_augmentor): def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args: if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
else: else:
return args["pooled_output"] return args["pooled_output"]
@ -303,13 +366,157 @@ class SVD_img2vid(BaseModel):
if latent_image.shape[1:] != noise.shape[1:]: if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0]) latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
if "time_conditioning" in kwargs: if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out return out
class SV3D_u(SVD_img2vid):
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
class SV3D_p(SVD_img2vid):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder_512 = Timestep(512)
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
azimuth = kwargs.get("azimuth", 0)
noise = kwargs.get("noise", None)
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
return torch.cat(out, dim=1)
class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device)
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
self.cc_projection.weight.copy_(cc_projection_weight)
self.cc_projection.bias.copy_(cc_projection_bias)
def extra_conds(self, **kwargs):
out = {}
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if cross_attn.shape[-1] != 768:
cross_attn = self.cc_projection(cross_attn)
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class SD_X4Upscaler(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_image", None)
noise = kwargs.get("noise", None)
noise_augment = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
seed = kwargs["seed"] - 10
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
if image is None:
image = torch.zeros_like(noise)[:,:3]
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
noise_level = torch.tensor([noise_level], device=device)
if noise_augment > 0:
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
return out
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
clip_text_pooled = kwargs["pooled_output"]
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
if "unclip_conditioning" in kwargs:
embeds = []
for unclip_cond in kwargs["unclip_conditioning"]:
weight = unclip_cond["strength"]
embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
clip_img = torch.cat(embeds, dim=1)
else:
clip_img = torch.zeros((1, 1, 768))
out["clip_img"] = comfy.conds.CONDRegular(clip_img)
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
noise = kwargs.get("noise", None)
clip_text_pooled = kwargs["pooled_output"]
if clip_text_pooled is not None:
out['clip'] = comfy.conds.CONDRegular(clip_text_pooled)
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = comfy.conds.CONDRegular(prior)
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out

82
comfy/model_detection.py

@ -1,5 +1,6 @@
import comfy.supported_models import comfy.supported_models
import comfy.supported_models_base import comfy.supported_models_base
import logging
def count_blocks(state_dict_keys, prefix_string): def count_blocks(state_dict_keys, prefix_string):
count = 0 count = 0
@ -28,13 +29,41 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None return None
def detect_unet_config(state_dict, key_prefix, dtype): def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys()) state_dict_keys = list(state_dict.keys())
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
unet_config = {}
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
if text_mapper_name in state_dict_keys:
unet_config['stable_cascade_stage'] = 'c'
w = state_dict[text_mapper_name]
if w.shape[0] == 1536: #stage c lite
unet_config['c_cond'] = 1536
unet_config['c_hidden'] = [1536, 1536]
unet_config['nhead'] = [24, 24]
unet_config['blocks'] = [[4, 12], [12, 4]]
elif w.shape[0] == 2048: #stage c full
unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b'
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
if w.shape[-1] == 640:
unet_config['c_hidden'] = [320, 640, 1280, 1280]
unet_config['nhead'] = [-1, -1, 20, 20]
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
elif w.shape[-1] == 576: #stage b lite
unet_config['c_hidden'] = [320, 576, 1152, 1152]
unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
unet_config = { unet_config = {
"use_checkpoint": False, "use_checkpoint": False,
"image_size": 32, "image_size": 32,
"out_channels": 4,
"use_spatial_transformer": True, "use_spatial_transformer": True,
"legacy": False "legacy": False
} }
@ -46,10 +75,15 @@ def detect_unet_config(state_dict, key_prefix, dtype):
else: else:
unet_config["adm_in_channels"] = None unet_config["adm_in_channels"] = None
unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
out_channels = state_dict[out_key].shape[0]
else:
out_channels = 4
num_res_blocks = [] num_res_blocks = []
channel_mult = [] channel_mult = []
attention_resolutions = [] attention_resolutions = []
@ -118,10 +152,13 @@ def detect_unet_config(state_dict, key_prefix, dtype):
channel_mult.append(last_channel_mult) channel_mult.append(last_channel_mult)
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
else: elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = -1 transformer_depth_middle = -1
else:
transformer_depth_middle = -2
unet_config["in_channels"] = in_channels unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels
unet_config["model_channels"] = model_channels unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks unet_config["num_res_blocks"] = num_res_blocks
unet_config["transformer_depth"] = transformer_depth unet_config["transformer_depth"] = transformer_depth
@ -150,11 +187,11 @@ def model_config_from_unet_config(unet_config):
if model_config.matches(unet_config): if model_config.matches(unet_config):
return model_config(unet_config) return model_config(unet_config)
print("no match", unet_config) logging.error("no match {}".format(unet_config))
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype) unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config) model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config) return comfy.supported_models_base.BASE(unet_config)
@ -200,7 +237,7 @@ def convert_config(unet_config):
return new_config return new_config
def unet_config_from_diffusers_unet(state_dict, dtype): def unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {} match = {}
transformer_depth = [] transformer_depth = []
@ -208,6 +245,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
down_blocks = count_blocks(state_dict, "down_blocks.{}") down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks): for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
for ab in range(attn_blocks): for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count) transformer_depth.append(transformer_count)
@ -216,8 +254,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
attn_res *= 2 attn_res *= 2
if attn_blocks == 0: if attn_blocks == 0:
transformer_depth.append(0) for i in range(res_blocks):
transformer_depth.append(0) transformer_depth.append(0)
match["transformer_depth"] = transformer_depth match["transformer_depth"] = transformer_depth
@ -289,7 +327,25 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True
@ -301,8 +357,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
return convert_config(unet_config) return convert_config(unet_config)
return None return None
def model_config_from_diffusers_unet(state_dict, dtype): def model_config_from_diffusers_unet(state_dict):
unet_config = unet_config_from_diffusers_unet(state_dict, dtype) unet_config = unet_config_from_diffusers_unet(state_dict)
if unet_config is not None: if unet_config is not None:
return model_config_from_unet_config(unet_config) return model_config_from_unet_config(unet_config)
return None return None

316
comfy/model_management.py

@ -1,4 +1,5 @@
import psutil import psutil
import logging
from enum import Enum from enum import Enum
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils import comfy.utils
@ -28,6 +29,10 @@ total_vram = 0
lowvram_available = True lowvram_available = True
xpu_available = False xpu_available = False
if args.deterministic:
logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False directml_enabled = False
if args.directml is not None: if args.directml is not None:
import torch_directml import torch_directml
@ -37,7 +42,7 @@ if args.directml is not None:
directml_device = torch_directml.device() directml_device = torch_directml.device()
else: else:
directml_device = torch_directml.device(device_index) directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index)) logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
# torch_directml.disable_tiled_resources(True) # torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
@ -113,10 +118,10 @@ def get_total_memory(dev=None, torch_total_too=False):
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu: if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096: if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
try: try:
@ -139,12 +144,10 @@ else:
pass pass
try: try:
XFORMERS_VERSION = xformers.version.__version__ XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION) logging.info("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"): if XFORMERS_VERSION.startswith("0.0.18"):
print() logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") logging.warning("Please downgrade or upgrade xformers to a different version.\n")
print("Please downgrade or upgrade xformers to a different version.")
print()
XFORMERS_ENABLED_VAE = False XFORMERS_ENABLED_VAE = False
except: except:
pass pass
@ -171,7 +174,7 @@ try:
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPE = torch.bfloat16 VAE_DTYPE = torch.bfloat16
if is_intel_xpu(): if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@ -182,6 +185,9 @@ except:
if is_intel_xpu(): if is_intel_xpu():
VAE_DTYPE = torch.bfloat16 VAE_DTYPE = torch.bfloat16
if args.cpu_vae:
VAE_DTYPE = torch.float32
if args.fp16_vae: if args.fp16_vae:
VAE_DTYPE = torch.float16 VAE_DTYPE = torch.float16
elif args.bf16_vae: elif args.bf16_vae:
@ -206,23 +212,16 @@ elif args.highvram or args.gpu_only:
FORCE_FP32 = False FORCE_FP32 = False
FORCE_FP16 = False FORCE_FP16 = False
if args.force_fp32: if args.force_fp32:
print("Forcing FP32, if this improves things please report it.") logging.info("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True FORCE_FP32 = True
if args.force_fp16: if args.force_fp16:
print("Forcing FP16.") logging.info("Forcing FP16.")
FORCE_FP16 = True FORCE_FP16 = True
if lowvram_available: if lowvram_available:
try: if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
import accelerate vram_state = set_vram_to
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
@ -231,12 +230,12 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}") logging.info(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY: if DISABLE_SMART_MEMORY:
print("Disabling smart memory management") logging.info("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
if hasattr(device, 'type'): if hasattr(device, 'type'):
@ -254,19 +253,27 @@ def get_torch_device_name(device):
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try: try:
print("Device:", get_torch_device_name(get_torch_device())) logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except: except:
print("Could not pick default device.") logging.warning("Could not pick default device.")
print("VAE dtype:", VAE_DTYPE) logging.info("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = [] current_loaded_models = []
def module_size(module):
module_mem = 0
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
self.model_accelerated = False
self.device = model.load_device self.device = model.load_device
self.weights_loaded = False
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
@ -278,38 +285,33 @@ class LoadedModel:
return self.model_memory() return self.model_memory()
def model_load(self, lowvram_model_memory=0): def model_load(self, lowvram_model_memory=0):
patch_model_to = None patch_model_to = self.device
if lowvram_model_memory == 0:
patch_model_to = self.device
self.model.model_patches_to(self.device) self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype()) self.model.model_patches_to(self.model.model_dtype())
load_weights = not self.weights_loaded
try: try:
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e: except Exception as e:
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
self.model_unload() self.model_unload()
raise e raise e
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
self.weights_loaded = True
return self.real_model return self.real_model
def model_unload(self): def model_unload(self, unpatch_weights=True):
if self.model_accelerated: self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
accelerate.hooks.remove_hook_from_submodules(self.real_model)
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
@ -317,31 +319,57 @@ class LoadedModel:
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) return (1024 * 1024 * 1024)
def unload_model_clones(model): def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [] to_unload = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model): if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload to_unload = [i] + to_unload
if len(to_unload) == 0:
return None
same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
for i in to_unload: for i in to_unload:
print("unload clone", i) logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload() current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False unloaded_model = []
can_unload = []
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded:
m = current_loaded_models.pop(i) can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
m.model_unload()
del m
unloaded_model = True
if unloaded_model: for x in sorted(can_unload):
i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
current_loaded_models.pop(i)
if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
else: else:
if vram_state != VRAMState.HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
@ -366,7 +394,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded.append(loaded_model) models_already_loaded.append(loaded_model)
else: else:
if hasattr(x, "model"): if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}") logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
if len(models_to_load) == 0: if len(models_to_load) == 0:
@ -376,17 +404,22 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem, d, models_already_loaded)
return return
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
unload_model_clones(loaded_model.model) unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load: for loaded_model in models_to_load:
model = loaded_model.model model = loaded_model.model
torch_dev = model.load_device torch_dev = model.load_device
@ -398,14 +431,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
else: else:
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024 lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory) cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
@ -430,6 +463,13 @@ def dtype_size(dtype):
dtype_size = 4 dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16: if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2 dtype_size = 2
elif dtype == torch.float32:
dtype_size = 4
else:
try:
dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize
pass
return dtype_size return dtype_size
def unet_offload_device(): def unet_offload_device():
@ -456,13 +496,44 @@ def unet_inital_load_device(parameters, dtype):
else: else:
return cpu_dev return cpu_dev
def unet_dtype(device=None, model_params=0): def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if should_use_fp16(device=device, model_params=model_params): if args.fp16_unet:
return torch.float16 return torch.float16
if args.fp8_e4m3fn_unet:
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32 return torch.float32
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
bf16_supported = should_use_bf16(inference_device)
if bf16_supported and weight_dtype == torch.bfloat16:
return None
if fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
elif bf16_supported and torch.bfloat16 in supported_dtypes:
return torch.bfloat16
else:
return torch.float32
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
@ -492,12 +563,21 @@ def text_encoder_dtype(device=None):
elif args.fp32_text_enc: elif args.fp32_text_enc:
return torch.float32 return torch.float32
if should_use_fp16(device, prioritize_performance=False): if is_device_cpu(device):
return torch.float16 return torch.float16
return torch.float16
def intermediate_device():
if args.gpu_only:
return get_torch_device()
else: else:
return torch.float32 return torch.device("cpu")
def vae_device(): def vae_device():
if args.cpu_vae:
return torch.device("cpu")
return get_torch_device() return get_torch_device()
def vae_offload_device(): def vae_offload_device():
@ -515,6 +595,22 @@ def get_autocast_device(dev):
return dev.type return dev.type
return "cuda" return "cuda"
def supports_dtype(device, dtype): #TODO
if dtype == torch.float32:
return True
if is_device_cpu(device):
return False
if dtype == torch.float16:
return True
if dtype == torch.bfloat16:
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
@ -525,15 +621,17 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu(): elif is_intel_xpu():
device_supports_cast = True device_supports_cast = True
non_blocking = device_supports_non_blocking(device)
if device_supports_cast: if device_supports_cast:
if copy: if copy:
if tensor.device == device: if tensor.device == device:
return tensor.to(dtype, copy=copy) return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy).to(dtype) return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else: else:
return tensor.to(device).to(dtype) return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else: else:
return tensor.to(dtype).to(device, copy=copy) return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_enabled
@ -606,19 +704,22 @@ def mps_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.MPS return cpu_state == CPUState.MPS
def is_device_cpu(device): def is_device_type(device, type):
if hasattr(device, 'type'): if hasattr(device, 'type'):
if (device.type == 'cpu'): if (device.type == type):
return True return True
return False return False
def is_device_cpu(device):
return is_device_type(device, 'cpu')
def is_device_mps(device): def is_device_mps(device):
if hasattr(device, 'type'): return is_device_type(device, 'mps')
if (device.type == 'mps'):
return True def is_device_cuda(device):
return False return is_device_type(device, 'cuda')
def should_use_fp16(device=None, model_params=0, prioritize_performance=True): def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled global directml_enabled
if device is not None: if device is not None:
@ -628,9 +729,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if FORCE_FP16: if FORCE_FP16:
return True return True
if device is not None: #TODO if device is not None:
if is_device_mps(device): if is_device_mps(device):
return False return True
if FORCE_FP32: if FORCE_FP32:
return False return False
@ -638,16 +739,22 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if directml_enabled: if directml_enabled:
return False return False
if cpu_mode() or mps_mode(): if mps_mode():
return False #TODO ? return True
if cpu_mode():
return False
if is_intel_xpu(): if is_intel_xpu():
return True return True
if torch.cuda.is_bf16_supported(): if torch.version.hip:
return True return True
props = torch.cuda.get_device_properties("cuda") props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
if props.major < 6: if props.major < 6:
return False return False
@ -655,12 +762,12 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled #FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
#when the model doesn't actually fit on the card #when the model doesn't actually fit on the card
#TODO: actually test if GP106 and others have the same type of behavior #TODO: actually test if GP106 and others have the same type of behavior
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"] nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series: for x in nvidia_10_series:
if x in props.name.lower(): if x in props.name.lower():
fp16_works = True fp16_works = True
if fp16_works: if fp16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory: if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True return True
@ -676,6 +783,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return True return True
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None: #TODO not sure about mps bf16 support
if is_device_mps(device):
return False
if FORCE_FP32:
return False
if directml_enabled:
return False
if cpu_mode() or mps_mode():
return False
if is_intel_xpu():
return True
if device is None:
device = torch.device("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
return False
def soft_empty_cache(force=False): def soft_empty_cache(force=False):
global cpu_state global cpu_state
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
@ -687,11 +831,11 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key): def unload_all_models():
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. free_memory(1e30, get_torch_device())
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]] def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight return weight
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else

252
comfy/model_patcher.py

@ -1,10 +1,24 @@
import torch import torch
import copy import copy
import inspect import inspect
import logging
import uuid
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
def apply_weight_decompose(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm)
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size self.size = size
@ -23,28 +37,29 @@ class ModelPatcher:
self.current_device = current_device self.current_device = current_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4()
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
return self.size return self.size
model_sd = self.model.state_dict() model_sd = self.model.state_dict()
size = 0 self.size = comfy.model_management.module_size(self.model)
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys()) self.model_keys = set(model_sd.keys())
return size return self.size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy() n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options) n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys n.model_keys = self.model_keys
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
return n return n
def is_clone(self, other): def is_clone(self, other):
@ -52,31 +67,58 @@ class ModelPatcher:
return True return True
return False return False
def clone_has_same_weights(self, clone):
if not self.is_clone(clone):
return False
if len(self.patches) == 0 and len(clone.patches) == 0:
return True
if self.patches_uuid == clone.patches_uuid:
if len(self.patches) != len(clone.patches):
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
else:
return True
def memory_required(self, input_shape): def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape) return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_unet_function_wrapper(self, unet_wrapper_function): def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_denoise_mask_function(self, denoise_mask_function):
self.model_options["denoise_mask_function"] = denoise_mask_function
def set_model_patch(self, patch, name): def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" not in to: if "patches" not in to:
to["patches"] = {} to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch] to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number): def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches_replace" not in to: if "patches_replace" not in to:
to["patches_replace"] = {} to["patches_replace"] = {}
if name not in to["patches_replace"]: if name not in to["patches_replace"]:
to["patches_replace"][name] = {} to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch): def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch") self.set_model_patch(patch, "attn1_patch")
@ -84,11 +126,11 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch): def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch") self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number): def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn1", block_name, number) self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
def set_model_attn2_replace(self, patch, block_name, number): def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn2", block_name, number) self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch): def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch") self.set_model_patch(patch, "attn1_output_patch")
@ -142,6 +184,7 @@ class ModelPatcher:
current_patches.append((strength_patch, patches[k], strength_model)) current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches self.patches[k] = current_patches
self.patches_uuid = uuid.uuid4()
return list(p) return list(p)
def get_key_patches(self, filter_prefix=None): def get_key_patches(self, filter_prefix=None):
@ -167,41 +210,87 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_model(self, device_to=None): def patch_weight_to_device(self, key, device_to=None):
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches: for k in self.object_patches:
old = getattr(self.model, k) old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup: if k not in self.object_patches_backup:
self.object_patches_backup[k] = old self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key] if patch_weights:
model_sd = self.model_state_dict()
inplace_update = self.weight_inplace_update for key in self.patches:
if key not in model_sd:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
if key not in self.backup: self.patch_weight_to_device(key, device_to)
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None: if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) self.model.to(device_to)
else: self.current_device = device_to
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None: return self.model
self.model.to(device_to)
self.current_device = device_to
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if lowvram_weight:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches:
m.bias_function = LowVramPatch(weight_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key, device_to)
self.patch_weight_to_device(bias_key, device_to)
m.to(device_to)
mem_counter += comfy.model_management.module_size(m)
logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
@ -217,15 +306,22 @@ class ModelPatcher:
v = (self.calculate_weight(v[1:], v[0].clone(), key), ) v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1: if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0] w1 = v[0]
if alpha != 0.0: if alpha != 0.0:
if w1.shape != weight.shape: if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else: else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None: if v[2] is not None:
alpha *= v[2] / mat2.shape[0] alpha *= v[2] / mat2.shape[0]
if v[3] is not None: if v[3] is not None:
@ -235,9 +331,11 @@ class ModelPatcher:
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try: try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
print("ERROR", key, e) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif len(v) == 8: #lokr elif patch_type == "lokr":
w1 = v[0] w1 = v[0]
w2 = v[1] w2 = v[1]
w1_a = v[3] w1_a = v[3]
@ -245,6 +343,7 @@ class ModelPatcher:
w2_a = v[5] w2_a = v[5]
w2_b = v[6] w2_b = v[6]
t2 = v[7] t2 = v[7]
dora_scale = v[8]
dim = None dim = None
if w1 is None: if w1 is None:
@ -274,15 +373,18 @@ class ModelPatcher:
try: try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
print("ERROR", key, e) logging.error("ERROR {} {} {}".format(patch_type, key, e))
else: #loha elif patch_type == "loha":
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]
if v[2] is not None: if v[2] is not None:
alpha *= v[2] / w1b.shape[0] alpha *= v[2] / w1b.shape[0]
w2a = v[3] w2a = v[3]
w2b = v[4] w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition if v[5] is not None: #cp decomposition
t1 = v[5] t1 = v[5]
t2 = v[6] t2 = v[6]
@ -303,29 +405,61 @@ class ModelPatcher:
try: try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
print("ERROR", key, e) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
return weight return weight
def unpatch_model(self, device_to=None): def unpatch_model(self, device_to=None, unpatch_weights=True):
keys = list(self.backup.keys()) if unpatch_weights:
if self.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
if self.weight_inplace_update: self.model_lowvram = False
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {} keys = list(self.backup.keys())
if device_to is not None: if self.weight_inplace_update:
self.model.to(device_to) for k in keys:
self.current_device = device_to comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr_param(self.model, k, self.backup[k])
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys()) keys = list(self.object_patches_backup.keys())
for k in keys: for k in keys:
setattr(self.model, k, self.object_patches_backup[k]) comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {} self.object_patches_backup = {}

98
comfy/model_sampling.py

@ -1,5 +1,4 @@
import torch import torch
import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math import math
@ -12,20 +11,43 @@ class EPS:
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class V_PREDICTION(EPS): class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None):
super().__init__() super().__init__()
beta_schedule = "linear"
if model_config is not None: if model_config is not None:
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule) sampling_settings = model_config.sampling_settings
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) else:
sampling_settings = {}
beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
self.sigma_data = 1.0 self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
@ -35,8 +57,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
else: else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas alphas = 1. - betas
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) alphas_cumprod = torch.cumprod(alphas, dim=0)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape timesteps, = betas.shape
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
@ -51,8 +72,8 @@ class ModelSamplingDiscrete(torch.nn.Module):
self.set_sigmas(sigmas) self.set_sigmas(sigmas)
def set_sigmas(self, sigmas): def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas) self.register_buffer('sigmas', sigmas.float())
self.register_buffer('log_sigmas', sigmas.log()) self.register_buffer('log_sigmas', sigmas.log().float())
@property @property
def sigma_min(self): def sigma_min(self):
@ -87,8 +108,6 @@ class ModelSamplingDiscrete(torch.nn.Module):
class ModelSamplingContinuousEDM(torch.nn.Module): class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None):
super().__init__() super().__init__()
self.sigma_data = 1.0
if model_config is not None: if model_config is not None:
sampling_settings = model_config.sampling_settings sampling_settings = model_config.sampling_settings
else: else:
@ -96,9 +115,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0) sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max) sigma_data = sampling_settings.get("sigma_data", 1.0)
self.set_parameters(sigma_min, sigma_max, sigma_data)
def set_sigma_range(self, sigma_min, sigma_max): def set_parameters(self, sigma_min, sigma_max, sigma_data):
self.sigma_data = sigma_data
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
@ -127,3 +148,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min = math.log(self.sigma_min) log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(sampling_settings.get("shift", 1.0))
def set_parameters(self, shift=1.0, cosine_s=8e-3):
self.shift = shift
self.cosine_s = torch.tensor(cosine_s)
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
#This part is just for compatibility with some schedulers in the codebase
self.num_timesteps = 10000
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
for x in range(self.num_timesteps):
t = (x + 1) / self.num_timesteps
sigmas[x] = self.sigma(t)
self.set_sigmas(sigmas)
def sigma(self, timestep):
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
if self.shift != 1.0:
var = alpha_cumprod
logSNR = (var/(1-var)).log()
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
alpha_cumprod = logSNR.sigmoid()
alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
def timestep(self, sigma):
var = 1 / ((sigma * sigma) + 1)
var = var.clamp(0, 1.0)
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return t
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent))

201
comfy/ops.py

@ -1,40 +1,163 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
from contextlib import contextmanager import comfy.model_management
class Linear(torch.nn.Linear): def cast_bias_weight(s, input):
def reset_parameters(self): bias = None
return None non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
class Conv2d(torch.nn.Conv2d): bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
def reset_parameters(self): if s.bias_function is not None:
return None bias = s.bias_function(bias)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
class Conv3d(torch.nn.Conv3d): if s.weight_function is not None:
def reset_parameters(self): weight = s.weight_function(weight)
return None return weight, bias
def conv_nd(dims, *args, **kwargs): class CastWeightBiasOp:
if dims == 2: comfy_cast_weights = False
return Conv2d(*args, **kwargs) weight_function = None
elif dims == 3: bias_function = None
return Conv3d(*args, **kwargs)
else: class disable_weight_init:
raise ValueError(f"unsupported dimensions: {dims}") class Linear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self):
@contextmanager return None
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear def forward_comfy_cast_weights(self, input):
force_device = device weight, bias = cast_bias_weight(self, input)
force_dtype = dtype return torch.nn.functional.linear(input, weight, bias)
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None: def forward(self, *args, **kwargs):
device = force_device if self.comfy_cast_weights:
if force_dtype is not None: return self.forward_comfy_cast_weights(*args, **kwargs)
dtype = force_dtype else:
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) return super().forward(*args, **kwargs)
torch.nn.Linear = linear_with_dtype class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
try: def reset_parameters(self):
yield return None
finally:
torch.nn.Linear = old_torch_nn_linear def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input, output_size=None):
num_spatial_dims = 2
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d):
comfy_cast_weights = True
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True
class LayerNorm(disable_weight_init.LayerNorm):
comfy_cast_weights = True
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
comfy_cast_weights = True

12
comfy/sample.py

@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions""" """ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device) noise_mask = noise_mask.to(device)
@ -47,7 +46,8 @@ def convert_cond(cond):
temp = c[1].copy() temp = c[1].copy()
model_conds = temp.get("model_conds", {}) model_conds = temp.get("model_conds", {})
if c[0] is not None: if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds temp["model_conds"] = model_conds
out.append(temp) out.append(temp)
return out return out
@ -98,10 +98,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu() samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models) cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
@ -111,8 +111,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
sigmas = sigmas.to(model.load_device) sigmas = sigmas.to(model.load_device)
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu() samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models) cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples

532
comfy/samplers.py

@ -1,260 +1,266 @@
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import enum import collections
from comfy import model_management from comfy import model_management
import math import math
from comfy import model_base import logging
import comfy.utils
import comfy.conds def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = conds.get('control', None)
patches = None
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
return cond_obj(input_x, mult, conditioning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
def can_concat_cond(c1, c2):
if c1.input_x.shape != c2.input_x.shape:
return False
def objects_concatable(obj1, obj2):
if (obj1 is None) != (obj2 is None):
return False
if obj1 is not None:
if obj1 is not obj2:
return False
return True
if not objects_concatable(c1.control, c2.control):
return False
#The main sampling function shared by all the samplers if not objects_concatable(c1.patches, c2.patches):
#Returns denoised return False
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
if 'control' in conds:
control = conds['control']
patches = None
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch] return cond_equal_size(c1.conditioning, c2.conditioning)
return (input_x, mult, conditionning, area, control, patches) def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
def cond_equal_size(c1, c2): temp = {}
if c1 is c2: for x in c_list:
return True for k in x:
if c1.keys() != c2.keys(): cur = temp.get(k, [])
return False cur.append(x[k])
for k in c1: temp[k] = cur
if not c1[k].can_concat(c2[k]):
return False
return True
def can_concat_cond(c1, c2): out = {}
if c1[0].shape != c2[0].shape: for k in temp:
return False conds = temp[k]
out[k] = conds[0].concat(conds[1:])
#control return out
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
#patches def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
if (c1[5] is None) != (c2[5] is None): out_cond = torch.zeros_like(x_in)
return False out_count = torch.ones_like(x_in) * 1e-37
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x += [p[0]]
mult += [p[1]]
c += [p[2]]
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:] out_uncond = torch.zeros_like(x_in)
transformer_options["sigmas"] = timestep out_uncond_count = torch.ones_like(x_in) * 1e-37
c['transformer_options'] = transformer_options COND = 0
UNCOND = 1
if 'model_function_wrapper' in model_options: to_run = []
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) for x in cond:
else: p = get_area_and_mult(x, x_in, timestep)
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) if p is None:
del input_x continue
for o in range(batch_chunks): to_run += [(p, COND)]
if cond_or_uncond[o] == COND: if uncond is not None:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] for x in uncond:
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else: else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] cur_patches[p] = patches[p]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] transformer_options["patches"] = cur_patches
del mult else:
transformer_options["patches"] = patches
out_cond /= out_count transformer_options["cond_or_uncond"] = cond_or_uncond[:]
del out_count transformer_options["sigmas"] = timestep
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
c['transformer_options'] = transformer_options
if math.isclose(cond_scale, 1.0): if 'model_function_wrapper' in model_options:
uncond = None output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
return x - model_options["sampler_cfg_function"](args) "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else: else:
return uncond + (cond - uncond) * cond_scale cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
class CFGNoisePredictor(torch.nn.Module): class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
@ -267,19 +273,19 @@ class CFGNoisePredictor(torch.nn.Module):
return self.apply_model(*args, **kwargs) return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module): class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model): def __init__(self, model, sigmas):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.sigmas = sigmas
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None: if denoise_mask is not None:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
if denoise_mask is not None: if denoise_mask is not None:
out *= denoise_mask out = out * denoise_mask + self.latent_image * latent_mask
if denoise_mask is not None:
out += self.latent_image * latent_mask
return out return out
def simple_scheduler(model, steps): def simple_scheduler(model, steps):
@ -294,7 +300,7 @@ def simple_scheduler(model, steps):
def ddim_scheduler(model, steps): def ddim_scheduler(model, steps):
s = model.model_sampling s = model.model_sampling
sigs = [] sigs = []
ss = len(s.sigmas) // steps ss = max(len(s.sigmas) // steps, 1)
x = 1 x = 1
while x < len(s.sigmas): while x < len(s.sigmas):
sigs += [float(s.sigmas[x])] sigs += [float(s.sigmas[x])]
@ -512,14 +518,6 @@ class Sampler:
sigma = float(sigmas[0]) sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
class UNIPCBH2(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
@ -532,7 +530,7 @@ class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap) model_k = KSamplerX0Inpaint(model_wrap, sigmas)
model_k.latent_image = latent_image model_k.latent_image = latent_image
if self.inpaint_options.get("random", False): #TODO: Should this be the default? if self.inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1) generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
@ -540,20 +538,15 @@ class KSAMPLER(Sampler):
else: else:
model_k.noise = noise model_k.noise = noise
if self.max_denoise(model_wrap, sigmas): noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
k_callback = None k_callback = None
total_steps = len(sigmas) - 1 total_steps = len(sigmas) - 1
if callback is not None: if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples return samples
@ -567,11 +560,11 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive": elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
sigma_min = sigmas[-1] sigma_min = sigmas[-1]
if sigma_min == 0: if sigma_min == 0:
sigma_min = sigmas[-2] sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
sampler_function = dpm_adaptive_function sampler_function = dpm_adaptive_function
else: else:
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
@ -594,6 +587,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive) calculate_start_end_timesteps(model, positive)
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
create_cond_with_same_area_if_none(negative, c) create_cond_with_same_area_if_none(negative, c)
@ -605,13 +605,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
@ -634,14 +627,14 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True) sigmas = normal_scheduler(model, steps, sgm=True)
else: else:
print("error invalid scheduler", self.scheduler) logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas return sigmas
def sampler_object(name): def sampler_object(name):
if name == "uni_pc": if name == "uni_pc":
sampler = UNIPC() sampler = KSAMPLER(uni_pc.sample_unipc)
elif name == "uni_pc_bh2": elif name == "uni_pc_bh2":
sampler = UNIPCBH2() sampler = KSAMPLER(uni_pc.sample_unipc_bh2)
elif name == "ddim": elif name == "ddim":
sampler = ksampler("euler", inpaint_options={"random": True}) sampler = ksampler("euler", inpaint_options={"random": True})
else: else:
@ -651,6 +644,7 @@ def sampler_object(name):
class KSampler: class KSampler:
SCHEDULERS = SCHEDULER_NAMES SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES SAMPLERS = SAMPLER_NAMES
DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2'))
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model self.model = model
@ -669,7 +663,7 @@ class KSampler:
sigmas = None sigmas = None
discard_penultimate_sigma = False discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']: if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS:
steps += 1 steps += 1
discard_penultimate_sigma = True discard_penultimate_sigma = True

215
comfy/sd.py

@ -1,10 +1,12 @@
import torch import torch
import contextlib from enum import Enum
import math import logging
from comfy import model_management from comfy import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
import yaml import yaml
import comfy.utils import comfy.utils
@ -36,7 +38,7 @@ def load_model_weights(model, sd):
w = sd.pop(x) w = sd.pop(x)
del w del w
if len(m) > 0: if len(m) > 0:
print("missing", m) logging.warning("missing {}".format(m))
return model return model
def load_clip_weights(model, sd): def load_clip_weights(model, sd):
@ -51,7 +53,7 @@ def load_clip_weights(model, sd):
if ids.dtype == torch.float32: if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.")
return load_model_weights(model, sd) return load_model_weights(model, sd)
@ -80,7 +82,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
k1 = set(k1) k1 = set(k1)
for x in loaded: for x in loaded:
if (x not in k) and (x not in k1): if (x not in k) and (x not in k1):
print("NOT LOADED", x) logging.warning("NOT LOADED {}".format(x))
return (new_modelpatcher, new_clip) return (new_modelpatcher, new_clip)
@ -122,10 +124,13 @@ class CLIP:
return self.tokenizer.tokenize_with_weights(text, return_word_ids) return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False): def encode_from_tokens(self, tokens, return_pooled=False):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
else:
self.cond_stage_model.reset_clip_layer() if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model() self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
@ -137,8 +142,11 @@ class CLIP:
tokens = self.tokenize(text) tokens = self.tokenize(text)
return self.encode_from_tokens(tokens) return self.encode_from_tokens(tokens)
def load_sd(self, sd): def load_sd(self, sd, full_model=False):
return self.cond_stage_model.load_sd(sd) if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
def get_sd(self): def get_sd(self):
return self.cond_stage_model.state_dict() return self.cond_stage_model.state_dict()
@ -151,12 +159,17 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None): def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
if config is None: if config is None:
if "decoder.mid.block_1.mix_factor" in sd: if "decoder.mid.block_1.mix_factor" in sd:
@ -169,9 +182,43 @@ class VAE:
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd: elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD() self.first_stage_model = comfy.taesd.taesd.TAESD()
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
self.downscale_ratio = 4
self.upscale_ratio = 4
#TODO
#self.memory_used_encode
#self.memory_used_decode
self.process_input = lambda image: image
self.process_output = lambda image: image
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["encoder.{}".format(k)] = sd[k]
sd = new_sd
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["previewer.{}".format(k)] = sd[k]
sd = new_sd
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
else: else:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else: else:
self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = AutoencoderKL(**(config['params']))
@ -179,32 +226,44 @@ class VAE:
m, u = self.first_stage_model.load_state_dict(sd, strict=False) m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0: if len(m) > 0:
print("Missing VAE keys", m) logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0: if len(u) > 0:
print("Leftover VAE keys", u) logging.debug("Leftover VAE keys {}".format(u))
if device is None: if device is None:
device = model_management.vae_device() device = model_management.vae_device()
self.device = device self.device = device
offload_device = model_management.vae_offload_device() offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype() if dtype is None:
dtype = model_management.vae_dtype()
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype) self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def vae_encode_crop_pixels(self, pixels):
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = torch.clamp(( output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0) / 3.0)
return output return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -213,10 +272,10 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0 samples /= 3.0
return samples return samples
@ -228,15 +287,15 @@ class VAE:
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = pixel_samples.cpu().movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
@ -245,6 +304,7 @@ class VAE:
return output.movedim(1,-1) return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
@ -252,18 +312,19 @@ class VAE:
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples) samples = self.encode_tiled_(pixel_samples)
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
@ -290,8 +351,11 @@ def load_style_model(ckpt_path):
model.load_state_dict(model_data) model.load_state_dict(model_data)
return StyleModel(model) return StyleModel(model)
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
def load_clip(ckpt_paths, embedding_directory=None): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
@ -301,14 +365,21 @@ def load_clip(ckpt_paths, embedding_directory=None):
for i in range(len(clip_data)): for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32) clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "")
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
clip_target = EmptyClass() clip_target = EmptyClass()
clip_target.params = {} clip_target.params = {}
if len(clip_data) == 1: if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel if clip_type == CLIPType.STABLE_CASCADE:
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.clip = sdxl_clip.StableCascadeClipModel
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer clip_target.tokenizer = sd2_clip.SD2Tokenizer
@ -323,10 +394,10 @@ def load_clip(ckpt_paths, embedding_directory=None):
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
print("clip missing:", m) logging.warning("clip missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
print("clip unexpected:", u) logging.debug("clip unexpected: {}".format(u))
return clip return clip
def load_gligen(ckpt_path): def load_gligen(ckpt_path):
@ -431,12 +502,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device()
class WeightsLoader(torch.nn.Module): model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
pass unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -451,27 +523,33 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd) vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd) vae = VAE(sd=vae_sd)
if output_clip: if output_clip:
w = WeightsLoader()
clip_target = model_config.clip_target() clip_target = model_config.clip_target()
if clip_target is not None: if clip_target is not None:
clip = CLIP(clip_target, embedding_directory=embedding_directory) clip_sd = model_config.process_clip_state_dict(sd)
w.cond_stage_model = clip.cond_stage_model if len(clip_sd) > 0:
sd = model_config.process_clip_state_dict(sd) clip = CLIP(clip_target, embedding_directory=embedding_directory)
load_model_weights(w, sd) m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) logging.debug("left over keys: {}".format(left_over))
if output_model: if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU") logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher) model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
@ -480,14 +558,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): #load unet in diffusers format def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None: if model_config is None:
return None return None
new_sd = sd new_sd = sd
else: #diffusers else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None: if model_config is None:
return None return None
@ -498,25 +578,36 @@ def load_unet_state_dict(sd): #load unet in diffusers format
if k in sd: if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k) new_sd[diffusers_keys[k]] = sd.pop(k)
else: else:
print(diffusers_keys[k], k) logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys in unet:", left_over) logging.info("left over keys in unet: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path): def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path) sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd) model = load_unet_state_dict(sd)
if model is None: if model is None:
print("ERROR UNSUPPORTED UNET", unet_path) logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
model_management.load_models_gpu([model, clip.load_model()]) clip_sd = None
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
comfy.utils.save_torch_file(sd, output_path, metadata=metadata) comfy.utils.save_torch_file(sd, output_path, metadata=metadata)

150
comfy/sd1_clip.py

@ -1,12 +1,14 @@
import os import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils from transformers import CLIPTokenizer
import comfy.ops import comfy.ops
import torch import torch
import traceback import traceback
import zipfile import zipfile
from . import model_management from . import model_management
import contextlib import comfy.clip_model
import json
import logging
def gen_empty_tokens(special_tokens, length): def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None) start_token = special_tokens.get("start", None)
@ -37,7 +39,7 @@ class ClipTokenWeightEncoder:
out, pooled = self.encode(to_encode) out, pooled = self.encode(to_encode)
if pooled is not None: if pooled is not None:
first_pooled = pooled[0:1].cpu() first_pooled = pooled[0:1].to(model_management.intermediate_device())
else: else:
first_pooled = pooled first_pooled = pooled
@ -54,8 +56,8 @@ class ClipTokenWeightEncoder:
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
return out[-1:].cpu(), first_pooled return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
@ -65,31 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None: if textmodel_json_config is None:
self.transformer = model_class.from_pretrained(textmodel_path) textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
else:
if textmodel_json_config is None: with open(textmodel_json_config) as f:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") config = json.load(f)
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
with comfy.ops.use_comfy_ops(device, dtype): self.num_layers = self.transformer.num_layers
with modeling_utils.no_init_weights():
self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None:
self.transformer.to(dtype)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
@ -97,16 +87,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = layer self.layer = layer
self.layer_idx = None self.layer_idx = None
self.special_tokens = special_tokens self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx) self.set_clip_options({"layer": layer_idx})
self.layer_default = (self.layer, self.layer_idx) self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
@ -114,16 +106,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
def clip_layer(self, layer_idx): def set_clip_options(self, options):
if abs(layer_idx) >= self.num_layers: layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last" self.layer = "last"
else: else:
self.layer = "hidden" self.layer = "hidden"
self.layer_idx = layer_idx self.layer_idx = layer_idx
def reset_clip_layer(self): def reset_clip_options(self):
self.layer = self.layer_default[0] self.layer = self.options_default[0]
self.layer_idx = self.layer_default[1] self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds): def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = [] out_tokens = []
@ -143,7 +138,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_temp += [next_new_token] tokens_temp += [next_new_token]
next_new_token += 1 next_new_token += 1
else: else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
while len(tokens_temp) < len(x): while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]] tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
@ -170,51 +165,37 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: attention_mask = None
precision_scope = torch.autocast if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else: else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a) z = outputs[1]
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
if hasattr(outputs, "pooler_output"): pooled_output = None
pooled_output = outputs.pooler_output.float() if len(outputs) >= 3:
else: if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
pooled_output = None pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output return z.float(), pooled_output
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)
def load_sd(self, sd): def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
if "text_projection.weight" in sd:
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
return self.transformer.load_state_dict(sd, strict=False) return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string): def parse_parentheses(string):
@ -349,9 +330,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
else: else:
embed = torch.load(embed_path, map_location="cpu") embed = torch.load(embed_path, map_location="cpu")
except Exception as e: except Exception as e:
print(traceback.format_exc()) logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
print()
print("error loading embedding, skipping loading:", embedding_name)
return None return None
if embed_out is None: if embed_out is None:
@ -375,11 +354,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer('')["input_ids"] empty = self.tokenizer('')["input_ids"]
if has_start_token: if has_start_token:
@ -441,7 +421,7 @@ class SDTokenizer:
embedding_name = word[len(self.embedding_identifier):].strip('\n') embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name) embed, leftover = self._try_get_embedding(embedding_name)
if embed is None: if embed is None:
print(f"warning, embedding:{embedding_name} does not exist, ignoring") logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
else: else:
if len(embed.shape) == 1: if len(embed.shape) == 1:
tokens.append([(embed, weight)]) tokens.append([(embed, weight)])
@ -491,6 +471,8 @@ class SDTokenizer:
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length: if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids: if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@ -524,11 +506,11 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
def clip_layer(self, layer_idx): def set_clip_options(self, options):
getattr(self, self.clip).clip_layer(layer_idx) getattr(self, self.clip).set_clip_options(options)
def reset_clip_layer(self): def reset_clip_options(self):
getattr(self, self.clip).reset_clip_layer() getattr(self, self.clip).reset_clip_options()
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name] token_weight_pairs = token_weight_pairs[self.clip_name]

6
comfy/sd2_clip.py

@ -3,13 +3,13 @@ import torch
import os import os
class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=23 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):

40
comfy/sdxl_clip.py

@ -3,13 +3,13 @@ import torch
import os import os
class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd): def load_sd(self, sd):
@ -37,16 +37,16 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def set_clip_options(self, options):
self.clip_l.clip_layer(layer_idx) self.clip_l.set_clip_options(options)
self.clip_g.clip_layer(layer_idx) self.clip_g.set_clip_options(options)
def reset_clip_layer(self): def reset_clip_options(self):
self.clip_g.reset_clip_layer() self.clip_g.reset_clip_options()
self.clip_l.reset_clip_layer() self.clip_l.reset_clip_options()
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"] token_weight_pairs_g = token_weight_pairs["g"]
@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
def load_sd(self, sd):
return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)

238
comfy/supported_models.py

@ -40,11 +40,16 @@ class SD15(supported_models_base.BASE):
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
replace_prefix = {} replace_prefix = {}
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." replace_prefix["cond_stage_model."] = "clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
for p in pop_keys:
if p in state_dict:
state_dict.pop(p)
replace_prefix = {"clip_l.": "cond_stage_model."} replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@ -72,10 +77,10 @@ class SD20(supported_models_base.BASE):
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
replace_prefix = {} replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) replace_prefix["cond_stage_model.model."] = "clip_h."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
@ -131,11 +136,10 @@ class SDXLRefiner(supported_models_base.BASE):
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
keys_to_replace = {} keys_to_replace = {}
replace_prefix = {} replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict return state_dict
@ -164,7 +168,13 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict: if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
self.sampling_settings["sigma_data"] = 0.5
self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM
elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
else: else:
return model_base.ModelType.EPS return model_base.ModelType.EPS
@ -179,26 +189,28 @@ class SDXL(supported_models_base.BASE):
keys_to_replace = {} keys_to_replace = {}
replace_prefix = {} replace_prefix = {}
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) replace_prefix["conditioner.embedders.1.model."] = "clip_g."
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {} replace_prefix = {}
keys_to_replace = {} keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
for k in state_dict: for k in state_dict:
if k.startswith("clip_l"): if k.startswith("clip_l"):
state_dict_g[k] = state_dict[k] state_dict_g[k] = state_dict[k]
state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
for p in pop_keys:
if p in state_dict_g:
state_dict_g.pop(p)
replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_g"] = "conditioner.embedders.1.model"
replace_prefix["clip_l"] = "conditioner.embedders.0" replace_prefix["clip_l"] = "conditioner.embedders.0"
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
@ -217,6 +229,36 @@ class SSD1B(SDXL):
"use_temporal_attention": False, "use_temporal_attention": False,
} }
class Segmind_Vega(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 1, 1, 2, 2],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class KOALA_700M(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 5],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class KOALA_1B(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 6],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE): class SVD_img2vid(supported_models_base.BASE):
unet_config = { unet_config = {
"model_channels": 320, "model_channels": 320,
@ -242,5 +284,161 @@ class SVD_img2vid(supported_models_base.BASE):
def clip_target(self): def clip_target(self):
return None return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] class SV3D_u(SVD_img2vid):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 256,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
vae_key_prefix = ["conditioner.embedders.1.encoder."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_u(self, device=device)
return out
class SV3D_p(SV3D_u):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 1280,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_p(self, device=device)
return out
class Stable_Zero123(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
unet_extra_config = {
"num_heads": 8,
"num_head_channels": -1,
}
clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
def clip_target(self):
return None
class SD_X4Upscaler(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 256,
'in_channels': 7,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
"disable_self_attentions": [True, True, True, False],
"num_classes": 1000,
"num_heads": 8,
"num_head_channels": -1,
}
latent_format = latent_formats.SD_X4
sampling_settings = {
"linear_start": 0.0001,
"linear_end": 0.02,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD_X4Upscaler(self, device=device)
return out
class Stable_Cascade_C(supported_models_base.BASE):
unet_config = {
"stable_cascade_stage": 'c',
}
unet_extra_config = {}
latent_format = latent_formats.SC_Prior
supported_inference_dtypes = [torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 2.0,
}
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoder."]
clip_vision_prefix = "clip_l_vision."
def process_unet_state_dict(self, state_dict):
key_list = list(state_dict.keys())
for y in ["weight", "bias"]:
suffix = "in_proj_{}".format(y)
keys = filter(lambda a: a.endswith(suffix), key_list)
for k_from in keys:
weights = state_dict.pop(k_from)
prefix = k_from[:-(len(suffix) + 1)]
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["to_q", "to_k", "to_v"]
k_to = "{}.{}.{}".format(prefix, p[x], y)
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return state_dict
def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
if "clip_g.text_projection" in state_dict:
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
return state_dict
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_C(self, device=device)
return out
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
class Stable_Cascade_B(Stable_Cascade_C):
unet_config = {
"stable_cascade_stage": 'b',
}
unet_extra_config = {}
latent_format = latent_formats.SC_B
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
clip_vision_prefix = None
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_B(self, device=device)
return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models += [SVD_img2vid] models += [SVD_img2vid]

21
comfy/supported_models_base.py

@ -21,11 +21,16 @@ class BASE:
noise_aug_config = None noise_aug_config = None
sampling_settings = {} sampling_settings = {}
latent_format = latent_formats.LatentFormat latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]
text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
manual_cast_dtype = None
@classmethod @classmethod
def matches(s, unet_config): def matches(s, unet_config):
for k in s.unet_config: for k in s.unet_config:
if s.unet_config[k] != unet_config[k]: if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False return False
return True return True
@ -51,6 +56,7 @@ class BASE:
return out return out
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
return state_dict return state_dict
def process_unet_state_dict(self, state_dict): def process_unet_state_dict(self, state_dict):
@ -60,7 +66,13 @@ class BASE:
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."} replace_prefix = {"": self.text_encoder_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
replace_prefix = {}
if self.clip_vision_prefix is not None:
replace_prefix[""] = self.clip_vision_prefix
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict): def process_unet_state_dict_for_saving(self, state_dict):
@ -68,6 +80,9 @@ class BASE:
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict): def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "first_stage_model."} replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype

5
comfy/taesd/taesd.py

@ -7,9 +7,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import comfy.utils import comfy.utils
import comfy.ops
def conv(n_in, n_out, **kwargs): def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module): class Clamp(nn.Module):
def forward(self, x): def forward(self, x):
@ -19,7 +20,7 @@ class Block(nn.Module):
def __init__(self, n_in, n_out): def __init__(self, n_in, n_out):
super().__init__() super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU() self.fuse = nn.ReLU()
def forward(self, x): def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))

62
comfy/utils.py

@ -5,6 +5,7 @@ import comfy.checkpoint_pickle
import safetensors.torch import safetensors.torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import logging
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
@ -14,14 +15,14 @@ def load_torch_file(ckpt, safe_load=False, device=None):
else: else:
if safe_load: if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames: if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False safe_load = False
if safe_load: if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else: else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else: else:
@ -98,8 +99,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd return sd
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
tp = "{}text_projection.weight".format(prefix_from)
if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
tp = "{}text_projection".format(prefix_from)
if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
return sd
UNET_MAP_ATTENTIONS = { UNET_MAP_ATTENTIONS = {
"proj_in.weight", "proj_in.weight",
"proj_in.bias", "proj_in.bias",
@ -169,6 +184,8 @@ UNET_MAP_BASIC = {
} }
def unet_to_diffusers(unet_config): def unet_to_diffusers(unet_config):
if "num_res_blocks" not in unet_config:
return {}
num_res_blocks = unet_config["num_res_blocks"] num_res_blocks = unet_config["num_res_blocks"]
channel_mult = unet_config["channel_mult"] channel_mult = unet_config["channel_mult"]
transformer_depth = unet_config["transformer_depth"][:] transformer_depth = unet_config["transformer_depth"][:]
@ -239,6 +256,26 @@ def repeat_to_batch_size(tensor, batch_size):
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
return tensor return tensor
def resize_to_batch_size(tensor, batch_size):
in_batch_size = tensor.shape[0]
if in_batch_size == batch_size:
return tensor
if batch_size <= 1:
return tensor[:batch_size]
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
return output
def convert_sd_to(state_dict, dtype): def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys()) keys = list(state_dict.keys())
for k in keys: for k in keys:
@ -258,8 +295,11 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]: for name in attrs[:-1]:
obj = getattr(obj, name) obj = getattr(obj, name)
prev = getattr(obj, attrs[-1]) prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) setattr(obj, attrs[-1], value)
del prev return prev
def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def copy_to_param(obj, attr, value): def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it # inplace update tensor instead of replacing it
@ -356,7 +396,7 @@ def lanczos(samples, width, height):
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images) result = torch.stack(images)
return result return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop): def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center": if crop == "center":
@ -385,17 +425,19 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode() @torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
for b in range(samples.shape[0]): for b in range(samples.shape[0]):
s = samples[b:b+1] s = samples[b:b+1]
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap): for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap): for x in range(0, s.shape[3], tile_x - overlap):
x = max(0, min(s.shape[-1] - overlap, x))
y = max(0, min(s.shape[-2] - overlap, y))
s_in = s[:,:,y:y+tile_y,x:x+tile_x] s_in = s[:,:,y:y+tile_y,x:x+tile_x]
ps = function(s_in).cpu() ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount) feather = round(overlap * upscale_amount)
for t in range(feather): for t in range(feather):

272
comfy_extras/nodes_canny.py

@ -5,275 +5,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import comfy.model_management import comfy.model_management
def get_canny_nms_kernel(device=None, dtype=None): from kornia.filters import canny
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
return torch.tensor(
[
[[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]],
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
[[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
def get_hysteresis_kernel(device=None, dtype=None):
"""Utility function that returns the 3x3 kernels for the Canny hysteresis."""
return torch.tensor(
[
[[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]],
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
[[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = torch.nn.functional.pad(img, padding, mode="reflect")
img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3])
return img
def get_sobel_kernel2d(device=None, dtype=None):
kernel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=device, dtype=dtype)
kernel_y = kernel_x.transpose(0, 1)
return torch.stack([kernel_x, kernel_y])
def spatial_gradient(input, normalized: bool = True):
r"""Compute the first order image derivative in both x and y using a Sobel operator.
.. image:: _static/img/spatial_gradient.png
Args:
input: input image tensor with shape :math:`(B, C, H, W)`.
mode: derivatives modality, can be: `sobel` or `diff`.
order: the order of the derivatives.
normalized: whether the output is normalized.
Return:
the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
filtering_edges.html>`__.
Examples:
>>> input = torch.rand(1, 3, 4, 4)
>>> output = spatial_gradient(input) # 1x3x2x4x4
>>> output.shape
torch.Size([1, 3, 2, 4, 4])
"""
# KORNIA_CHECK_IS_TENSOR(input)
# KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W'])
# allocate kernel
kernel = get_sobel_kernel2d(device=input.device, dtype=input.dtype)
if normalized:
kernel = normalize_kernel2d(kernel)
# prepare kernel
b, c, h, w = input.shape
tmp_kernel = kernel[:, None, ...]
# Pad with "replicate for spatial dims, but with zeros for channel
spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
out_channels: int = 2
padded_inp = torch.nn.functional.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')
out = F.conv2d(padded_inp, tmp_kernel, groups=1, padding=0, stride=1)
return out.reshape(b, c, out_channels, h, w)
def rgb_to_grayscale(image, rgb_weights = None):
r"""Convert a RGB image to grayscale version of image.
.. image:: _static/img/rgb_to_grayscale.png
The image data is assumed to be in the range of (0, 1).
Args:
image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`.
rgb_weights: Weights that will be applied on each channel (RGB).
The sum of the weights should add up to one.
Returns:
grayscale version of the image with shape :math:`(*,1,H,W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
color_conversions.html>`__.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> gray = rgb_to_grayscale(input) # 2x1x4x5
"""
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
if rgb_weights is None:
# 8 bit images
if image.dtype == torch.uint8:
rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
# floating point images
elif image.dtype in (torch.float16, torch.float32, torch.float64):
rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
else:
raise TypeError(f"Unknown data type: {image.dtype}")
else:
# is tensor that we make sure is in the same device/dtype
rgb_weights = rgb_weights.to(image)
# unpack the color image channels with RGB order
r: Tensor = image[..., 0:1, :, :]
g: Tensor = image[..., 1:2, :, :]
b: Tensor = image[..., 2:3, :, :]
w_r, w_g, w_b = rgb_weights.unbind()
return w_r * r + w_g * g + w_b * b
def canny(
input,
low_threshold = 0.1,
high_threshold = 0.2,
kernel_size = 5,
sigma = 1,
hysteresis = True,
eps = 1e-6,
):
r"""Find edges of the input image and filters them using the Canny algorithm.
.. image:: _static/img/canny.png
Args:
input: input image tensor with shape :math:`(B,C,H,W)`.
low_threshold: lower threshold for the hysteresis procedure.
high_threshold: upper threshold for the hysteresis procedure.
kernel_size: the size of the kernel for the gaussian blur.
sigma: the standard deviation of the kernel for the gaussian blur.
hysteresis: if True, applies the hysteresis edge tracking.
Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
eps: regularization number to avoid NaN during backprop.
Returns:
- the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
- the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
canny.html>`__.
Example:
>>> input = torch.rand(5, 3, 4, 4)
>>> magnitude, edges = canny(input) # 5x3x4x4
>>> magnitude.shape
torch.Size([5, 1, 4, 4])
>>> edges.shape
torch.Size([5, 1, 4, 4])
"""
# KORNIA_CHECK_IS_TENSOR(input)
# KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W'])
# KORNIA_CHECK(
# low_threshold <= high_threshold,
# "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: "
# f"{low_threshold}>{high_threshold}",
# )
# KORNIA_CHECK(0 < low_threshold < 1, f'Invalid low threshold. Should be in range (0, 1). Got: {low_threshold}')
# KORNIA_CHECK(0 < high_threshold < 1, f'Invalid high threshold. Should be in range (0, 1). Got: {high_threshold}')
device = input.device
dtype = input.dtype
# To Grayscale
if input.shape[1] == 3:
input = rgb_to_grayscale(input)
# Gaussian filter
blurred: Tensor = gaussian_blur_2d(input, kernel_size, sigma)
# Compute the gradients
gradients: Tensor = spatial_gradient(blurred, normalized=False)
# Unpack the edges
gx: Tensor = gradients[:, :, 0]
gy: Tensor = gradients[:, :, 1]
# Compute gradient magnitude and angle
magnitude: Tensor = torch.sqrt(gx * gx + gy * gy + eps)
angle: Tensor = torch.atan2(gy, gx)
# Radians to Degrees
angle = 180.0 * angle / math.pi
# Round angle to the nearest 45 degree
angle = torch.round(angle / 45) * 45
# Non-maximal suppression
nms_kernels: Tensor = get_canny_nms_kernel(device, dtype)
nms_magnitude: Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)
# Get the indices for both directions
positive_idx: Tensor = (angle / 45) % 8
positive_idx = positive_idx.long()
negative_idx: Tensor = ((angle / 45) + 4) % 8
negative_idx = negative_idx.long()
# Apply the non-maximum suppression to the different directions
channel_select_filtered_positive: Tensor = torch.gather(nms_magnitude, 1, positive_idx)
channel_select_filtered_negative: Tensor = torch.gather(nms_magnitude, 1, negative_idx)
channel_select_filtered: Tensor = torch.stack(
[channel_select_filtered_positive, channel_select_filtered_negative], 1
)
is_max: Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
magnitude = magnitude * is_max
# Threshold
edges: Tensor = F.threshold(magnitude, low_threshold, 0.0)
low: Tensor = magnitude > low_threshold
high: Tensor = magnitude > high_threshold
edges = low * 0.5 + high * 0.5
edges = edges.to(dtype)
# Hysteresis
if hysteresis:
edges_old: Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
hysteresis_kernels: Tensor = get_hysteresis_kernel(device, dtype)
while ((edges_old - edges).abs() != 0).any():
weak: Tensor = (edges == 0.5).float()
strong: Tensor = (edges == 1).float()
hysteresis_magnitude: Tensor = F.conv2d(
edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
)
hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
hysteresis_magnitude = hysteresis_magnitude * weak + strong
edges_old = edges.clone()
edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
edges = hysteresis_magnitude
return magnitude, edges
class Canny: class Canny:
@ -291,7 +23,7 @@ class Canny:
def detect_edge(self, image, low_threshold, high_threshold): def detect_edge(self, image, low_threshold, high_threshold):
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1) img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,) return (img_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

25
comfy_extras/nodes_cond.py

@ -0,0 +1,25 @@
class CLIPTextEncodeControlnet:
@classmethod
def INPUT_TYPES(s):
return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "_for_testing/conditioning"
def encode(self, clip, conditioning, text):
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['cross_attn_controlnet'] = cond
n[1]['pooled_output_controlnet'] = pooled
c.append(n)
return (c, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
}

104
comfy_extras/nodes_custom_sampler.py

@ -13,6 +13,7 @@ class BasicScheduler:
{"model": ("MODEL",), {"model": ("MODEL",),
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ), "scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
@ -20,8 +21,14 @@ class BasicScheduler:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, model, scheduler, steps): def get_sigmas(self, model, scheduler, steps, denoise):
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu() total_steps = steps
if denoise < 1.0:
total_steps = int(steps/denoise)
comfy.model_management.load_models_gpu([model])
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):]
return (sigmas, ) return (sigmas, )
@ -87,6 +94,7 @@ class SDTurboScheduler:
return {"required": return {"required":
{"model": ("MODEL",), {"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}), "steps": ("INT", {"default": 1, "min": 1, "max": 10}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
@ -94,8 +102,10 @@ class SDTurboScheduler:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps): def get_sigmas(self, model, steps, denoise):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
comfy.model_management.load_models_gpu([model])
sigmas = model.model.model_sampling.sigma(timesteps) sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, ) return (sigmas, )
@ -171,6 +181,28 @@ class KSamplerSelect:
sampler = comfy.samplers.sampler_object(sampler_name) sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, ) return (sampler, )
class SamplerDPMPP_3M_SDE:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"noise_device": (['gpu', 'cpu'], ),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise, noise_device):
if noise_device == 'cpu':
sampler_name = "dpmpp_3m_sde"
else:
sampler_name = "dpmpp_3m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerDPMPP_2M_SDE: class SamplerDPMPP_2M_SDE:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -218,6 +250,66 @@ class SamplerDPMPP_SDE:
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, ) return (sampler, )
class SamplerEulerAncestral:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise):
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerLMS:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"order": ("INT", {"default": 4, "min": 1, "max": 100}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, order):
sampler = comfy.samplers.ksampler("lms", {"order": order})
return (sampler, )
class SamplerDPMAdaptative:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"order": ("INT", {"default": 3, "min": 2, "max": 3}),
"rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise):
sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
"icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
"s_noise":s_noise })
return (sampler, )
class SamplerCustom: class SamplerCustom:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -278,8 +370,12 @@ NODE_CLASS_MAPPINGS = {
"VPScheduler": VPScheduler, "VPScheduler": VPScheduler,
"SDTurboScheduler": SDTurboScheduler, "SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect, "KSamplerSelect": KSamplerSelect,
"SamplerEulerAncestral": SamplerEulerAncestral,
"SamplerLMS": SamplerLMS,
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,
} }

42
comfy_extras/nodes_differential_diffusion.py

@ -0,0 +1,42 @@
# code adapted from https://github.com/exx8/differential-diffusion
import torch
class DifferentialDiffusion():
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply"
CATEGORY = "_for_testing"
INIT = False
def apply(self, model):
model = model.clone()
model.set_model_denoise_mask_function(self.forward)
return (model,)
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
model = extra_options["model"]
step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min
if step_sigmas[-1] > sigma_to:
sigma_to = step_sigmas[-1]
sigma_from = step_sigmas[0]
ts_from = model.inner_model.model_sampling.timestep(sigma_from)
ts_to = model.inner_model.model_sampling.timestep(sigma_to)
current_ts = model.inner_model.model_sampling.timestep(sigma[0])
threshold = (current_ts - ts_to) / (ts_from - ts_to)
return (denoise_mask >= threshold).to(denoise_mask.dtype)
NODE_CLASS_MAPPINGS = {
"DifferentialDiffusion": DifferentialDiffusion,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DifferentialDiffusion": "Differential Diffusion",
}

10
comfy_extras/nodes_freelunch.py

@ -1,7 +1,7 @@
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
import torch import torch
import logging
def Fourier_filter(x, threshold, scale): def Fourier_filter(x, threshold, scale):
# FFT # FFT
@ -34,7 +34,7 @@ class FreeU:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches"
def patch(self, model, b1, b2, s1, s2): def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
@ -49,7 +49,7 @@ class FreeU:
try: try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except: except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else: else:
@ -73,7 +73,7 @@ class FreeU_V2:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches"
def patch(self, model, b1, b2, s1, s2): def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
@ -95,7 +95,7 @@ class FreeU_V2:
try: try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except: except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else: else:

3
comfy_extras/nodes_hypernetwork.py

@ -1,6 +1,7 @@
import comfy.utils import comfy.utils
import folder_paths import folder_paths
import torch import torch
import logging
def load_hypernetwork_patch(path, strength): def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True) sd = comfy.utils.load_torch_file(path, safe_load=True)
@ -23,7 +24,7 @@ def load_hypernetwork_patch(path, strength):
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
return None return None
out = {} out = {}

36
comfy_extras/nodes_hypertile.py

@ -2,9 +2,10 @@
import math import math
from einops import rearrange from einops import rearrange
import random # Use torch rng for consistency across generations
from torch import randint
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value) min_value = min(min_value, value)
# All big divisors of value (inclusive) # All big divisors of value (inclusive)
@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter
ns = [value // i for i in divisors[:max_options]] # has at least 1 element ns = [value // i for i in divisors[:max_options]] # has at least 1 element
random.seed(counter) if len(ns) - 1 > 0:
idx = random.randint(0, len(ns) - 1) idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
else:
idx = 0
return ns[idx] return ns[idx]
@ -29,34 +32,31 @@ class HyperTile:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth): def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
apply_to = set()
temp = model_channels
for x in range(max_depth + 1):
apply_to.add(temp)
temp *= 2
latent_tile_size = max(32, tile_size) // 8 latent_tile_size = max(32, tile_size) // 8
self.temp = None self.temp = None
self.counter = 1
def hypertile_in(q, k, v, extra_options): def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to: model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
for i in range(max_depth + 1):
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
if model_chans in apply_to:
shape = extra_options["original_shape"] shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2] aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1) hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) nh = random_divisor(h, latent_tile_size * factor, swap_size)
self.counter += 1 nw = random_divisor(w, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
if nh * nw > 1: if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)

26
comfy_extras/nodes_images.py

@ -37,7 +37,7 @@ class RepeatImageBatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}), "amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat" FUNCTION = "repeat"
@ -48,6 +48,25 @@ class RepeatImageBatch:
s = image.repeat((amount, 1,1,1)) s = image.repeat((amount, 1,1,1))
return (s,) return (s,)
class ImageFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "frombatch"
CATEGORY = "image/batch"
def frombatch(self, image, batch_index, length):
s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone()
return (s,)
class SaveAnimatedWEBP: class SaveAnimatedWEBP:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -74,7 +93,7 @@ class SaveAnimatedWEBP:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "_for_testing" CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method) method = self.methods.get(method)
@ -136,7 +155,7 @@ class SaveAnimatedPNG:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "_for_testing" CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append filename_prefix += self.prefix_append
@ -170,6 +189,7 @@ class SaveAnimatedPNG:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop, "ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch, "RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG, "SaveAnimatedPNG": SaveAnimatedPNG,
} }

49
comfy_extras/nodes_latent.py

@ -3,9 +3,7 @@ import torch
def reshape_latent_to(target_shape, latent): def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
latent.movedim(1, -1)
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
latent.movedim(-1, 1)
return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
@ -102,9 +100,56 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,) return (samples_out,)
class LatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
if s1.shape[1:] != s2.shape[1:]:
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
class LatentBatchSeedBehavior:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
samples_out = samples.copy()
latent = samples["samples"]
if seed_behavior == "random":
if 'batch_index' in samples_out:
samples_out.pop('batch_index')
elif seed_behavior == "fixed":
batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, "LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply, "LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate, "LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
} }

22
comfy_extras/nodes_mask.py

@ -6,6 +6,7 @@ import comfy.utils
from nodes import MAX_RESOLUTION from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
if resize_source: if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
@ -20,7 +21,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if mask is None: if mask is None:
mask = torch.ones_like(source) mask = torch.ones_like(source)
else: else:
mask = mask.clone() mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
@ -340,6 +341,24 @@ class GrowMask:
out.append(output) out.append(output)
return (torch.stack(out, dim=0),) return (torch.stack(out, dim=0),)
class ThresholdMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, mask, value):
mask = (mask > value).float()
return (mask,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
@ -354,6 +373,7 @@ NODE_CLASS_MAPPINGS = {
"MaskComposite": MaskComposite, "MaskComposite": MaskComposite,
"FeatherMask": FeatherMask, "FeatherMask": FeatherMask,
"GrowMask": GrowMask, "GrowMask": GrowMask,
"ThresholdMask": ThresholdMask,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {

96
comfy_extras/nodes_model_advanced.py

@ -1,6 +1,7 @@
import folder_paths import folder_paths
import comfy.sd import comfy.sd
import comfy.model_sampling import comfy.model_sampling
import comfy.latent_formats
import torch import torch
class LCM(comfy.model_sampling.EPS): class LCM(comfy.model_sampling.EPS):
@ -17,41 +18,23 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteDistilled(torch.nn.Module): class X0(comfy.model_sampling.EPS):
original_timesteps = 50 def calculate_denoised(self, sigma, model_output, model_input):
return model_output
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
alphas = 1.0 - betas original_timesteps = 50
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.skip_steps = timesteps // self.original_timesteps def __init__(self, model_config=None):
super().__init__(model_config)
self.skip_steps = self.num_timesteps // self.original_timesteps
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps): for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property self.set_sigmas(sigmas_valid)
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
@ -66,14 +49,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device) return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas): def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1) alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
@ -98,7 +73,7 @@ class ModelSamplingDiscrete:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],), "sampling": (["eps", "v_prediction", "lcm", "x0"],),
"zsnr": ("BOOLEAN", {"default": False}), "zsnr": ("BOOLEAN", {"default": False}),
}} }}
@ -118,22 +93,50 @@ class ModelSamplingDiscrete:
elif sampling == "lcm": elif sampling == "lcm":
sampling_type = LCM sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "x0":
sampling_type = X0
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced() model_sampling = ModelSamplingAdvanced(model.model.model_config)
if zsnr: if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )
class ModelSamplingStableCascade:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, shift):
m = model.clone()
sampling_base = comfy.model_sampling.StableCascadeSampling
sampling_type = comfy.model_sampling.EPS
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM: class ModelSamplingContinuousEDM:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],), "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}} }}
@ -146,17 +149,25 @@ class ModelSamplingContinuousEDM:
def patch(self, model, sampling, sigma_max, sigma_min): def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone() m = model.clone()
latent_format = None
sigma_data = 1.0
if sampling == "eps": if sampling == "eps":
sampling_type = comfy.model_sampling.EPS sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction": elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "edm_playground_v2.5":
sampling_type = comfy.model_sampling.EDM
sigma_data = 0.5
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced() model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max) model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
if latent_format is not None:
m.add_object_patch("latent_format", latent_format)
return (m, ) return (m, )
class RescaleCFG: class RescaleCFG:
@ -201,5 +212,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"ModelSamplingStableCascade": ModelSamplingStableCascade,
"RescaleCFG": RescaleCFG, "RescaleCFG": RescaleCFG,
} }

129
comfy_extras/nodes_model_merging.py

@ -87,6 +87,50 @@ class CLIPMergeSimple:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, ) return (m, )
class CLIPSubtract:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, multiplier):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
class CLIPAdd:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
class ModelMergeBlocks: class ModelMergeBlocks:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -119,6 +163,48 @@ class ModelMergeBlocks:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, ) return (m, )
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
class CheckpointSave: class CheckpointSave:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -137,46 +223,7 @@ class CheckpointSave:
CATEGORY = "advanced/model_merging" CATEGORY = "advanced/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {} return {}
class CLIPSave: class CLIPSave:
@ -276,6 +323,8 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeAdd": ModelAdd, "ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave, "CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple, "CLIPMergeSimple": CLIPMergeSimple,
"CLIPMergeSubtract": CLIPSubtract,
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave, "CLIPSave": CLIPSave,
"VAESave": VAESave, "VAESave": VAESave,
} }

49
comfy_extras/nodes_morphology.py

@ -0,0 +1,49 @@
import torch
import comfy.model_management
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
class Morphology:
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],),
"kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "image/postprocessing"
def process(self, image, operation, kernel_size):
device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1)
if operation == "erode":
output = erosion(image_k, kernel)
elif operation == "dilate":
output = dilation(image_k, kernel)
elif operation == "open":
output = opening(image_k, kernel)
elif operation == "close":
output = closing(image_k, kernel)
elif operation == "gradient":
output = gradient(image_k, kernel)
elif operation == "top_hat":
output = top_hat(image_k, kernel)
elif operation == "bottom_hat":
output = bottom_hat(image_k, kernel)
else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Morphology": "ImageMorphology",
}

55
comfy_extras/nodes_perpneg.py

@ -0,0 +1,55 @@
import torch
import comfy.model_management
import comfy.sample
import comfy.samplers
import comfy.utils
class PerpNeg:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()
nocond = comfy.sample.convert_cond(empty_conditioning)
def cfg_function(args):
model = args["model"]
noise_pred_pos = args["cond_denoised"]
noise_pred_neg = args["uncond_denoised"]
cond_scale = args["cond_scale"]
x = args["input"]
sigma = args["sigma"]
model_options = args["model_options"]
nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond
perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
cfg_result = x - cfg_result
return cfg_result
m.set_model_sampler_cfg_function(cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg",
}

187
comfy_extras/nodes_photomaker.py

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
import folder_paths
import comfy.clip_model
import comfy.clip_vision
import comfy.ops
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
VISION_CONFIG_DICT = {
"hidden_size": 1024,
"image_size": 224,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"hidden_act": "quick_gelu",
}
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = operations.LayerNorm(in_dim)
self.fc1 = operations.Linear(in_dim, hidden_dim)
self.fc2 = operations.Linear(hidden_dim, out_dim)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class FuseModule(nn.Module):
def __init__(self, embed_dim, operations):
super().__init__()
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
self.layer_norm = operations.LayerNorm(embed_dim)
def fuse_fn(self, prompt_embeds, id_embeds):
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
stacked_id_embeds = self.mlp2(stacked_id_embeds)
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
return stacked_id_embeds
def forward(
self,
prompt_embeds,
id_embeds,
class_tokens_mask,
) -> torch.Tensor:
# id_embeds shape: [b, max_num_inputs, 1, 2048]
id_embeds = id_embeds.to(prompt_embeds.dtype)
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
batch_size, max_num_inputs = id_embeds.shape[:2]
# seq_length: 77
seq_length = prompt_embeds.shape[1]
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
flat_id_embeds = id_embeds.view(
-1, id_embeds.shape[-2], id_embeds.shape[-1]
)
# valid_id_mask [b*max_num_inputs]
valid_id_mask = (
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
< num_inputs[:, None]
)
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
class_tokens_mask = class_tokens_mask.view(-1)
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
# slice out the image token embeddings
image_token_embeds = prompt_embeds[class_tokens_mask]
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
return updated_prompt_embeds
class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
def __init__(self):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
dtype = comfy.model_management.text_encoder_dtype(self.load_device)
super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast)
self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False)
self.fuse_module = FuseModule(2048, comfy.ops.manual_cast)
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
b, num_inputs, c, h, w = id_pixel_values.shape
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
shared_id_embeds = self.vision_model(id_pixel_values)[2]
id_embeds = self.visual_projection(shared_id_embeds)
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
return updated_prompt_embeds
class PhotoMakerLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
RETURN_TYPES = ("PHOTOMAKER",)
FUNCTION = "load_photomaker_model"
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:
data = data["id_encoder"]
photomaker_model.load_state_dict(data)
return (photomaker_model,)
class PhotoMakerEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker": ("PHOTOMAKER",),
"image": ("IMAGE",),
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_photomaker"
CATEGORY = "_for_testing/photomaker"
def apply_photomaker(self, photomaker, image, clip, text):
special_token = "photomaker"
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
try:
index = text.split(" ").index(special_token) + 1
except ValueError:
index = -1
tokens = clip.tokenize(text, return_word_ids=True)
out_tokens = {}
for k in tokens:
out_tokens[k] = []
for t in tokens[k]:
f = list(filter(lambda x: x[2] != index, t))
while len(f) < len(t):
f.append(t[-1])
out_tokens[k].append(f)
cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
if index > 0:
token_index = index - 1
num_id_images = 1
class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
else:
out = cond
return ([[out, {"pooled_output": pooled}]], )
NODE_CLASS_MAPPINGS = {
"PhotoMakerLoader": PhotoMakerLoader,
"PhotoMakerEncode": PhotoMakerEncode,
}

3
comfy_extras/nodes_post_processing.py

@ -33,6 +33,7 @@ class Blend:
CATEGORY = "image/postprocessing" CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
image2 = image2.to(image1.device)
if image1.shape != image2.shape: if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2) image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
@ -226,7 +227,7 @@ class Sharpen:
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1 kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
center = kernel_size // 2 center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)

30
comfy_extras/nodes_rebatch.py

@ -99,10 +99,40 @@ class LatentRebatch:
return (output_list,) return (output_list,)
class ImageRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
batch_size = batch_size[0]
output_list = []
all_images = []
for img in images:
for i in range(img.shape[0]):
all_images.append(img[i:i+1])
for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch, "RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents", "RebatchLatents": "Rebatch Latents",
"RebatchImages": "Rebatch Images",
} }

170
comfy_extras/nodes_sag.py

@ -0,0 +1,170 @@
import torch
from torch import einsum
import torch.nn.functional as F
import math
from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
import comfy.samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return (out, sim)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, scale, blur_sigma):
m = model.clone()
attn_scores = None
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
def attn_and_record(q, k, v, extra_options):
nonlocal attn_scores
# if uncond, save the attention scores
heads = extra_options["n_heads"]
cond_or_uncond = extra_options["cond_or_uncond"]
b = q.shape[0] // len(cond_or_uncond)
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
def post_cfg_function(args):
nonlocal attn_scores
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
return cfg_result
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, )
NODE_CLASS_MAPPINGS = {
"SelfAttentionGuidance": SelfAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SelfAttentionGuidance": "Self-Attention Guidance",
}

47
comfy_extras/nodes_sdupscale.py

@ -0,0 +1,47 @@
import torch
import nodes
import comfy.utils
class SD_4XUpscale_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/upscale_diffusion"
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
width = max(1, round(images.shape[-2] * scale_ratio))
height = max(1, round(images.shape[-3] * scale_ratio))
pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
out_cp = []
out_cn = []
for t in positive:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cp.append(n)
for t in negative:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cn.append(n)
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
return (out_cp, out_cn, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
}

143
comfy_extras/nodes_stable3d.py

@ -0,0 +1,143 @@
import torch
import nodes
import comfy.utils
def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation])
azimuth = torch.as_tensor([azimuth])
embeddings = torch.stack(
[
torch.deg2rad(
(90 - elevation) - (90)
), # Zero123 polar is 90-elevation
torch.sin(torch.deg2rad(azimuth)),
torch.cos(torch.deg2rad(azimuth)),
torch.deg2rad(
90 - torch.full_like(elevation, 0)
),
], dim=-1).unsqueeze(1)
return embeddings
class StableZero123_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = camera_embeddings(elevation, azimuth)
cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = []
for i in range(batch_size):
cam_embeds.append(camera_embeddings(elevation, azimuth))
elevation += elevation_batch_increment
azimuth += azimuth_batch_increment
cam_embeds = torch.cat(cam_embeds, dim=0)
cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
class SV3D_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
azimuth = 0
azimuth_increment = 360 / (max(video_frames, 2) - 1)
elevations = []
azimuths = []
for i in range(video_frames):
elevations.append(elevation)
azimuths.append(azimuth)
azimuth += azimuth_increment
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning,
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
"SV3D_Conditioning": SV3D_Conditioning,
}

140
comfy_extras/nodes_stable_cascade.py

@ -0,0 +1,140 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import nodes
import comfy.utils
class StableCascade_EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, width, height, compression, batch_size=1):
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageC_VAEEncode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, image, vae, compression):
width = image.shape[-2]
height = image.shape[-3]
out_width = (width // compression) * vae.downscale_ratio
out_height = (height // compression) * vae.downscale_ratio
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageB_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "conditioning": ("CONDITIONING",),
"stage_c": ("LATENT",),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "set_prior"
CATEGORY = "conditioning/stable_cascade"
def set_prior(self, conditioning, stage_c):
c = []
for t in conditioning:
d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples']
n = [t[0], d]
c.append(n)
return (c, )
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):
width = image.shape[-2]
height = image.shape[-3]
batch_size = image.shape[0]
controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1)
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})
NODE_CLASS_MAPPINGS = {
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
}

45
comfy_extras/nodes_video_model.py

@ -3,6 +3,7 @@ import torch
import comfy.utils import comfy.utils
import comfy.sd import comfy.sd
import folder_paths import folder_paths
import comfy_extras.nodes_model_merging
class ImageOnlyCheckpointLoader: class ImageOnlyCheckpointLoader:
@ -78,10 +79,54 @@ class VideoLinearCFGGuidance:
m.set_model_sampler_cfg_function(linear_cfg) m.set_model_sampler_cfg_function(linear_cfg)
return (m, ) return (m, )
class VideoTriangleCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
period = 1.0
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {

10
cuda_malloc.py

@ -1,6 +1,7 @@
import os import os
import importlib.util import importlib.util
from comfy.cli_args import args from comfy.cli_args import args
import subprocess
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names(): def get_gpu_names():
@ -34,14 +35,19 @@ def get_gpu_names():
return gpu_names return gpu_names
return enum_display_devices() return enum_display_devices()
else: else:
return set() gpu_names = set()
out = subprocess.check_output(['nvidia-smi', '-L'])
for l in out.split(b'\n'):
if len(l) > 0:
gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
return gpu_names
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630" "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
} }
def cuda_malloc_supported(): def cuda_malloc_supported():

16
custom_nodes/example_node.py.example

@ -6,6 +6,8 @@ class Example:
------------- -------------
INPUT_TYPES (dict): INPUT_TYPES (dict):
Tell the main program input parameters of nodes. Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
Attributes Attributes
---------- ----------
@ -89,6 +91,20 @@ class Example:
image = 1.0 - image image = 1.0 - image
return (image,) return (image,)
"""
The node will always be re executed if any of the inputs change but
this method can be used to force the node to execute again even when the inputs don't change.
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
executed, if it is different the node will be executed again.
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
changes between executions the LoadImage node is executed again.
"""
#@classmethod
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
# return ""
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
# WEB_DIRECTORY = "./somejs"
# A dictionary that contains all nodes you want to export with their names # A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique # NOTE: names should be globally unique

45
custom_nodes/websocket_image_save.py

@ -0,0 +1,45 @@
from PIL import Image, ImageOps
from io import BytesIO
import numpy as np
import struct
import comfy.utils
import time
#You can use this node to save full size images through the websocket, the
#images will be sent in exactly the same format as the image previews: as
#binary images on the websocket with a 8 byte header indicating the type
#of binary message (first 4 bytes) and the image format (next 4 bytes).
#Note that no metadata will be put in the images saved with this node.
class SaveImageWebsocket:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),}
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "api/image"
def save_images(self, images):
pbar = comfy.utils.ProgressBar(images.shape[0])
step = 0
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pbar.update_absolute(step, images.shape[0], ("PNG", img, None))
step += 1
return {}
def IS_CHANGED(s, images):
return time.time()
NODE_CLASS_MAPPINGS = {
"SaveImageWebsocket": SaveImageWebsocket,
}

172
execution.py

@ -1,12 +1,11 @@
import os
import sys import sys
import copy import copy
import json
import logging import logging
import threading import threading
import heapq import heapq
import traceback import traceback
import gc import inspect
from typing import List, Literal, NamedTuple, Optional
import torch import torch
import nodes import nodes
@ -36,8 +35,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
if h[x] == "PROMPT": if h[x] == "PROMPT":
input_data_all[x] = [prompt] input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO": if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data: input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID": if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id] input_data_all[x] = [unique_id]
return input_data_all return input_data_all
@ -195,8 +193,12 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
return (True, None, None) return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item): def recursive_will_execute(prompt, outputs, current_item, memo={}):
unique_id = current_item unique_id = current_item
if unique_id in memo:
return memo[unique_id]
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
will_execute = [] will_execute = []
if unique_id in outputs: if unique_id in outputs:
@ -208,9 +210,10 @@ def recursive_will_execute(prompt, outputs, current_item):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id) will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
return will_execute + [unique_id] memo[unique_id] = will_execute + [unique_id]
return memo[unique_id]
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item unique_id = current_item
@ -267,11 +270,21 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor: class PromptExecutor:
def __init__(self, server): def __init__(self, server):
self.server = server
self.reset()
def reset(self):
self.outputs = {} self.outputs = {}
self.object_storage = {} self.object_storage = {}
self.outputs_ui = {} self.outputs_ui = {}
self.status_messages = []
self.success = True
self.old_prompt = {} self.old_prompt = {}
self.server = server
def add_message(self, event, data, broadcast: bool):
self.status_messages.append((event, data))
if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"] node_id = error["node_id"]
@ -286,22 +299,21 @@ class PromptExecutor:
"node_type": class_type, "node_type": class_type,
"executed": list(executed), "executed": list(executed),
} }
self.server.send_sync("execution_interrupted", mes, self.server.client_id) self.add_message("execution_interrupted", mes, broadcast=True)
else: else:
if self.server.client_id is not None: mes = {
mes = { "prompt_id": prompt_id,
"prompt_id": prompt_id, "node_id": node_id,
"node_id": node_id, "node_type": class_type,
"node_type": class_type, "executed": list(executed),
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_message": error["exception_message"], "exception_type": error["exception_type"],
"exception_type": error["exception_type"], "traceback": error["traceback"],
"traceback": error["traceback"], "current_inputs": error["current_inputs"],
"current_inputs": error["current_inputs"], "current_outputs": error["current_outputs"],
"current_outputs": error["current_outputs"], }
} self.add_message("execution_error", mes, broadcast=False)
self.server.send_sync("execution_error", mes, self.server.client_id)
# Next, remove the subsequent outputs since they will not be executed # Next, remove the subsequent outputs since they will not be executed
to_delete = [] to_delete = []
@ -323,8 +335,8 @@ class PromptExecutor:
else: else:
self.server.client_id = None self.server.client_id = None
if self.server.client_id is not None: self.status_messages = []
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode(): with torch.inference_mode():
#delete cached outputs if nodes don't exist for them #delete cached outputs if nodes don't exist for them
@ -356,9 +368,9 @@ class PromptExecutor:
d = self.outputs_ui.pop(x) d = self.outputs_ui.pop(x)
del d del d
comfy.model_management.cleanup_models() self.add_message("execution_cached",
if self.server.client_id is not None: { "nodes": list(current_outputs) , "prompt_id": prompt_id},
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) broadcast=False)
executed = set() executed = set()
output_node_id = None output_node_id = None
to_execute = [] to_execute = []
@ -368,20 +380,23 @@ class PromptExecutor:
while len(to_execute) > 0: while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first #always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) memo = {}
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1] output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in # This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the # the actual SD code, instead it will report the node where the
# error was raised # error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if success is not True: if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break break
for x in executed: for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
@ -400,6 +415,10 @@ def validate_inputs(prompt, item, validated):
errors = [] errors = []
valid = True valid = True
validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
for x in required_inputs: for x in required_inputs:
if x not in inputs: if x not in inputs:
error = { error = {
@ -529,29 +548,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if hasattr(obj_class, "VALIDATE_INPUTS"): if x not in validate_function_inputs:
input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
input_config = info input_config = info
@ -578,6 +575,35 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True: if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id) ret = (False, errors, unique_id)
else: else:
@ -692,6 +718,7 @@ class PromptQueue:
self.queue = [] self.queue = []
self.currently_running = {} self.currently_running = {}
self.history = {} self.history = {}
self.flags = {}
server.prompt_queue = self server.prompt_queue = self
def put(self, item): def put(self, item):
@ -713,14 +740,27 @@ class PromptQueue:
self.server.queue_updated() self.server.queue_updated()
return (item, i) return (item, i)
def task_done(self, item_id, outputs): class ExecutionStatus(NamedTuple):
status_str: Literal['success', 'error']
completed: bool
messages: List[str]
def task_done(self, item_id, outputs,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex: with self.mutex:
prompt = self.currently_running.pop(item_id) prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE: if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history))) self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs: status_dict: Optional[dict] = None
self.history[prompt[1]]["outputs"][o] = outputs[o] if status is not None:
status_dict = copy.deepcopy(status._asdict())
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
'status': status_dict,
}
self.server.queue_updated() self.server.queue_updated()
def get_current_queue(self): def get_current_queue(self):
@ -778,3 +818,17 @@ class PromptQueue:
def delete_history_item(self, id_to_delete): def delete_history_item(self, id_to_delete):
with self.mutex: with self.mutex:
self.history.pop(id_to_delete, None) self.history.pop(id_to_delete, None)
def set_flag(self, name, data):
with self.mutex:
self.flags[name] = data
self.not_empty.notify()
def get_flags(self, reset=True):
with self.mutex:
if reset:
ret = self.flags
self.flags = {}
return ret
else:
return self.flags.copy()

22
folder_paths.py

@ -29,11 +29,14 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
filename_list_cache = {} filename_list_cache = {}
@ -137,15 +140,27 @@ def recursive_search(directory, excluded_dir_names=None):
excluded_dir_names = [] excluded_dir_names = []
result = [] result = []
dirs = {directory: os.path.getmtime(directory)} dirs = {}
# Attempt to add the initial directory to dirs with error handling
try:
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
print(f"Warning: Unable to access {directory}. Skipping this path.")
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames: for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path) result.append(relative_path)
for d in subdirs: for d in subdirs:
path = os.path.join(dirpath, d) path = os.path.join(dirpath, d)
dirs[path] = os.path.getmtime(path) try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
print(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
@ -184,8 +199,7 @@ def cached_filename_list_(folder_name):
if folder_name not in filename_list_cache: if folder_name not in filename_list_cache:
return None return None
out = filename_list_cache[folder_name] out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]: for x in out[1]:
time_modified = out[1][x] time_modified = out[1][x]
folder = x folder = x

3
latent_preview.py

@ -6,6 +6,7 @@ from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD from comfy.taesd.taesd import TAESD
import folder_paths import folder_paths
import comfy.utils import comfy.utils
import logging
MAX_PREVIEW_RESOLUTION = 512 MAX_PREVIEW_RESOLUTION = 512
@ -70,7 +71,7 @@ def get_previewer(device, latent_format):
taesd = TAESD(None, taesd_decoder_path).to(device) taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd) previewer = TAESDPreviewerImpl(taesd)
else: else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None: if previewer is None:
if latent_format.latent_rgb_factors is not None: if latent_format.latent_rgb_factors is not None:

58
main.py

@ -54,15 +54,19 @@ import threading
import gc import gc
from comfy.cli_args import args from comfy.cli_args import args
import logging
if os.name == "nt": if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__": if __name__ == "__main__":
if args.cuda_device is not None: if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc import cuda_malloc
@ -84,7 +88,7 @@ def cuda_malloc_warning():
if b in device_name: if b in device_name:
cuda_malloc_warning = True cuda_malloc_warning = True
if cuda_malloc_warning: if cuda_malloc_warning:
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
@ -93,7 +97,7 @@ def prompt_worker(q, server):
gc_collect_interval = 10.0 gc_collect_interval = 10.0
while True: while True:
timeout = None timeout = 1000.0
if need_gc: if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
@ -102,19 +106,40 @@ def prompt_worker(q, server):
item, item_id = queue_item item, item_id = queue_item
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
prompt_id = item[1] prompt_id = item[1]
server.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
q.task_done(item_id, e.outputs_ui) q.task_done(item_id,
e.outputs_ui,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time)) logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)
if flags.get("unload_models", free_memory):
comfy.model_management.unload_all_models()
need_gc = True
last_gc_collect = 0
if free_memory:
e.reset()
need_gc = True
last_gc_collect = 0
if need_gc: if need_gc:
current_time = time.perf_counter() current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval: if (current_time - last_gc_collect) > gc_collect_interval:
comfy.model_management.cleanup_models()
gc.collect() gc.collect()
comfy.model_management.soft_empty_cache() comfy.model_management.soft_empty_cache()
last_gc_collect = current_time last_gc_collect = current_time
@ -127,7 +152,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server): def hijack_progress(server):
def hook(value, total, preview_image): def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
server.send_sync("progress", {"value": value, "max": total}, server.client_id) progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
server.send_sync("progress", progress, server.client_id)
if preview_image is not None: if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)
@ -156,17 +183,24 @@ def load_extra_path_config(yaml_path):
full_path = y full_path = y
if base_path is not None: if base_path is not None:
full_path = os.path.join(base_path, full_path) full_path = os.path.join(base_path, full_path)
print("Adding extra search path", x, full_path) logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path) folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__": if __name__ == "__main__":
if args.temp_directory: if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
print(f"Setting temp directory to: {temp_dir}") logging.info(f"Setting temp directory to: {temp_dir}")
folder_paths.set_temp_directory(temp_dir) folder_paths.set_temp_directory(temp_dir)
cleanup_temp() cleanup_temp()
if args.windows_standalone_build:
try:
import new_updater
new_updater.update_windows_updater()
except:
pass
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
server = server.PromptServer(loop) server = server.PromptServer(loop)
@ -191,7 +225,7 @@ if __name__ == "__main__":
if args.output_directory: if args.output_directory:
output_dir = os.path.abspath(args.output_directory) output_dir = os.path.abspath(args.output_directory)
print(f"Setting output directory to: {output_dir}") logging.info(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir) folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
@ -201,7 +235,7 @@ if __name__ == "__main__":
if args.input_directory: if args.input_directory:
input_dir = os.path.abspath(args.input_directory) input_dir = os.path.abspath(args.input_directory)
print(f"Setting input directory to: {input_dir}") logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir) folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci: if args.quick_test_for_ci:
@ -219,6 +253,6 @@ if __name__ == "__main__":
try: try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nStopped server") logging.info("\nStopped server")
cleanup_temp() cleanup_temp()

0
models/photomaker/put_photomaker_models_here

35
new_updater.py

@ -0,0 +1,35 @@
import os
import shutil
base_path = os.path.dirname(os.path.realpath(__file__))
def update_windows_updater():
top_path = os.path.dirname(base_path)
updater_path = os.path.join(base_path, ".ci/update_windows/update.py")
bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat")
dest_updater_path = os.path.join(top_path, "update/update.py")
dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat")
dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat")
try:
with open(dest_bat_path, 'rb') as f:
contents = f.read()
except:
return
if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"):
return
shutil.copy(updater_path, dest_updater_path)
try:
with open(dest_bat_deps_path, 'rb') as f:
contents = f.read()
contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause')
with open(dest_bat_deps_path, 'wb') as f:
f.write(contents)
except:
pass
shutil.copy(bat_path, dest_bat_path)
print("Updated the windows standalone package updater.")

212
nodes.py

@ -8,8 +8,9 @@ import traceback
import math import math
import time import time
import random import random
import logging
from PIL import Image, ImageOps from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
@ -40,7 +41,7 @@ def before_node_execution():
def interrupt_processing(value=True): def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value) comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192 MAX_RESOLUTION=16384
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
@ -83,7 +84,7 @@ class ConditioningAverage :
out = [] out = []
if len(conditioning_from) > 1: if len(conditioning_from) > 1:
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0] cond_from = conditioning_from[0][0]
pooled_output_from = conditioning_from[0][1].get("pooled_output", None) pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
@ -122,7 +123,7 @@ class ConditioningConcat:
out = [] out = []
if len(conditioning_from) > 1: if len(conditioning_from) > 1:
print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0] cond_from = conditioning_from[0][0]
@ -184,6 +185,26 @@ class ConditioningSetAreaPercentage:
c.append(n) c.append(n)
return (c, ) return (c, )
class ConditioningSetAreaStrength:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
return (c, )
class ConditioningSetMask: class ConditioningSetMask:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -289,18 +310,7 @@ class VAEEncode:
CATEGORY = "latent" CATEGORY = "latent"
@staticmethod
def vae_encode_crop_pixels(pixels):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def encode(self, vae, pixels): def encode(self, vae, pixels):
pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3]) t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, ) return ({"samples":t}, )
@ -316,7 +326,6 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size): def encode(self, vae, pixels, tile_size):
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, ) return ({"samples":t}, )
@ -330,14 +339,14 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 8) * 8 x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
y = (pixels.shape[2] // 8) * 8 y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone() pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2 x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
y_offset = (pixels.shape[2] % 8) // 2 y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
@ -359,6 +368,62 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class InpaintModelConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
orig_pixels = pixels
pixels = orig_pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
m = (1.0 - mask.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels)
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -778,15 +843,20 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name, type="stable_diffusion"):
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)
class DualCLIPLoader: class DualCLIPLoader:
@ -934,7 +1004,7 @@ class GLIGENTextBoxApply:
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
c = [] c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
for t in conditioning_to: for t in conditioning_to:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)] position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
@ -947,8 +1017,8 @@ class GLIGENTextBoxApply:
return (c, ) return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self, device="cpu"): def __init__(self):
self.device = device self.device = comfy.model_management.intermediate_device()
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -961,7 +1031,7 @@ class EmptyLatentImage:
CATEGORY = "latent" CATEGORY = "latent"
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, ) return ({"samples":latent}, )
@ -1358,7 +1428,7 @@ class SaveImage:
filename_prefix += self.prefix_append filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list() results = list()
for image in images: for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None metadata = None
@ -1370,7 +1440,8 @@ class SaveImage:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x])) metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png" filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({ results.append({
"filename": file, "filename": file,
@ -1410,17 +1481,32 @@ class LoadImage:
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path) img = Image.open(image_path)
i = ImageOps.exif_transpose(i) output_images = []
image = i.convert("RGB") output_masks = []
image = np.array(image).astype(np.float32) / 255.0 for i in ImageSequence.Iterator(img):
image = torch.from_numpy(image)[None,] i = ImageOps.exif_transpose(i)
if 'A' in i.getbands(): if i.mode == 'I':
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 i = i.point(lambda i: i * (1 / 255))
mask = 1. - torch.from_numpy(mask) image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") output_image = output_images[0]
return (image, mask.unsqueeze(0)) output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
@ -1457,6 +1543,8 @@ class LoadImageMask:
i = Image.open(image_path) i = Image.open(image_path)
i = ImageOps.exif_transpose(i) i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"): if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA") i = i.convert("RGBA")
mask = None mask = None
c = channel[0].upper() c = channel[0].upper()
@ -1478,13 +1566,10 @@ class LoadImageMask:
return m.digest().hex() return m.digest().hex()
@classmethod @classmethod
def VALIDATE_INPUTS(s, image, channel): def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image): if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image) return "Invalid image file: {}".format(image)
if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)
return True return True
class ImageScale: class ImageScale:
@ -1614,10 +1699,11 @@ class ImagePadForOutpaint:
def expand_image(self, image, left, top, right, bottom, feathering): def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size() d1, d2, d3, d4 = image.size()
new_image = torch.zeros( new_image = torch.ones(
(d1, d2 + top + bottom, d3 + left + right, d4), (d1, d2 + top + bottom, d3 + left + right, d4),
dtype=torch.float32, dtype=torch.float32,
) ) * 0.5
new_image[:, top:top + d2, left:left + d3, :] = image new_image[:, top:top + d2, left:left + d3, :] = image
mask = torch.ones( mask = torch.ones(
@ -1683,6 +1769,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningConcat": ConditioningConcat, "ConditioningConcat": ConditioningConcat,
"ConditioningSetArea": ConditioningSetArea, "ConditioningSetArea": ConditioningSetArea,
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
"ConditioningSetMask": ConditioningSetMask, "ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced, "KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask, "SetLatentNoiseMask": SetLatentNoiseMask,
@ -1709,6 +1796,7 @@ NODE_CLASS_MAPPINGS = {
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader, "GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply, "GLIGENTextBoxApply": GLIGENTextBoxApply,
"InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
@ -1812,11 +1900,11 @@ def load_custom_node(module_path, ignore=set()):
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True return True
else: else:
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False return False
except Exception as e: except Exception as e:
print(traceback.format_exc()) logging.warning(traceback.format_exc())
print(f"Cannot import {module_path} module for custom nodes:", e) logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
return False return False
def load_custom_nodes(): def load_custom_nodes():
@ -1837,14 +1925,14 @@ def load_custom_nodes():
node_import_times.append((time.perf_counter() - time_before, module_path, success)) node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0: if len(node_import_times) > 0:
print("\nImport times for custom nodes:") logging.info("\nImport times for custom nodes:")
for n in sorted(node_import_times): for n in sorted(node_import_times):
if n[2]: if n[2]:
import_message = "" import_message = ""
else: else:
import_message = " (IMPORT FAILED)" import_message = " (IMPORT FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
print() logging.info("")
def init_custom_nodes(): def init_custom_nodes():
extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras") extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras")
@ -1867,9 +1955,31 @@ def init_custom_nodes():
"nodes_model_downscale.py", "nodes_model_downscale.py",
"nodes_images.py", "nodes_images.py",
"nodes_video_model.py", "nodes_video_model.py",
"nodes_sag.py",
"nodes_perpneg.py",
"nodes_stable3d.py",
"nodes_sdupscale.py",
"nodes_photomaker.py",
"nodes_cond.py",
"nodes_morphology.py",
"nodes_stable_cascade.py",
"nodes_differential_diffusion.py",
] ]
import_failed = []
for node_file in extras_files: for node_file in extras_files:
load_custom_node(os.path.join(extras_dir, node_file)) if not load_custom_node(os.path.join(extras_dir, node_file)):
import_failed.append(node_file)
load_custom_nodes() load_custom_nodes()
if len(import_failed) > 0:
logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
for node in import_failed:
logging.warning("IMPORT FAILED: {}".format(node))
logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.")
if args.windows_standalone_build:
logging.warning("Please run the update script: update/update_comfyui.bat")
else:
logging.warning("Please do a: pip install -r requirements.txt")
logging.warning("")

3
requirements.txt

@ -1,13 +1,14 @@
torch torch
torchsde torchsde
torchvision
einops einops
transformers>=4.25.1 transformers>=4.25.1
safetensors>=0.3.0 safetensors>=0.3.0
aiohttp aiohttp
accelerate
pyyaml pyyaml
Pillow Pillow
scipy scipy
tqdm tqdm
psutil psutil
kornia>=0.7.1
spandrel==0.3.1 spandrel==0.3.1

159
script_examples/websockets_api_example_ws_images.py

@ -0,0 +1,159 @@
#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without
#them being saved to disk
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
server_address = "127.0.0.1:8188"
client_id = str(uuid.uuid4())
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)['prompt_id']
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['prompt_id'] == prompt_id:
if data['node'] is None:
break #Execution is done
else:
current_node = data['node']
else:
if current_node == 'save_image_websocket_node':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
return output_images
prompt_text = """
{
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": 8,
"denoise": 1,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "euler",
"scheduler": "normal",
"seed": 8566257,
"steps": 20
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": 512,
"width": 512
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "masterpiece best quality girl"
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "bad hands"
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"save_image_websocket_node": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": [
"8",
0
]
}
}
}
"""
prompt = json.loads(prompt_text)
#set the text prompt for our positive CLIPTextEncode
prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
#set the seed for our KSampler node
prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt)
#Commented out code to display the output images:
# for node_id in images:
# for image_data in images[node_id]:
# from PIL import Image
# import io
# image = Image.open(io.BytesIO(image_data))
# image.show()

58
server.py

@ -15,21 +15,16 @@ from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
try: import aiohttp
import aiohttp from aiohttp import web
from aiohttp import web import logging
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()
import mimetypes import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from app.user_manager import UserManager
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
@ -39,7 +34,7 @@ async def send_socket_catch_exception(function, message):
try: try:
await function(message) await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err: except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
print("send error:", err) logging.warning("send error: {}".format(err))
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
@ -72,6 +67,7 @@ class PromptServer():
mimetypes.init() mimetypes.init()
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.supports = ["custom_nodes_from_web"] self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
@ -116,7 +112,7 @@ class PromptServer():
async for msg in ws: async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR: if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception()) logging.warning('ws connection closed with exception %s' % ws.exception())
finally: finally:
self.sockets.pop(sid, None) self.sockets.pop(sid, None)
return ws return ws
@ -417,8 +413,8 @@ class PromptServer():
try: try:
out[x] = node_info(x) out[x] = node_info(x)
except Exception as e: except Exception as e:
print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr) logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
traceback.print_exc() logging.error(traceback.format_exc())
return web.json_response(out) return web.json_response(out)
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
@ -451,7 +447,7 @@ class PromptServer():
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
print("got prompt") logging.info("got prompt")
resp_code = 200 resp_code = 200
out_string = "" out_string = ""
json_data = await request.json() json_data = await request.json()
@ -483,7 +479,7 @@ class PromptServer():
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response) return web.json_response(response)
else: else:
print("invalid prompt:", valid[1]) logging.warning("invalid prompt: {}".format(valid[1]))
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else: else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400) return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
@ -507,6 +503,17 @@ class PromptServer():
nodes.interrupt_processing() nodes.interrupt_processing()
return web.Response(status=200) return web.Response(status=200)
@routes.post("/free")
async def post_free(request):
json_data = await request.json()
unload_models = json_data.get("unload_models", False)
free_memory = json_data.get("free_memory", False)
if unload_models:
self.prompt_queue.set_flag("unload_models", unload_models)
if free_memory:
self.prompt_queue.set_flag("free_memory", free_memory)
return web.Response(status=200)
@routes.post("/history") @routes.post("/history")
async def post_history(request): async def post_history(request):
json_data = await request.json() json_data = await request.json()
@ -521,15 +528,16 @@ class PromptServer():
return web.Response(status=200) return web.Response(status=200)
def add_routes(self): def add_routes(self):
self.user_manager.add_routes(self.routes)
self.app.add_routes(self.routes) self.app.add_routes(self.routes)
for name, dir in nodes.EXTENSION_WEB_DIRS.items(): for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([ self.app.add_routes([
web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True), web.static('/extensions/' + urllib.parse.quote(name), dir),
]) ])
self.app.add_routes([ self.app.add_routes([
web.static('/', self.web_root, follow_symlinks=True), web.static('/', self.web_root),
]) ])
def get_queue_info(self): def get_queue_info(self):
@ -584,7 +592,8 @@ class PromptServer():
message = self.encode_bytes(event, data) message = self.encode_bytes(event, data)
if sid is None: if sid is None:
for ws in self.sockets.values(): sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_bytes, message) await send_socket_catch_exception(ws.send_bytes, message)
elif sid in self.sockets: elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_bytes, message) await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
@ -593,7 +602,8 @@ class PromptServer():
message = {"type": event, "data": data} message = {"type": event, "data": data}
if sid is None: if sid is None:
for ws in self.sockets.values(): sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_json, message) await send_socket_catch_exception(ws.send_json, message)
elif sid in self.sockets: elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message) await send_socket_catch_exception(self.sockets[sid].send_json, message)
@ -616,11 +626,9 @@ class PromptServer():
site = web.TCPSite(runner, address, port) site = web.TCPSite(runner, address, port)
await site.start() await site.start()
if address == '':
address = '0.0.0.0'
if verbose: if verbose:
print("Starting server\n") logging.info("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port)) logging.info("To see the GUI go to: http://{}:{}".format(address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(address, port)
@ -632,7 +640,7 @@ class PromptServer():
try: try:
json_data = handler(json_data) json_data = handler(json_data)
except Exception as e: except Exception as e:
print(f"[ERROR] An error occurred during the on_prompt_handler processing") logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
traceback.print_exc() logging.warning(traceback.format_exc())
return json_data return json_data

9
tests-ui/afterSetup.js

@ -0,0 +1,9 @@
const { start } = require("./utils");
const lg = require("./utils/litegraph");
// Load things once per test file before to ensure its all warmed up for the tests
beforeAll(async () => {
lg.setup(global);
await start({ resetEnv: true });
lg.teardown(global);
});

3
tests-ui/babel.config.json

@ -1,3 +1,4 @@
{ {
"presets": ["@babel/preset-env"] "presets": ["@babel/preset-env"],
"plugins": ["babel-plugin-transform-import-meta"]
} }

2
tests-ui/jest.config.js

@ -2,8 +2,10 @@
const config = { const config = {
testEnvironment: "jsdom", testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"], setupFiles: ["./globalSetup.js"],
setupFilesAfterEnv: ["./afterSetup.js"],
clearMocks: true, clearMocks: true,
resetModules: true, resetModules: true,
testTimeout: 10000
}; };
module.exports = config; module.exports = config;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save