Browse Source

Merge remote-tracking branch 'comfyanonymous/master' into replace-chainner-models-spandrel

pull/2146/head
Joey Ballentine 8 months ago
parent
commit
8796bde936
  1. 3
      .ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat
  2. 2
      .ci/nightly/windows_base_files/run_nvidia_gpu.bat
  3. 49
      .ci/update_windows/update.py
  4. 8
      .ci/update_windows/update_comfyui.bat
  5. 3
      .ci/update_windows/update_comfyui_and_python_dependencies.bat
  6. 11
      .ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat
  7. 2
      .github/workflows/test-ui.yaml
  8. 71
      .github/workflows/windows_release_cu118_dependencies.yml
  9. 37
      .github/workflows/windows_release_cu118_dependencies_2.yml
  10. 79
      .github/workflows/windows_release_cu118_package.yml
  11. 7
      .github/workflows/windows_release_dependencies.yml
  12. 33
      .github/workflows/windows_release_nightly_pytorch.yml
  13. 2
      .github/workflows/windows_release_package.yml
  14. 1
      .gitignore
  15. 9
      .vscode/settings.json
  16. 17
      README.md
  17. 54
      app/app_settings.py
  18. 140
      app/user_manager.py
  19. 34
      comfy/cldm/cldm.py
  20. 22
      comfy/cli_args.py
  21. 194
      comfy/clip_model.py
  22. 73
      comfy/clip_vision.py
  23. 1
      comfy/conds.py
  24. 143
      comfy/controlnet.py
  25. 11
      comfy/diffusers_convert.py
  26. 1
      comfy/diffusers_load.py
  27. 37
      comfy/extra_samplers/uni_pc.py
  28. 54
      comfy/gligen.py
  29. 2
      comfy/k_diffusion/sampling.py
  30. 69
      comfy/latent_formats.py
  31. 161
      comfy/ldm/cascade/common.py
  32. 93
      comfy/ldm/cascade/controlnet.py
  33. 255
      comfy/ldm/cascade/stage_a.py
  34. 257
      comfy/ldm/cascade/stage_b.py
  35. 274
      comfy/ldm/cascade/stage_c.py
  36. 95
      comfy/ldm/cascade/stage_c_coder.py
  37. 5
      comfy/ldm/models/autoencoder.py
  38. 123
      comfy/ldm/modules/attention.py
  39. 54
      comfy/ldm/modules/diffusionmodules/model.py
  40. 96
      comfy/ldm/modules/diffusionmodules/openaimodel.py
  41. 16
      comfy/ldm/modules/diffusionmodules/upscaling.py
  42. 67
      comfy/ldm/modules/diffusionmodules/util.py
  43. 8
      comfy/ldm/modules/encoders/noise_aug_modules.py
  44. 31
      comfy/ldm/modules/sub_quadratic_attention.py
  45. 13
      comfy/ldm/modules/temporal_ae.py
  46. 46
      comfy/lora.py
  47. 271
      comfy/model_base.py
  48. 82
      comfy/model_detection.py
  49. 316
      comfy/model_management.py
  50. 252
      comfy/model_patcher.py
  51. 98
      comfy/model_sampling.py
  52. 201
      comfy/ops.py
  53. 12
      comfy/sample.py
  54. 532
      comfy/samplers.py
  55. 215
      comfy/sd.py
  56. 150
      comfy/sd1_clip.py
  57. 6
      comfy/sd2_clip.py
  58. 40
      comfy/sdxl_clip.py
  59. 238
      comfy/supported_models.py
  60. 21
      comfy/supported_models_base.py
  61. 5
      comfy/taesd/taesd.py
  62. 62
      comfy/utils.py
  63. 272
      comfy_extras/nodes_canny.py
  64. 25
      comfy_extras/nodes_cond.py
  65. 104
      comfy_extras/nodes_custom_sampler.py
  66. 42
      comfy_extras/nodes_differential_diffusion.py
  67. 10
      comfy_extras/nodes_freelunch.py
  68. 3
      comfy_extras/nodes_hypernetwork.py
  69. 36
      comfy_extras/nodes_hypertile.py
  70. 26
      comfy_extras/nodes_images.py
  71. 49
      comfy_extras/nodes_latent.py
  72. 22
      comfy_extras/nodes_mask.py
  73. 96
      comfy_extras/nodes_model_advanced.py
  74. 129
      comfy_extras/nodes_model_merging.py
  75. 49
      comfy_extras/nodes_morphology.py
  76. 55
      comfy_extras/nodes_perpneg.py
  77. 187
      comfy_extras/nodes_photomaker.py
  78. 3
      comfy_extras/nodes_post_processing.py
  79. 30
      comfy_extras/nodes_rebatch.py
  80. 170
      comfy_extras/nodes_sag.py
  81. 47
      comfy_extras/nodes_sdupscale.py
  82. 143
      comfy_extras/nodes_stable3d.py
  83. 140
      comfy_extras/nodes_stable_cascade.py
  84. 45
      comfy_extras/nodes_video_model.py
  85. 10
      cuda_malloc.py
  86. 16
      custom_nodes/example_node.py.example
  87. 45
      custom_nodes/websocket_image_save.py
  88. 172
      execution.py
  89. 22
      folder_paths.py
  90. 3
      latent_preview.py
  91. 58
      main.py
  92. 0
      models/photomaker/put_photomaker_models_here
  93. 35
      new_updater.py
  94. 212
      nodes.py
  95. 3
      requirements.txt
  96. 159
      script_examples/websockets_api_example_ws_images.py
  97. 58
      server.py
  98. 9
      tests-ui/afterSetup.js
  99. 3
      tests-ui/babel.config.json
  100. 2
      tests-ui/jest.config.js
  101. Some files were not shown because too many files have changed in this diff Show More

3
.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat

@ -1,3 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
pause

2
.ci/nightly/windows_base_files/run_nvidia_gpu.bat

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention
pause

49
.ci/update_windows/update.py

@ -1,6 +1,9 @@
import pygit2
from datetime import datetime
import sys
import os
import shutil
import filecmp
def pull(repo, remote_name='origin', branch='master'):
for remote in repo.remotes:
@ -42,7 +45,8 @@ def pull(repo, remote_name='origin', branch='master'):
raise AssertionError('Unknown merge analysis result')
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
repo = pygit2.Repository(str(sys.argv[1]))
repo_path = str(sys.argv[1])
repo = pygit2.Repository(repo_path)
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:
print("stashing current changes")
@ -51,7 +55,10 @@ except KeyError:
print("nothing to stash")
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name))
repo.branches.local.create(backup_branch_name, repo.head.peel())
try:
repo.branches.local.create(backup_branch_name, repo.head.peel())
except:
pass
print("checking out master branch")
branch = repo.lookup_branch('master')
@ -63,3 +70,41 @@ pull(repo)
print("Done!")
self_update = True
if len(sys.argv) > 2:
self_update = '--skip_self_update' not in sys.argv
update_py_path = os.path.realpath(__file__)
repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py")
cur_path = os.path.dirname(update_py_path)
req_path = os.path.join(cur_path, "current_requirements.txt")
repo_req_path = os.path.join(repo_path, "requirements.txt")
def files_equal(file1, file2):
try:
return filecmp.cmp(file1, file2, shallow=False)
except:
return False
def file_size(f):
try:
return os.path.getsize(f)
except:
return 0
if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10:
shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py"))
exit()
if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
import subprocess
try:
subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path])
shutil.copy(repo_req_path, req_path)
except:
pass

8
.ci/update_windows/update_comfyui.bat

@ -1,2 +1,8 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\
pause
if exist update_new.py (
move /y update_new.py update.py
echo Running updater again since it got updated.
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update
)
if "%~1"=="" pause

3
.ci/update_windows/update_comfyui_and_python_dependencies.bat

@ -1,3 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 xformers -r ../ComfyUI/requirements.txt pygit2
pause

11
.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat

@ -1,11 +0,0 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\
echo
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
echo You should not be running this anyways unless you really have to
echo
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo
pause
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers -r ../ComfyUI/requirements.txt pygit2
pause

2
.github/workflows/test-ui.yaml

@ -22,5 +22,5 @@ jobs:
run: |
npm ci
npm run test:generate
npm test
npm test -- --verbose
working-directory: ./tests-ui

71
.github/workflows/windows_release_cu118_dependencies.yml

@ -1,71 +0,0 @@
name: "Windows Release cu118 dependencies"
on:
workflow_dispatch:
# push:
# branches:
# - master
jobs:
build_dependencies:
env:
# you need at least cuda 5.0 for some of the stuff compiled here.
TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6 8.9"
FORCE_CUDA: 1
MAX_JOBS: 1 # will crash otherwise
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
XFORMERS_BUILD_TYPE: "Release"
runs-on: windows-latest
steps:
- name: Cache Built Dependencies
uses: actions/cache@v3
id: cache-cu118_python_stuff
with:
path: cu118_python_deps.tar
key: ${{ runner.os }}-build-cu118
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
uses: actions/checkout@v3
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
uses: actions/setup-python@v4
with:
python-version: '3.10.9'
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
uses: comfyanonymous/cuda-toolkit@test
id: cuda-toolkit
with:
cuda: '11.8.0'
# copied from xformers github
- name: Setup MSVC
uses: ilammy/msvc-dev-cmd@v1
- name: Configure Pagefile
# windows runners will OOM with many CUDA architectures
# we cheat here with a page file
uses: al-cheb/configure-pagefile-action@v1.3
with:
minimum-size: 2GB
# really unfortunate: https://github.com/ilammy/msvc-dev-cmd#name-conflicts-with-shell-bash
- name: Remove link.exe
shell: bash
run: rm /usr/bin/link
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
shell: bash
run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
git clone --recurse-submodules https://github.com/facebookresearch/xformers.git
cd xformers
python -m pip install --no-cache-dir wheel setuptools twine
echo building xformers
python setup.py bdist_wheel -d ../temp_wheel_dir/
cd ..
rm -rf xformers
ls -lah temp_wheel_dir
mv temp_wheel_dir cu118_python_deps
tar cf cu118_python_deps.tar cu118_python_deps

37
.github/workflows/windows_release_cu118_dependencies_2.yml

@ -1,37 +0,0 @@
name: "Windows Release cu118 dependencies 2"
on:
workflow_dispatch:
inputs:
xformers:
description: 'xformers version'
required: true
type: string
default: "xformers"
# push:
# branches:
# - master
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10.9'
- shell: bash
run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir cu118_python_deps
tar cf cu118_python_deps.tar cu118_python_deps
- uses: actions/cache/save@v3
with:
path: cu118_python_deps.tar
key: ${{ runner.os }}-build-cu118

79
.github/workflows/windows_release_cu118_package.yml

@ -1,79 +0,0 @@
name: "Windows Release cu118 packaging"
on:
workflow_dispatch:
# push:
# branches:
# - master
jobs:
package_comfyui:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
runs-on: windows-latest
steps:
- uses: actions/cache/restore@v3
id: cache
with:
path: cu118_python_deps.tar
key: ${{ runner.os }}-build-cu118
- shell: bash
run: |
mv cu118_python_deps.tar ../
cd ..
tar xf cu118_python_deps.tar
pwd
ls
- uses: actions/checkout@v3
with:
fetch-depth: 0
persist-credentials: false
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
echo 'import site' >> ./python310._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../cu118_python_deps/*
sed -i '1i../ComfyUI' ./python310._pth
cd ..
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
cd ComfyUI_windows_portable
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/update_windows_cu118/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
ls
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
tag: "latest"
overwrite: true

7
.github/workflows/windows_release_dependencies.yml

@ -24,7 +24,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "6"
default: "8"
# push:
# branches:
# - master
@ -41,10 +41,9 @@ jobs:
- shell: bash
run: |
echo "@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\\
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
echo You should not be running this anyways unless you really have to
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -

33
.github/workflows/windows_release_nightly_pytorch.yml

@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch"
on:
workflow_dispatch:
inputs:
cu:
description: 'cuda version'
required: true
type: string
default: "121"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "2"
# push:
# branches:
# - master
@ -20,21 +38,21 @@ jobs:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: '3.11.6'
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
echo 'import site' >> ./python311._pth
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python311._pth
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd ..
git clone https://github.com/comfyanonymous/taesd
@ -49,9 +67,10 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
echo "call update_comfyui.bat nopause
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > ./update/update_comfyui_and_python_dependencies.bat
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch

2
.github/workflows/windows_release_package.yml

@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "6"
default: "8"
# push:
# branches:
# - master

1
.gitignore vendored

@ -15,3 +15,4 @@ venv/
!/web/extensions/logging.js.example
!/web/extensions/core/
/tests-ui/data/object_info.json
/user/

9
.vscode/settings.json vendored

@ -1,9 +0,0 @@
{
"path-intellisense.mappings": {
"../": "${workspaceFolder}/web/extensions/core"
},
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

17
README.md

@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
@ -77,6 +77,8 @@ There is a portable standalone build for Windows that should work for running on
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
If you have trouble extracting it, right click the file -> properties -> unblock
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -93,23 +95,26 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier.
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7```
This is the command to install the nightly with ROCm 5.7 that might have some performance improvements:
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```
### NVIDIA
Nvidia users should install pytorch using this command:
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
This is the command to install pytorch nightly instead which might have performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
#### Troubleshooting
If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with:

54
app/app_settings.py

@ -0,0 +1,54 @@
import os
import json
from aiohttp import web
class AppSettings():
def __init__(self, user_manager):
self.user_manager = user_manager
def get_settings(self, request):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
if os.path.isfile(file):
with open(file) as f:
return json.load(f)
else:
return {}
def save_settings(self, request, settings):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
with open(file, "w") as f:
f.write(json.dumps(settings, indent=4))
def add_routes(self, routes):
@routes.get("/settings")
async def get_settings(request):
return web.json_response(self.get_settings(request))
@routes.get("/settings/{id}")
async def get_setting(request):
value = None
settings = self.get_settings(request)
setting_id = request.match_info.get("id", None)
if setting_id and setting_id in settings:
value = settings[setting_id]
return web.json_response(value)
@routes.post("/settings")
async def post_settings(request):
settings = self.get_settings(request)
new_settings = await request.json()
self.save_settings(request, {**settings, **new_settings})
return web.Response(status=200)
@routes.post("/settings/{id}")
async def post_setting(request):
setting_id = request.match_info.get("id", None)
if not setting_id:
return web.Response(status=400)
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)

140
app/user_manager.py

@ -0,0 +1,140 @@
import json
import os
import re
import uuid
from aiohttp import web
from comfy.cli_args import args
from folder_paths import user_directory
from .app_settings import AppSettings
default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class UserManager():
def __init__(self):
global user_directory
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(users_file):
with open(users_file) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
if user not in self.users:
raise KeyError("Unknown user: " + user)
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory
if type == "userdata":
root_dir = user_directory
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
return None
parent = user_root
if file is not None:
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
return None
if create_dir and not os.path.exists(parent):
os.mkdir(parent)
return path
def add_user(self, name):
name = name.strip()
if not name:
raise ValueError("username not provided")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
global users_file
with open(users_file, "w") as f:
json.dump(self.users, f)
return user_id
def add_routes(self, routes):
self.settings.add_routes(routes)
@routes.get("/users")
async def get_users(request):
if args.multi_user:
return web.json_response({"storage": "server", "users": self.users})
else:
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
return web.json_response({
"storage": "server",
"migrated": os.path.exists(user_dir)
})
@routes.post("/users")
async def post_users(request):
body = await request.json()
username = body["username"]
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
user_id = self.add_user(username)
return web.json_response(user_id)
@routes.get("/userdata/{file}")
async def getuserdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if not os.path.exists(path):
return web.Response(status=404)
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
body = await request.read()
with open(path, "wb") as f:
f.write(body)
return web.Response(status=200)

34
comfy/cldm/cldm.py

@ -53,7 +53,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=comfy.ops,
operations=comfy.ops.disable_weight_init,
**kwargs,
):
super().__init__()
@ -141,24 +141,24 @@ class ControlNet(nn.Module):
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
self.input_hint_block = TimestepEmbedSequential(
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 16, 3, padding=1),
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 32, 3, padding=1),
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 96, 3, padding=1),
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
self._feature_size = model_channels
@ -206,7 +206,7 @@ class ControlNet(nn.Module):
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
@ -234,7 +234,7 @@ class ControlNet(nn.Module):
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
ds *= 2
self._feature_size += ch
@ -276,14 +276,14 @@ class ControlNet(nn.Module):
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch
def make_zero_conv(self, channels, operations=None):
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
@ -295,7 +295,7 @@ class ControlNet(nn.Module):
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
h = x
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)

22
comfy/cli_args.py

@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@ -106,6 +112,11 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:
@ -116,3 +127,10 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)

194
comfy/clip_model.py

@ -0,0 +1,194 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPMLP(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
super().__init__()
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
self.activation = ACTIVATIONS[activation]
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
def forward(self, x, mask=None, optimized_attention=None):
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
dtype=dtype,
device=device
)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
return (x[0], x[1], out)

73
comfy/clip_vision.py

@ -1,64 +1,62 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert, common_upscale
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os
import torch
import contextlib
import json
import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
class Output:
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2]))
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
with open(json_config) as f:
config = json.load(f)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device))
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
for k in outputs:
t = outputs[k]
if t is not None:
if k == 'hidden_states':
outputs["penultimate_hidden_states"] = t[-2].cpu()
outputs["hidden_states"] = None
else:
outputs[k] = t.cpu()
pixel_values = clip_preprocess(image.to(self.load_device)).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
return outputs
def convert_to_transformers(sd, prefix):
@ -82,6 +80,9 @@ def convert_to_transformers(sd, prefix):
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, prefix, "vision_model.", 48)
else:
replace_prefix = {prefix: ""}
sd = state_dict_prefix_replace(sd, replace_prefix)
return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:
print("missing clip vision:", m)
logging.warning("missing clip vision: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:

1
comfy/conds.py

@ -1,4 +1,3 @@
import enum
import torch
import math
import comfy.utils

143
comfy/controlnet.py

@ -1,13 +1,16 @@
import torch
import math
import os
import logging
import comfy.utils
import comfy.model_management
import comfy.model_detection
import comfy.model_patcher
import comfy.ops
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -34,13 +37,15 @@ class ControlBase:
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (0.0, 1.0)
self.global_average_pooling = False
self.timestep_range = None
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
@ -75,6 +80,9 @@ class ControlBase:
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
c.global_average_pooling = self.global_average_pooling
c.compression_ratio = self.compression_ratio
c.upscale_algorithm = self.upscale_algorithm
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@ -123,16 +131,21 @@ class ControlBase:
if o[i] is None:
o[i] = prev_val
else:
o[i] += prev_val
if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] += prev_val
return out
class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.load_device = load_device
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@ -146,28 +159,31 @@ class ControlNet(ControlBase):
else:
return None
dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond['c_crossattn']
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None)
if y is not None:
y = y.to(self.control_model.dtype)
y = y.to(dtype)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
self.copy_to(c)
return c
@ -185,7 +201,7 @@ class ControlNet(ControlBase):
super().cleanup()
class ControlLoraOps:
class Linear(torch.nn.Module):
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
@ -198,12 +214,13 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module):
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
self,
in_channels,
@ -237,16 +254,11 @@ class ControlLoraOps:
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class ControlLora(ControlNet):
@ -260,25 +272,34 @@ class ControlLora(ControlNet):
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.manual_cast_dtype = model.manual_cast_dtype
dtype = model.get_dtype()
self.control_model.to(dtype)
if self.manual_cast_dtype is None:
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
pass
else:
class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
pass
dtype = self.manual_cast_dtype
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
weight = sd[k]
try:
comfy.utils.set_attr(self.control_model, k, weight)
comfy.utils.set_attr_param(self.control_model, k, weight)
except:
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
@ -303,9 +324,10 @@ def load_controlnet(ckpt_path, model=None):
return ControlLora(controlnet_data)
controlnet_config = None
supported_inference_dtypes = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@ -346,7 +368,7 @@ def load_controlnet(ckpt_path, model=None):
leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0:
print("leftover keys:", leftover_keys)
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight'
@ -361,12 +383,24 @@ def load_controlnet(ckpt_path, model=None):
else:
net = load_t2i_adapter(controlnet_data)
if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return net
if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes
controlnet_config = model_config.unet_config
load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
controlnet_config["dtype"] = unet_dtype
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@ -384,7 +418,7 @@ def load_controlnet(ckpt_path, model=None):
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
class WeightsLoader(torch.nn.Module):
pass
@ -393,24 +427,29 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else:
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
control_model = control_model.to(unet_dtype)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, device=None):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
self.t2i_model = t2i_model
self.channels_in = channels_in
self.control_input = None
self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm
def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
@ -430,13 +469,13 @@ class T2IAdapter(ControlBase):
else:
return None
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.control_input = None
self.cond_hint = None
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
@ -455,11 +494,14 @@ class T2IAdapter(ControlBase):
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in)
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
compression_ratio = 8
upscale_algorithm = 'nearest-exact'
if 'adapter' in t2i_data:
t2i_data = t2i_data['adapter']
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
@ -487,13 +529,22 @@ def load_t2i_adapter(t2i_data):
if cin == 256 or cin == 768:
xl = True
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
elif "backbone.0.0.weight" in keys:
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 32
upscale_algorithm = 'bilinear'
elif "backbone.10.blocks.0.weight" in keys:
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 1
upscale_algorithm = 'nearest-exact'
else:
return None
missing, unexpected = model_ad.load_state_dict(t2i_data)
if len(missing) > 0:
print("t2i missing", missing)
logging.warning("t2i missing {}".format(missing))
if len(unexpected) > 0:
print("t2i unexpected", unexpected)
logging.debug("t2i unexpected {}".format(unexpected))
return T2IAdapter(model_ad, model_ad.input_channels)
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)

11
comfy/diffusers_convert.py

@ -1,5 +1,6 @@
import re
import torch
import logging
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
@ -237,8 +238,12 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
text_proj = "transformer.text_projection.weight"
if k.endswith(text_proj):
new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
else:
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:

1
comfy/diffusers_load.py

@ -1,4 +1,3 @@
import json
import os
import comfy.sd

37
comfy/extra_samplers/uni_pc.py

@ -358,9 +358,6 @@ class UniPC:
thresholding=False,
max_val=1.,
variant='bh1',
noise_mask=None,
masked_image=None,
noise=None,
):
"""Construct a UniPC.
@ -372,9 +369,6 @@ class UniPC:
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
self.noise_mask = noise_mask
self.masked_image = masked_image
self.noise = noise
def dynamic_thresholding_fn(self, x0, t=None):
"""
@ -391,10 +385,7 @@ class UniPC:
"""
Return the noise prediction model.
"""
if self.noise_mask is not None:
return self.model(x, t) * self.noise_mask
else:
return self.model(x, t)
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
@ -409,8 +400,6 @@ class UniPC:
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
if self.noise_mask is not None:
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
return x0
def model_fn(self, x, t):
@ -723,8 +712,6 @@ class UniPC:
assert timesteps.shape[0] - 1 == steps
# with torch.no_grad():
for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
@ -766,7 +753,7 @@ class UniPC:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
if callback is not None:
callback(step_index, model_prev_list[-1], x, steps)
callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
else:
raise NotImplementedError()
# if denoise_to_zero:
@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
timesteps = sigmas.clone()
if sigmas[-1] == 0:
timesteps = sigmas[:]
@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
timesteps = sigmas.clone()
ns = SigmaConvert()
if image is not None:
img = image * ns.marginal_alpha(timesteps[0])
if max_denoise:
noise_mult = 1.0
else:
noise_mult = ns.marginal_std(timesteps[0])
img += noise * noise_mult
else:
img = noise
noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
model_type = "noise"
model_fn = model_wrapper(
@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
)
order = min(3, len(timesteps) - 2)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
x /= ns.marginal_alpha(timesteps[-1])
return x
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')

54
comfy/gligen.py

@ -1,8 +1,9 @@
import torch
from torch import nn, einsum
from torch import nn
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
import comfy.ops
ops = comfy.ops.manual_cast
def exists(val):
return val is not None
@ -22,7 +23,7 @@ def default(val, d):
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@ -35,14 +36,14 @@ class FeedForward(nn.Module):
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module):
query_dim=query_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head)
dim_head=d_head,
operations=ops)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.norm1 = ops.LayerNorm(query_dim)
self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module):
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.linear = ops.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=query_dim,
heads=n_heads,
dim_head=d_head)
dim_head=d_head,
operations=ops)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.norm1 = ops.LayerNorm(query_dim)
self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module):
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.linear = ops.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.norm1 = ops.LayerNorm(query_dim)
self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@ -201,11 +204,11 @@ class PositionNet(nn.Module):
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
self.linears = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512),
ops.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
ops.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
ops.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(
@ -215,16 +218,15 @@ class PositionNet(nn.Module):
def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
dtype = self.linears[0].weight.dtype
masks = masks.unsqueeze(-1).to(dtype)
positive_embeddings = positive_embeddings.to(dtype)
masks = masks.unsqueeze(-1)
positive_embeddings = positive_embeddings
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \
@ -251,7 +253,7 @@ class Gligen(nn.Module):
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return module(x, objs.to(device=x.device, dtype=x.dtype))
return func
def set_position(self, latent_image_shape, position_params, device):

2
comfy/k_diffusion/sampling.py

@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
x = denoised
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
return x

69
comfy/latent_formats.py

@ -1,3 +1,4 @@
import torch
class LatentFormat:
scale_factor = 1.0
@ -33,3 +34,71 @@ class SDXL(LatentFormat):
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
def __init__(self):
self.scale_factor = 0.5
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
self.latent_rgb_factors = [
[-0.2340, -0.3863, -0.3257],
[ 0.0994, 0.0885, -0.0908],
[-0.2833, -0.2349, -0.3741],
[ 0.2523, -0.0055, -0.1651]
]
class SC_Prior(LatentFormat):
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
[-0.0326, -0.0204, -0.0127],
[-0.1592, -0.0427, 0.0216],
[ 0.0873, 0.0638, -0.0020],
[-0.0602, 0.0442, 0.1304],
[ 0.0800, -0.0313, -0.1796],
[-0.0810, -0.0638, -0.1581],
[ 0.1791, 0.1180, 0.0967],
[ 0.0740, 0.1416, 0.0432],
[-0.1745, -0.1888, -0.1373],
[ 0.2412, 0.1577, 0.0928],
[ 0.1908, 0.0998, 0.0682],
[ 0.0209, 0.0365, -0.0092],
[ 0.0448, -0.0650, -0.1728],
[-0.1658, -0.1045, -0.1308],
[ 0.0542, 0.1545, 0.1325],
[-0.0352, -0.1672, -0.2541]
]
class SC_B(LatentFormat):
def __init__(self):
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195],
[-0.3087, -0.1535, 0.0366],
[ 0.0290, -0.1574, -0.4078]
]

161
comfy/ldm/cascade/common.py

@ -0,0 +1,161 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads)
return self.out_proj(out)
class Attention2D(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
def LayerNorm2d_op(operations):
class LayerNorm2d(operations.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return LayerNorm2d
class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def __init__(self, dim, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
super().__init__()
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
# self.depthwise = SAMBlock(c, num_heads, expansion)
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.channelwise = nn.Sequential(
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
nn.GELU(),
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
nn.Dropout(dropout),
operations.Linear(c * 4, c, dtype=dtype, device=device)
)
def forward(self, x, x_skip=None):
x_res = x
x = self.norm(self.depthwise(x))
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x + x_res
class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.self_attn = self_attn
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
self.kv_mapper = nn.Sequential(
nn.SiLU(),
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
def forward(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
return x
class FeedForwardBlock(nn.Module):
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.channelwise = nn.Sequential(
operations.Linear(c, c * 4, dtype=dtype, device=device),
nn.GELU(),
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
nn.Dropout(dropout),
operations.Linear(c * 4, c, dtype=dtype, device=device)
)
def forward(self, x):
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
super().__init__()
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
def forward(self, x, t):
t = t.chunk(len(self.conds) + 1, dim=1)
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
for i, c in enumerate(self.conds):
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
a, b = a + ac, b + bc
return x * (1 + a) + b

93
comfy/ldm/cascade/controlnet.py

@ -0,0 +1,93 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
from .common import LayerNorm2d_op
class CNetResBlock(nn.Module):
def __init__(self, c, dtype=None, device=None, operations=None):
super().__init__()
self.blocks = nn.Sequential(
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c, c, kernel_size=3, padding=1),
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c, c, kernel_size=3, padding=1),
)
def forward(self, x):
return x + self.blocks(x)
class ControlNet(nn.Module):
def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
super().__init__()
if bottleneck_mode is None:
bottleneck_mode = 'effnet'
self.proj_blocks = proj_blocks
if bottleneck_mode == 'effnet':
embd_channels = 1280
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
if c_in != 3:
in_weights = self.backbone[0][0].weight.data
self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
if c_in > 3:
# nn.init.constant_(self.backbone[0][0].weight, 0)
self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
else:
self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
elif bottleneck_mode == 'simple':
embd_channels = c_in
self.backbone = nn.Sequential(
operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
)
elif bottleneck_mode == 'large':
self.backbone = nn.Sequential(
operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
*[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
)
embd_channels = 1280
else:
raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
self.projections = nn.ModuleList()
for _ in range(len(proj_blocks)):
self.projections.append(nn.Sequential(
operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
))
# nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
self.xl = False
self.input_channels = c_in
self.unshuffle_amount = 8
def forward(self, x):
x = self.backbone(x)
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
for i, idx in enumerate(self.proj_blocks):
proj_outputs[idx] = self.projections[i](x)
return proj_outputs

255
comfy/ldm/cascade/stage_a.py

@ -0,0 +1,255 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from torch import nn
from torch.autograd import Function
class vector_quantize(Function):
@staticmethod
def forward(ctx, x, codebook):
with torch.no_grad():
codebook_sqr = torch.sum(codebook ** 2, dim=1)
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
_, indices = dist.min(dim=1)
ctx.save_for_backward(indices, codebook)
ctx.mark_non_differentiable(indices)
nn = torch.index_select(codebook, 0, indices)
return nn, indices
@staticmethod
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output)
return (grad_inputs, grad_codebook)
class VectorQuantize(nn.Module):
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
"""
Takes an input of variable size (as long as the last dimension matches the embedding size).
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
with the same size as the input, vq and commitment components for the loss as a touple
in the second output and the indices of the quantized vectors in the third:
quantized, (vq_loss, commit_loss), indices
"""
super(VectorQuantize, self).__init__()
self.codebook = nn.Embedding(k, embedding_size)
self.codebook.weight.data.uniform_(-1./k, 1./k)
self.vq = vector_quantize.apply
self.ema_decay = ema_decay
self.ema_loss = ema_loss
if ema_loss:
self.register_buffer('ema_element_count', torch.ones(k))
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
def _laplace_smoothing(self, x, epsilon):
n = torch.sum(x)
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
def _updateEMA(self, z_e_x, indices):
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
def idx2vq(self, idx, dim=-1):
q_idx = self.codebook(idx)
if dim != -1:
q_idx = q_idx.movedim(-1, dim)
return q_idx
def forward(self, x, get_losses=True, dim=-1):
if dim != -1:
x = x.movedim(dim, -1)
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
vq_loss, commit_loss = None, None
if self.ema_loss and self.training:
self._updateEMA(z_e_x.detach(), indices.detach())
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
if get_losses:
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
z_q_x = z_q_x.view(x.shape)
if dim != -1:
z_q_x = z_q_x.movedim(-1, dim)
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
class ResBlock(nn.Module):
def __init__(self, c, c_hidden):
super().__init__()
# depthwise/attention
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(c, c, kernel_size=3, groups=c)
)
# channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c, c_hidden),
nn.GELU(),
nn.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
def forward(self, x):
mods = self.gammas
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
try:
x = x + self.depthwise(x_temp) * mods[2]
except: #operation not implemented for bf16
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
x = x + self.depthwise[1](x_temp) * mods[2]
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x
class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
super().__init__()
self.c_latent = c_latent
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
# Encoder blocks
self.in_block = nn.Sequential(
nn.PixelUnshuffle(2),
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
)
down_blocks = []
for i in range(levels):
if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = ResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(nn.Sequential(
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
))
self.down_blocks = nn.Sequential(*down_blocks)
self.down_blocks[0]
self.codebook_size = codebook_size
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
# Decoder blocks
up_blocks = [nn.Sequential(
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
)]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1))
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2),
)
def encode(self, x, quantize=False):
x = self.in_block(x)
x = self.down_blocks(x)
if quantize:
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
return qe, x, indices, vq_loss + commit_loss * 0.25
else:
return x
def decode(self, x):
x = self.up_blocks(x)
x = self.out_block(x)
return x
def forward(self, x, quantize=False):
qe, x, _, vq_loss = self.encode(x, quantize)
x = self.decode(qe)
return x, vq_loss
class Discriminator(nn.Module):
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
super().__init__()
d = max(depth - 3, 3)
layers = [
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = c_hidden // (2 ** max((d - i), 0))
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.logits = nn.Sigmoid()
def forward(self, x, cond=None):
x = self.encoder(x)
if cond is not None:
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
x = torch.cat([x, cond], dim=1)
x = self.shuffle(x)
x = self.logits(x)
return x

257
comfy/ldm/cascade/stage_b.py

@ -0,0 +1,257 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import math
import numpy as np
import torch
from torch import nn
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
class StageB(nn.Module):
def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
super().__init__()
self.dtype = dtype
self.c_r = c_r
self.t_conds = t_conds
self.c_clip_seq = c_clip_seq
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
if not isinstance(self_attn, list):
self_attn = [self_attn] * len(c_hidden)
# CONDITIONING
self.effnet_mapper = nn.Sequential(
operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
self.pixels_mapper = nn.Sequential(
operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == 'C':
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'A':
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'F':
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'T':
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
else:
raise Exception(f'Block type {block_type} not supported')
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(c_hidden)):
if i > 0:
self.down_downscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
))
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(blocks[0][i]):
for block_type in level_config[i]:
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
down_block.append(block)
self.down_blocks.append(down_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[0][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
if i > 0:
self.up_upscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
))
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[1][::-1][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
return emb
def gen_c_embeddings(self, clip):
if len(clip.shape) == 2:
clip = clip.unsqueeze(1)
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
align_corners=True)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
t_cond = kwargs.get(c, torch.zeros_like(r))
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
clip = self.gen_c_embeddings(clip)
# Model Blocks
x = self.embedding(x)
x = x + self.effnet_mapper(
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)

274
comfy/ldm/cascade/stage_c.py

@ -0,0 +1,274 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from torch import nn
import numpy as np
import math
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
# from .controlnet import ControlNetDeliverer
class UpDownBlock2d(nn.Module):
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
super().__init__()
assert mode in ['up', 'down']
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
align_corners=True) if enabled else nn.Identity()
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class StageC(nn.Module):
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
dtype=None, device=None, operations=None):
super().__init__()
self.dtype = dtype
self.c_r = c_r
self.t_conds = t_conds
self.c_clip_seq = c_clip_seq
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
if not isinstance(self_attn, list):
self_attn = [self_attn] * len(c_hidden)
# CONDITIONING
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == 'C':
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'A':
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'F':
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'T':
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
else:
raise Exception(f'Block type {block_type} not supported')
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(c_hidden)):
if i > 0:
self.down_downscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
))
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(blocks[0][i]):
for block_type in level_config[i]:
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
down_block.append(block)
self.down_blocks.append(down_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[0][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
if i > 0:
self.up_upscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
))
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[1][::-1][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
return emb
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
clip_txt = self.clip_txt_mapper(clip_txt)
if len(clip_txt_pooled.shape) == 2:
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
if len(clip_img.shape) == 2:
clip_img = clip_img.unsqueeze(1)
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
if cnet is not None:
next_cnet = cnet.pop()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
align_corners=True).to(x.dtype)
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
align_corners=True)
if cnet is not None:
next_cnet = cnet.pop()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
align_corners=True).to(x.dtype)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
t_cond = kwargs.get(c, torch.zeros_like(r))
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
if control is not None:
cnet = control.get("input")
else:
cnet = None
# Model Blocks
x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet)
x = self._up_decode(level_outputs, r_embed, clip, cnet)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)

95
comfy/ldm/cascade/stage_c_coder.py

@ -0,0 +1,95 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
# EfficientNet
class EfficientNetEncoder(nn.Module):
def __init__(self, c_latent=16):
super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
self.mapper = nn.Sequential(
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
)
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
def forward(self, x):
x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
o = self.mapper(self.backbone(x))
return o
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__()
self.blocks = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
)
def forward(self, x):
return (self.blocks(x) - 0.5) * 2.0
class StageC_coder(nn.Module):
def __init__(self):
super().__init__()
self.previewer = Previewer()
self.encoder = EfficientNetEncoder()
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.previewer(x)

5
comfy/ldm/models/autoencoder.py

@ -8,6 +8,7 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma
import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True):
@ -161,12 +162,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:

123
comfy/ldm/modules/attention.py

@ -1,12 +1,10 @@
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from functools import partial
import logging
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@ -19,10 +17,11 @@ if model_management.xformers_enabled():
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
# CrossAttn precision handling
if args.dont_upcast_attention:
print("disabling upcasting of attention")
logging.info("disabling upcasting of attention")
_ATTN_PRECISION = "fp16"
else:
_ATTN_PRECISION = "fp32"
@ -55,7 +54,7 @@ def init_(tensor):
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
@ -65,7 +64,7 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
@ -83,16 +82,6 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
@ -113,19 +102,25 @@ def attention_basic(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
if mask.dtype == torch.bool:
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
sim.add_(mask)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
@ -176,6 +171,13 @@ def attention_sub_quad(query, key, value, heads, mask=None):
if query_chunk_size is None:
query_chunk_size = 512
if mask is not None:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
hidden_states = efficient_dot_product_attention(
query,
key,
@ -185,6 +187,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False,
upcast_attention=upcast_attention,
mask=mask,
)
hidden_states = hidden_states.to(dtype)
@ -233,6 +236,13 @@ def attention_split(q, k, v, heads, mask=None):
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
if mask is not None:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
@ -247,6 +257,12 @@ def attention_split(q, k, v, heads, mask=None):
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
@ -259,12 +275,12 @@ def attention_split(q, k, v, heads, mask=None):
model_management.soft_empty_cache(True)
if cleared_cache == False:
cleared_cache = True
print("out of memory error, emptying cache and trying again")
logging.warning("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
raise e
print("out of memory error, increasing steps and trying again", steps)
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
else:
raise e
@ -302,11 +318,14 @@ def attention_xformers(q, k, v, heads, mask=None):
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
@ -331,27 +350,41 @@ def attention_pytorch(q, k, v, heads, mask=None):
optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled():
print("Using xformers cross attention")
logging.info("Using xformers cross attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
logging.info("Using pytorch cross attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
logging.info("Using split optimization for cross attention")
optimized_attention = attention_split
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
optimized_attention_masked = optimized_attention
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else:
return attention_basic
if device == torch.device("cpu"):
return attention_sub_quad
if mask:
return optimized_attention_masked
return optimized_attention
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -384,7 +417,7 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
@ -394,7 +427,7 @@ class BasicTransformerBlock(nn.Module):
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn
@ -414,10 +447,10 @@ class BasicTransformerBlock(nn.Module):
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
@ -553,13 +586,13 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
use_checkpoint=True, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device)
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if not use_linear:
self.proj_in = operations.Conv2d(in_channels,
inner_dim,
@ -627,7 +660,7 @@ class SpatialVideoTransformer(SpatialTransformer):
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops
dtype=None, device=None, operations=ops
):
super().__init__(
in_channels,

54
comfy/ldm/modules/diffusionmodules/model.py

@ -5,9 +5,11 @@ import torch.nn as nn
import numpy as np
from einops import rearrange
from typing import Optional, Any
import logging
from comfy import model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
if model_management.xformers_enabled_vae():
import xformers
@ -40,7 +42,7 @@ def nonlinearity(x):
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
@ -48,7 +50,7 @@ class Upsample(nn.Module):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = comfy.ops.Conv2d(in_channels,
self.conv = ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
@ -78,7 +80,7 @@ class Downsample(nn.Module):
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = comfy.ops.Conv2d(in_channels,
self.conv = ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
@ -105,30 +107,30 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = comfy.ops.Conv2d(in_channels,
self.conv1 = ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = comfy.ops.Linear(temb_channels,
self.temb_proj = ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = comfy.ops.Conv2d(out_channels,
self.conv2 = ops.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
self.conv_shortcut = ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
self.nin_shortcut = ops.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
@ -189,7 +191,7 @@ def slice_attention(q, k, v):
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
return r1
@ -234,7 +236,7 @@ def pytorch_attention(q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
@ -245,35 +247,35 @@ class AttnBlock(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
self.q = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
self.k = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
self.v = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
self.proj_out = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
if model_management.xformers_enabled_vae():
print("Using xformers attention in VAE")
logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
print("Using pytorch attention in VAE")
logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
print("Using split attention in VAE")
logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention
def forward(self, x):
@ -312,14 +314,14 @@ class Model(nn.Module):
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
comfy.ops.Linear(self.ch,
ops.Linear(self.ch,
self.temb_ch),
comfy.ops.Linear(self.temb_ch,
ops.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = comfy.ops.Conv2d(in_channels,
self.conv_in = ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
@ -388,7 +390,7 @@ class Model(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
@ -461,7 +463,7 @@ class Encoder(nn.Module):
self.in_channels = in_channels
# downsampling
self.conv_in = comfy.ops.Conv2d(in_channels,
self.conv_in = ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
@ -506,7 +508,7 @@ class Encoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
@ -541,7 +543,7 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
conv_out_op=comfy.ops.Conv2d,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
**ignorekwargs):
@ -561,11 +563,11 @@ class Decoder(nn.Module):
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
logging.debug("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = comfy.ops.Conv2d(z_channels,
self.conv_in = ops.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,

96
comfy/ldm/modules/diffusionmodules/openaimodel.py

@ -1,24 +1,22 @@
from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from functools import partial
import logging
from .util import (
checkpoint,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.ops
ops = comfy.ops.disable_weight_init
class TimestepBlock(nn.Module):
"""
@ -70,7 +68,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -106,7 +104,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -159,7 +157,7 @@ class ResBlock(TimestepBlock):
skip_t_emb=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__()
self.channels = channels
@ -177,7 +175,7 @@ class ResBlock(TimestepBlock):
padding = kernel_size // 2
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
operations.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
@ -206,12 +204,11 @@ class ResBlock(TimestepBlock):
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
),
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
,
)
if self.out_channels == channels:
@ -285,7 +282,7 @@ class VideoResBlock(ResBlock):
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__(
channels,
@ -363,7 +360,7 @@ def apply_control(h, control, name):
try:
h += ctrl
except:
print("warning control could not be applied", h.shape, ctrl.shape)
logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h
class UNetModel(nn.Module):
@ -435,12 +432,9 @@ class UNetModel(nn.Module):
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None,
operations=comfy.ops,
operations=ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
@ -457,7 +451,6 @@ class UNetModel(nn.Module):
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
@ -492,7 +485,6 @@ class UNetModel(nn.Module):
self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
@ -503,9 +495,9 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
logging.debug("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
@ -582,7 +574,7 @@ class UNetModel(nn.Module):
up=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
@ -716,27 +708,30 @@ class UNetModel(nn.Module):
device=device,
operations=operations
)]
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block = None
if transformer_depth_middle >= -1:
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
@ -810,13 +805,13 @@ class UNetModel(nn.Module):
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
@ -835,21 +830,21 @@ class UNetModel(nn.Module):
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
image_only_indicator = kwargs.get("image_only_indicator", None)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
@ -866,7 +861,8 @@ class UNetModel(nn.Module):
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
if self.middle_block is not None:
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')

16
comfy/ldm/modules/diffusionmodules/upscaling.py

@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module):
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def q_sample(self, x_start, t, noise=None, seed=None):
if noise is None:
if seed is None:
noise = torch.randn_like(x_start)
else:
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
def forward(self, x):
return x, None
@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None):
def forward(self, x, noise_level=None, seed=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level

67
comfy/ldm/modules/diffusionmodules/util.py

@ -16,7 +16,6 @@ import numpy as np
from einops import repeat, rearrange
from comfy.ldm.util import instantiate_from_config
import comfy.ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
@ -47,23 +46,25 @@ class AlphaBlender(nn.Module):
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor
alpha = self.mix_factor.to(device)
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
alpha = torch.sigmoid(self.mix_factor.to(device))
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
if image_only_indicator is None:
alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1")
else:
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
@ -77,7 +78,7 @@ class AlphaBlender(nn.Module):
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
alpha = self.get_alpha(image_only_indicator, x_spatial.device)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
@ -99,7 +100,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
betas = torch.clamp(betas, min=0, max=0.999)
elif schedule == "squaredcos_cap_v2": # used for karlo prior
# return early
@ -114,7 +115,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
return betas
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
@ -273,46 +274,6 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return comfy.ops.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return comfy.ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.

8
comfy/ldm/modules/encoders/noise_aug_modules.py

@ -15,21 +15,21 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def scale(self, x):
# re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1. / self.data_std
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
return x
def unscale(self, x):
# back to original data stats
x = (x * self.data_std) + self.data_mean
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x
def forward(self, x, noise_level=None):
def forward(self, x, noise_level=None, seed=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
x = self.scale(x)
z = self.q_sample(x, noise_level)
z = self.q_sample(x, noise_level, seed=seed)
z = self.unscale(z)
noise_level = self.time_embed(noise_level)
return z, noise_level

31
comfy/ldm/modules/sub_quadratic_attention.py

@ -14,6 +14,7 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
import logging
try:
from typing import Optional, NamedTuple, List, Protocol
@ -61,6 +62,7 @@ def _summarize_chunk(
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
) -> AttnChunk:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
@ -84,6 +86,8 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
attn_weights -= max_score
if mask is not None:
attn_weights += mask
torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value)
@ -96,11 +100,12 @@ def _query_chunk_attention(
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
mask,
) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk:
def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
key_chunk = dynamic_slice(
key_t,
(0, 0, chunk_idx),
@ -111,10 +116,13 @@ def _query_chunk_attention(
(0, chunk_idx, 0),
(batch_x_heads, kv_chunk_size, v_channels_per_head)
)
return summarize_chunk(query, key_chunk, value_chunk)
if mask is not None:
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk
@ -135,6 +143,7 @@ def _get_attention_scores_no_kv_chunking(
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
) -> Tensor:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
@ -156,11 +165,13 @@ def _get_attention_scores_no_kv_chunking(
beta=0,
)
if mask is not None:
attn_scores += mask
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
@ -183,6 +194,7 @@ def efficient_dot_product_attention(
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
mask = None,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
@ -209,6 +221,9 @@ def efficient_dot_product_attention(
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
if mask is not None and len(mask.shape) == 2:
mask = mask.unsqueeze(0)
def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice(
query,
@ -216,6 +231,12 @@ def efficient_dot_product_attention(
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
)
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@ -237,6 +258,7 @@ def efficient_dot_product_attention(
query=query,
key_t=key_t,
value=value,
mask=mask,
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
@ -246,6 +268,7 @@ def efficient_dot_product_attention(
query=get_query_chunk(i * query_chunk_size),
key_t=key_t,
value=value,
mask=get_mask_chunk(i * query_chunk_size)
) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1)
return res

13
comfy/ldm/modules/temporal_ae.py

@ -5,6 +5,7 @@ import torch
from einops import rearrange, repeat
import comfy.ops
ops = comfy.ops.disable_weight_init
from .diffusionmodules.model import (
AttnBlock,
@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock):
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
alpha = self.get_alpha(bs=b // timesteps).to(x.device)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
class AE3DConv(ops.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
self.time_mix_conv = ops.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
@ -130,9 +131,9 @@ class AttnVideoBlock(AttnBlock):
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
comfy.ops.Linear(self.in_channels, time_embed_dim),
ops.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
comfy.ops.Linear(time_embed_dim, self.in_channels),
ops.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
alpha = self.get_alpha().to(x.device)
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge

46
comfy/lora.py

@ -1,4 +1,5 @@
import comfy.utils
import logging
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@ -20,6 +21,12 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
@ -43,7 +50,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@ -64,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@ -116,8 +123,19 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
@ -126,26 +144,26 @@ def load_lora(lora, to_load):
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
for x in lora.keys():
if x not in loaded_keys:
print("lora key not loaded", x)
logging.warning("lora key not loaded: {}".format(x))
return patch_dict
def model_lora_keys_clip(model, key_map={}):
@ -186,6 +204,15 @@ def model_lora_keys_clip(model, key_map={}):
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
key_map["lora_prior_te_text_projection"] = k #cascade lora?
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
# key_map["lora_te_text_projection"] = k
return key_map
@ -196,6 +223,7 @@ def model_lora_keys_unet(model, key_map={}):
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:

271
comfy/model_base.py

@ -1,9 +1,13 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
import comfy.model_management
import comfy.conds
import comfy.ops
from enum import Enum
from . import utils
@ -11,9 +15,11 @@ class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
V_PREDICTION_EDM = 3
STABLE_CASCADE = 4
EDM = 5
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
def model_sampling(model_config, model_type):
@ -26,6 +32,12 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
class ModelSampling(s, c):
pass
@ -34,15 +46,20 @@ def model_sampling(model_config, model_type):
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@ -50,8 +67,8 @@ class BaseModel(torch.nn.Module):
if self.adm_channels is None:
self.adm_channels = 0
self.inpaint_model = False
print("model_type", model_type.name)
print("adm", self.adm_channels)
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
@ -61,15 +78,21 @@ class BaseModel(torch.nn.Module):
context = c_crossattn
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
@ -87,11 +110,29 @@ class BaseModel(torch.nn.Module):
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None)
latent_image = kwargs.get("latent_image", None)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
concat_latent_image = kwargs.get("concat_latent_image", None)
if concat_latent_image is None:
concat_latent_image = kwargs.get("latent_image", None)
else:
concat_latent_image = self.process_latent_in(concat_latent_image)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
@ -104,9 +145,9 @@ class BaseModel(torch.nn.Module):
for ck in concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device))
cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
@ -114,9 +155,23 @@ class BaseModel(torch.nn.Module):
cond_concat.append(blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
if cross_attn_cnet is not None:
out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)
c_concat = kwargs.get("noise_concat", None)
if c_concat is not None:
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
return out
def load_model_weights(self, sd, unet_prefix=""):
@ -129,10 +184,10 @@ class BaseModel(torch.nn.Module):
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)
logging.warning("unet missing: {}".format(m))
if len(u) > 0:
print("unet unexpected:", u)
logging.warning("unet unexpected: {}".format(u))
del to_load
return self
@ -142,39 +197,47 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
if vae_state_dict is not None:
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
for sd in extra_sds:
unet_state_dict.update(sd)
return unet_state_dict
def set_inpaint(self):
self.inpaint_model = True
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = []
weights = []
noise_aug = []
@ -183,7 +246,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed)
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
@ -209,11 +272,11 @@ class SD21UNCLIP(BaseModel):
if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels))
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
else:
return args["pooled_output"]
@ -303,13 +366,157 @@ class SVD_img2vid(BaseModel):
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out
class SV3D_u(SVD_img2vid):
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
class SV3D_p(SVD_img2vid):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder_512 = Timestep(512)
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
azimuth = kwargs.get("azimuth", 0)
noise = kwargs.get("noise", None)
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
return torch.cat(out, dim=1)
class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device)
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
self.cc_projection.weight.copy_(cc_projection_weight)
self.cc_projection.bias.copy_(cc_projection_bias)
def extra_conds(self, **kwargs):
out = {}
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if cross_attn.shape[-1] != 768:
cross_attn = self.cc_projection(cross_attn)
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class SD_X4Upscaler(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_image", None)
noise = kwargs.get("noise", None)
noise_augment = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
seed = kwargs["seed"] - 10
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
if image is None:
image = torch.zeros_like(noise)[:,:3]
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
noise_level = torch.tensor([noise_level], device=device)
if noise_augment > 0:
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
return out
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
clip_text_pooled = kwargs["pooled_output"]
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
if "unclip_conditioning" in kwargs:
embeds = []
for unclip_cond in kwargs["unclip_conditioning"]:
weight = unclip_cond["strength"]
embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
clip_img = torch.cat(embeds, dim=1)
else:
clip_img = torch.zeros((1, 1, 768))
out["clip_img"] = comfy.conds.CONDRegular(clip_img)
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
noise = kwargs.get("noise", None)
clip_text_pooled = kwargs["pooled_output"]
if clip_text_pooled is not None:
out['clip'] = comfy.conds.CONDRegular(clip_text_pooled)
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = comfy.conds.CONDRegular(prior)
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out

82
comfy/model_detection.py

@ -1,5 +1,6 @@
import comfy.supported_models
import comfy.supported_models_base
import logging
def count_blocks(state_dict_keys, prefix_string):
count = 0
@ -28,13 +29,41 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
def detect_unet_config(state_dict, key_prefix, dtype):
def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys())
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
unet_config = {}
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
if text_mapper_name in state_dict_keys:
unet_config['stable_cascade_stage'] = 'c'
w = state_dict[text_mapper_name]
if w.shape[0] == 1536: #stage c lite
unet_config['c_cond'] = 1536
unet_config['c_hidden'] = [1536, 1536]
unet_config['nhead'] = [24, 24]
unet_config['blocks'] = [[4, 12], [12, 4]]
elif w.shape[0] == 2048: #stage c full
unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b'
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
if w.shape[-1] == 640:
unet_config['c_hidden'] = [320, 640, 1280, 1280]
unet_config['nhead'] = [-1, -1, 20, 20]
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
elif w.shape[-1] == 576: #stage b lite
unet_config['c_hidden'] = [320, 576, 1152, 1152]
unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
unet_config = {
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"use_spatial_transformer": True,
"legacy": False
}
@ -46,10 +75,15 @@ def detect_unet_config(state_dict, key_prefix, dtype):
else:
unet_config["adm_in_channels"] = None
unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
out_channels = state_dict[out_key].shape[0]
else:
out_channels = 4
num_res_blocks = []
channel_mult = []
attention_resolutions = []
@ -118,10 +152,13 @@ def detect_unet_config(state_dict, key_prefix, dtype):
channel_mult.append(last_channel_mult)
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
else:
elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = -1
else:
transformer_depth_middle = -2
unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["transformer_depth"] = transformer_depth
@ -150,11 +187,11 @@ def model_config_from_unet_config(unet_config):
if model_config.matches(unet_config):
return model_config(unet_config)
print("no match", unet_config)
logging.error("no match {}".format(unet_config))
return None
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
@ -200,7 +237,7 @@ def convert_config(unet_config):
return new_config
def unet_config_from_diffusers_unet(state_dict, dtype):
def unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {}
transformer_depth = []
@ -208,6 +245,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
@ -216,8 +254,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
attn_res *= 2
if attn_blocks == 0:
transformer_depth.append(0)
transformer_depth.append(0)
for i in range(res_blocks):
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
@ -289,7 +327,25 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B]
for unet_config in supported_models:
matches = True
@ -301,8 +357,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
return convert_config(unet_config)
return None
def model_config_from_diffusers_unet(state_dict, dtype):
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
def model_config_from_diffusers_unet(state_dict):
unet_config = unet_config_from_diffusers_unet(state_dict)
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None

316
comfy/model_management.py

@ -1,4 +1,5 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args
import comfy.utils
@ -28,6 +29,10 @@ total_vram = 0
lowvram_available = True
xpu_available = False
if args.deterministic:
logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False
if args.directml is not None:
import torch_directml
@ -37,7 +42,7 @@ if args.directml is not None:
directml_device = torch_directml.device()
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
@ -113,10 +118,10 @@ def get_total_memory(dev=None, torch_total_too=False):
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
try:
@ -139,12 +144,10 @@ else:
pass
try:
XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION)
logging.info("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"):
print()
print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
print("Please downgrade or upgrade xformers to a different version.")
print()
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
XFORMERS_ENABLED_VAE = False
except:
pass
@ -171,7 +174,7 @@ try:
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported():
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPE = torch.bfloat16
if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@ -182,6 +185,9 @@ except:
if is_intel_xpu():
VAE_DTYPE = torch.bfloat16
if args.cpu_vae:
VAE_DTYPE = torch.float32
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
@ -206,23 +212,16 @@ elif args.highvram or args.gpu_only:
FORCE_FP32 = False
FORCE_FP16 = False
if args.force_fp32:
print("Forcing FP32, if this improves things please report it.")
logging.info("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if args.force_fp16:
print("Forcing FP16.")
logging.info("Forcing FP16.")
FORCE_FP16 = True
if lowvram_available:
try:
import accelerate
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
if cpu_state != CPUState.GPU:
@ -231,12 +230,12 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}")
logging.info(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY:
print("Disabling smart memory management")
logging.info("Disabling smart memory management")
def get_torch_device_name(device):
if hasattr(device, 'type'):
@ -254,19 +253,27 @@ def get_torch_device_name(device):
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Device:", get_torch_device_name(get_torch_device()))
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
print("Could not pick default device.")
logging.warning("Could not pick default device.")
print("VAE dtype:", VAE_DTYPE)
logging.info("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = []
def module_size(module):
module_mem = 0
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel:
def __init__(self, model):
self.model = model
self.model_accelerated = False
self.device = model.load_device
self.weights_loaded = False
def model_memory(self):
return self.model.model_size()
@ -278,38 +285,33 @@ class LoadedModel:
return self.model_memory()
def model_load(self, lowvram_model_memory=0):
patch_model_to = None
if lowvram_model_memory == 0:
patch_model_to = self.device
patch_model_to = self.device
self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype())
load_weights = not self.weights_loaded
try:
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
self.weights_loaded = True
return self.real_model
def model_unload(self):
if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model)
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
def __eq__(self, other):
return self.model is other.model
@ -317,31 +319,57 @@ class LoadedModel:
def minimum_inference_memory():
return (1024 * 1024 * 1024)
def unload_model_clones(model):
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return None
same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
for i in to_unload:
print("unload clone", i)
current_loaded_models.pop(i).model_unload()
logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False
unloaded_model = []
can_unload = []
for i in range(len(current_loaded_models) -1, -1, -1):
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
m = current_loaded_models.pop(i)
m.model_unload()
del m
unloaded_model = True
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
if unloaded_model:
for x in sorted(can_unload):
i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
current_loaded_models.pop(i)
if len(unloaded_model) > 0:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
@ -366,7 +394,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded.append(loaded_model)
else:
if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}")
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
@ -376,17 +404,22 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded)
return
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model)
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
@ -398,14 +431,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024
lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
@ -430,6 +463,13 @@ def dtype_size(dtype):
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
elif dtype == torch.float32:
dtype_size = 4
else:
try:
dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize
pass
return dtype_size
def unet_offload_device():
@ -456,13 +496,44 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
def unet_dtype(device=None, model_params=0):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet:
return torch.bfloat16
if should_use_fp16(device=device, model_params=model_params):
if args.fp16_unet:
return torch.float16
if args.fp8_e4m3fn_unet:
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
bf16_supported = should_use_bf16(inference_device)
if bf16_supported and weight_dtype == torch.bfloat16:
return None
if fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
elif bf16_supported and torch.bfloat16 in supported_dtypes:
return torch.bfloat16
else:
return torch.float32
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
@ -492,12 +563,21 @@ def text_encoder_dtype(device=None):
elif args.fp32_text_enc:
return torch.float32
if should_use_fp16(device, prioritize_performance=False):
if is_device_cpu(device):
return torch.float16
return torch.float16
def intermediate_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.float32
return torch.device("cpu")
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
return get_torch_device()
def vae_offload_device():
@ -515,6 +595,22 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
def supports_dtype(device, dtype): #TODO
if dtype == torch.float32:
return True
if is_device_cpu(device):
return False
if dtype == torch.float16:
return True
if dtype == torch.bfloat16:
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
@ -525,15 +621,17 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True
non_blocking = device_supports_non_blocking(device)
if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy)
return tensor.to(device, copy=copy).to(dtype)
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device).to(dtype)
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(dtype).to(device, copy=copy)
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled():
global directml_enabled
@ -606,19 +704,22 @@ def mps_mode():
global cpu_state
return cpu_state == CPUState.MPS
def is_device_cpu(device):
def is_device_type(device, type):
if hasattr(device, 'type'):
if (device.type == 'cpu'):
if (device.type == type):
return True
return False
def is_device_cpu(device):
return is_device_type(device, 'cpu')
def is_device_mps(device):
if hasattr(device, 'type'):
if (device.type == 'mps'):
return True
return False
return is_device_type(device, 'mps')
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled
if device is not None:
@ -628,9 +729,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if FORCE_FP16:
return True
if device is not None: #TODO
if device is not None:
if is_device_mps(device):
return False
return True
if FORCE_FP32:
return False
@ -638,16 +739,22 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if directml_enabled:
return False
if cpu_mode() or mps_mode():
return False #TODO ?
if mps_mode():
return True
if cpu_mode():
return False
if is_intel_xpu():
return True
if torch.cuda.is_bf16_supported():
if torch.version.hip:
return True
props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
if props.major < 6:
return False
@ -655,12 +762,12 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
#when the model doesn't actually fit on the card
#TODO: actually test if GP106 and others have the same type of behavior
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"]
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series:
if x in props.name.lower():
fp16_works = True
if fp16_works:
if fp16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@ -676,6 +783,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return True
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None: #TODO not sure about mps bf16 support
if is_device_mps(device):
return False
if FORCE_FP32:
return False
if directml_enabled:
return False
if cpu_mode() or mps_mode():
return False
if is_intel_xpu():
return True
if device is None:
device = torch.device("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
return False
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
@ -687,11 +831,11 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
def unload_all_models():
free_memory(1e30, get_torch_device())
def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight
#TODO: might be cleaner to put this somewhere else

252
comfy/model_patcher.py

@ -1,10 +1,24 @@
import torch
import copy
import inspect
import logging
import uuid
import comfy.utils
import comfy.model_management
def apply_weight_decompose(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
@ -23,28 +37,29 @@ class ModelPatcher:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4()
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.size = comfy.model_management.module_size(self.model)
self.model_keys = set(model_sd.keys())
return size
return self.size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
return n
def is_clone(self, other):
@ -52,31 +67,58 @@ class ModelPatcher:
return True
return False
def clone_has_same_weights(self, clone):
if not self.is_clone(clone):
return False
if len(self.patches) == 0 and len(clone.patches) == 0:
return True
if self.patches_uuid == clone.patches_uuid:
if len(self.patches) != len(clone.patches):
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
else:
return True
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function):
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_denoise_mask_function(self, denoise_mask_function):
self.model_options["denoise_mask_function"] = denoise_mask_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
@ -84,11 +126,11 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
@ -142,6 +184,7 @@ class ModelPatcher:
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
self.patches_uuid = uuid.uuid4()
return list(p)
def get_key_patches(self, filter_prefix=None):
@ -167,41 +210,87 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_model(self, device_to=None):
def patch_weight_to_device(self, key, device_to=None):
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = getattr(self.model, k)
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key]
inplace_update = self.weight_inplace_update
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
self.patch_weight_to_device(key, device_to)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
self.model.to(device_to)
self.current_device = device_to
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if lowvram_weight:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches:
m.bias_function = LowVramPatch(weight_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key, device_to)
self.patch_weight_to_device(bias_key, device_to)
m.to(device_to)
mem_counter += comfy.model_management.module_size(m)
logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True
return self.model
def calculate_weight(self, patches, weight, key):
@ -217,15 +306,22 @@ class ModelPatcher:
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
@ -235,9 +331,11 @@ class ModelPatcher:
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
@ -245,6 +343,7 @@ class ModelPatcher:
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
@ -274,15 +373,18 @@ class ModelPatcher:
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
print("ERROR", key, e)
else: #loha
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
@ -303,29 +405,61 @@ class ModelPatcher:
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
print("ERROR", key, e)
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
if self.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
if self.weight_inplace_update:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.model_lowvram = False
self.backup = {}
keys = list(self.backup.keys())
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
if self.weight_inplace_update:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr_param(self.model, k, self.backup[k])
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
setattr(self.model, k, self.object_patches_backup[k])
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}

98
comfy/model_sampling.py

@ -1,5 +1,4 @@
import torch
import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
@ -12,20 +11,43 @@ class EPS:
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
beta_schedule = "linear"
if model_config is not None:
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
@ -35,8 +57,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
alphas_cumprod = torch.cumprod(alphas, dim=0)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@ -51,8 +72,8 @@ class ModelSamplingDiscrete(torch.nn.Module):
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
self.register_buffer('sigmas', sigmas.float())
self.register_buffer('log_sigmas', sigmas.log().float())
@property
def sigma_min(self):
@ -87,8 +108,6 @@ class ModelSamplingDiscrete(torch.nn.Module):
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
@ -96,9 +115,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
sigma_data = sampling_settings.get("sigma_data", 1.0)
self.set_parameters(sigma_min, sigma_max, sigma_data)
def set_sigma_range(self, sigma_min, sigma_max):
def set_parameters(self, sigma_min, sigma_max, sigma_data):
self.sigma_data = sigma_data
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
@ -127,3 +148,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(sampling_settings.get("shift", 1.0))
def set_parameters(self, shift=1.0, cosine_s=8e-3):
self.shift = shift
self.cosine_s = torch.tensor(cosine_s)
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
#This part is just for compatibility with some schedulers in the codebase
self.num_timesteps = 10000
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
for x in range(self.num_timesteps):
t = (x + 1) / self.num_timesteps
sigmas[x] = self.sigma(t)
self.set_sigmas(sigmas)
def sigma(self, timestep):
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
if self.shift != 1.0:
var = alpha_cumprod
logSNR = (var/(1-var)).log()
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
alpha_cumprod = logSNR.sigmoid()
alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
def timestep(self, sigma):
var = 1 / ((sigma * sigma) + 1)
var = var.clamp(0, 1.0)
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return t
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent))

201
comfy/ops.py

@ -1,40 +1,163 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from contextlib import contextmanager
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class Conv3d(torch.nn.Conv3d):
def reset_parameters(self):
return None
def conv_nd(dims, *args, **kwargs):
if dims == 2:
return Conv2d(*args, **kwargs)
elif dims == 3:
return Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear
force_device = device
force_dtype = dtype
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
torch.nn.Linear = linear_with_dtype
try:
yield
finally:
torch.nn.Linear = old_torch_nn_linear
import comfy.model_management
def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.bias_function is not None:
bias = s.bias_function(bias)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.weight_function is not None:
weight = s.weight_function(weight)
return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = None
bias_function = None
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input, output_size=None):
num_spatial_dims = 2
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d):
comfy_cast_weights = True
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True
class LayerNorm(disable_weight_init.LayerNorm):
comfy_cast_weights = True
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
comfy_cast_weights = True

12
comfy/sample.py

@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
@ -47,7 +46,8 @@ def convert_cond(cond):
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0])
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
@ -98,10 +98,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
@ -111,8 +111,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
sigmas = sigmas.to(model.load_device)
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples

532
comfy/samplers.py

@ -1,260 +1,266 @@
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
import torch
import enum
import collections
from comfy import model_management
import math
from comfy import model_base
import comfy.utils
import comfy.conds
import logging
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = conds.get('control', None)
patches = None
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
return cond_obj(input_x, mult, conditioning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
def can_concat_cond(c1, c2):
if c1.input_x.shape != c2.input_x.shape:
return False
def objects_concatable(obj1, obj2):
if (obj1 is None) != (obj2 is None):
return False
if obj1 is not None:
if obj1 is not obj2:
return False
return True
if not objects_concatable(c1.control, c2.control):
return False
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
if 'control' in conds:
control = conds['control']
patches = None
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
if not objects_concatable(c1.patches, c2.patches):
return False
patches['middle_patch'] = [gligen_patch]
return cond_equal_size(c1.conditioning, c2.conditioning)
return (input_x, mult, conditionning, area, control, patches)
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
#control
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
return out
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x += [p[0]]
mult += [p[1]]
c += [p[2]]
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
c['transformer_options'] = transformer_options
COND = 0
UNCOND = 1
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if math.isclose(cond_scale, 1.0):
uncond = None
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args)
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
return uncond + (cond - uncond) * cond_scale
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
@ -267,19 +273,19 @@ class CFGNoisePredictor(torch.nn.Module):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, sigmas):
super().__init__()
self.inner_model = model
self.sigmas = sigmas
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
if denoise_mask is not None:
out *= denoise_mask
if denoise_mask is not None:
out += self.latent_image * latent_mask
out = out * denoise_mask + self.latent_image * latent_mask
return out
def simple_scheduler(model, steps):
@ -294,7 +300,7 @@ def simple_scheduler(model, steps):
def ddim_scheduler(model, steps):
s = model.model_sampling
sigs = []
ss = len(s.sigmas) // steps
ss = max(len(s.sigmas) // steps, 1)
x = 1
while x < len(s.sigmas):
sigs += [float(s.sigmas[x])]
@ -512,14 +518,6 @@ class Sampler:
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
class UNIPCBH2(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
@ -532,7 +530,7 @@ class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
model_k.latent_image = latent_image
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
@ -540,20 +538,15 @@ class KSAMPLER(Sampler):
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples
@ -567,11 +560,11 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
sampler_function = dpm_adaptive_function
else:
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
@ -594,6 +587,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
@ -605,13 +605,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
@ -634,14 +627,14 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True)
else:
print("error invalid scheduler", self.scheduler)
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas
def sampler_object(name):
if name == "uni_pc":
sampler = UNIPC()
sampler = KSAMPLER(uni_pc.sample_unipc)
elif name == "uni_pc_bh2":
sampler = UNIPCBH2()
sampler = KSAMPLER(uni_pc.sample_unipc_bh2)
elif name == "ddim":
sampler = ksampler("euler", inpaint_options={"random": True})
else:
@ -651,6 +644,7 @@ def sampler_object(name):
class KSampler:
SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES
DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2'))
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
@ -669,7 +663,7 @@ class KSampler:
sigmas = None
discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS:
steps += 1
discard_penultimate_sigma = True

215
comfy/sd.py

@ -1,10 +1,12 @@
import torch
import contextlib
import math
from enum import Enum
import logging
from comfy import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
import yaml
import comfy.utils
@ -36,7 +38,7 @@ def load_model_weights(model, sd):
w = sd.pop(x)
del w
if len(m) > 0:
print("missing", m)
logging.warning("missing {}".format(m))
return model
def load_clip_weights(model, sd):
@ -51,7 +53,7 @@ def load_clip_weights(model, sd):
if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.")
return load_model_weights(model, sd)
@ -80,7 +82,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print("NOT LOADED", x)
logging.warning("NOT LOADED {}".format(x))
return (new_modelpatcher, new_clip)
@ -122,10 +124,13 @@ class CLIP:
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
else:
self.cond_stage_model.reset_clip_layer()
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
@ -137,8 +142,11 @@ class CLIP:
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd)
def load_sd(self, sd, full_model=False):
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
def get_sd(self):
return self.cond_stage_model.state_dict()
@ -151,12 +159,17 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None):
def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@ -169,9 +182,43 @@ class VAE:
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
self.downscale_ratio = 4
self.upscale_ratio = 4
#TODO
#self.memory_used_encode
#self.memory_used_decode
self.process_input = lambda image: image
self.process_output = lambda image: image
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["encoder.{}".format(k)] = sd[k]
sd = new_sd
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["previewer.{}".format(k)] = sd[k]
sd = new_sd
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
@ -179,32 +226,44 @@ class VAE:
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
print("Missing VAE keys", m)
logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0:
print("Leftover VAE keys", u)
logging.debug("Leftover VAE keys {}".format(u))
if device is None:
device = model_management.vae_device()
self.device = device
offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype()
if dtype is None:
dtype = model_management.vae_dtype()
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def vae_encode_crop_pixels(self, pixels):
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp((
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0)
return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -213,10 +272,10 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples
@ -228,15 +287,15 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
@ -245,6 +304,7 @@ class VAE:
return output.movedim(1,-1)
def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
@ -252,18 +312,19 @@ class VAE:
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
@ -290,8 +351,11 @@ def load_style_model(ckpt_path):
model.load_state_dict(model_data)
return StyleModel(model)
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
def load_clip(ckpt_paths, embedding_directory=None):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
@ -301,14 +365,21 @@ def load_clip(ckpt_paths, embedding_directory=None):
for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32)
clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "")
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
clip_target = EmptyClass()
clip_target.params = {}
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
if clip_type == CLIPType.STABLE_CASCADE:
clip_target.clip = sdxl_clip.StableCascadeClipModel
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
@ -323,10 +394,10 @@ def load_clip(ckpt_paths, embedding_directory=None):
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
print("clip missing:", m)
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
print("clip unexpected:", u)
logging.debug("clip unexpected: {}".format(u))
return clip
def load_gligen(ckpt_path):
@ -431,12 +502,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -451,27 +523,33 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
if output_clip:
w = WeightsLoader()
clip_target = model_config.clip_target()
if clip_target is not None:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
sd = model_config.process_clip_state_dict(sd)
load_model_weights(w, sd)
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys:", left_over)
logging.debug("left over keys: {}".format(left_over))
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision)
@ -480,14 +558,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
load_device = model_management.get_torch_device()
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None
new_sd = sd
else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
@ -498,25 +578,36 @@ def load_unet_state_dict(sd): #load unet in diffusers format
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
logging.info("left over keys in unet: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd)
if model is None:
print("ERROR UNSUPPORTED UNET", unet_path)
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()])
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
clip_sd = None
load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)

150
comfy/sd1_clip.py

@ -1,12 +1,14 @@
import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
from transformers import CLIPTokenizer
import comfy.ops
import torch
import traceback
import zipfile
from . import model_management
import contextlib
import comfy.clip_model
import json
import logging
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
@ -37,7 +39,7 @@ class ClipTokenWeightEncoder:
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].cpu()
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
@ -54,8 +56,8 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
@ -65,31 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None:
self.transformer = model_class.from_pretrained(textmodel_path)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None:
self.transformer.to(dtype)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
with open(textmodel_json_config) as f:
config = json.load(f)
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
if freeze:
@ -97,16 +87,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False
self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers
self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx)
assert abs(layer_idx) < self.num_layers
self.set_clip_options({"layer": layer_idx})
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
def freeze(self):
self.transformer = self.transformer.eval()
@ -114,16 +106,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
for param in self.parameters():
param.requires_grad = False
def clip_layer(self, layer_idx):
if abs(layer_idx) >= self.num_layers:
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def reset_clip_layer(self):
self.layer = self.layer_default[0]
self.layer_idx = self.layer_default[1]
def reset_clip_options(self):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
@ -143,7 +138,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_temp += [next_new_token]
next_new_token += 1
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
@ -170,51 +165,37 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
z = outputs[1]
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
pooled_output = None
if len(outputs) >= 3:
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
def encode(self, tokens):
return self(tokens)
def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
if "text_projection.weight" in sd:
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string):
@ -349,9 +330,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
else:
embed = torch.load(embed_path, map_location="cpu")
except Exception as e:
print(traceback.format_exc())
print()
print("error loading embedding, skipping loading:", embedding_name)
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
return None
if embed_out is None:
@ -375,11 +354,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer('')["input_ids"]
if has_start_token:
@ -441,7 +421,7 @@ class SDTokenizer:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:
if len(embed.shape) == 1:
tokens.append([(embed, weight)])
@ -491,6 +471,8 @@ class SDTokenizer:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@ -524,11 +506,11 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
def clip_layer(self, layer_idx):
getattr(self, self.clip).clip_layer(layer_idx)
def set_clip_options(self, options):
getattr(self, self.clip).set_clip_options(options)
def reset_clip_layer(self):
getattr(self, self.clip).reset_clip_layer()
def reset_clip_options(self):
getattr(self, self.clip).reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]

6
comfy/sd2_clip.py

@ -3,13 +3,13 @@ import torch
import os
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=23
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):

40
comfy/sdxl_clip.py

@ -3,13 +3,13 @@ import torch
import os
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd):
@ -37,16 +37,16 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
self.clip_l.clip_layer(layer_idx)
self.clip_g.clip_layer(layer_idx)
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.clip_g.set_clip_options(options)
def reset_clip_layer(self):
self.clip_g.reset_clip_layer()
self.clip_l.reset_clip_layer()
def reset_clip_options(self):
self.clip_g.reset_clip_options()
self.clip_l.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
def load_sd(self, sd):
return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)

238
comfy/supported_models.py

@ -40,11 +40,16 @@ class SD15(supported_models_base.BASE):
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
replace_prefix = {}
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
replace_prefix["cond_stage_model."] = "clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
for p in pop_keys:
if p in state_dict:
state_dict.pop(p)
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@ -72,10 +77,10 @@ class SD20(supported_models_base.BASE):
def process_clip_state_dict(self, state_dict):
replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
replace_prefix["cond_stage_model.model."] = "clip_h."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@ -131,11 +136,10 @@ class SDXLRefiner(supported_models_base.BASE):
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
@ -164,7 +168,13 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict:
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
self.sampling_settings["sigma_data"] = 0.5
self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM
elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
else:
return model_base.ModelType.EPS
@ -179,26 +189,28 @@ class SDXL(supported_models_base.BASE):
keys_to_replace = {}
replace_prefix = {}
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
for k in state_dict:
if k.startswith("clip_l"):
state_dict_g[k] = state_dict[k]
state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
for p in pop_keys:
if p in state_dict_g:
state_dict_g.pop(p)
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
replace_prefix["clip_l"] = "conditioner.embedders.0"
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
@ -217,6 +229,36 @@ class SSD1B(SDXL):
"use_temporal_attention": False,
}
class Segmind_Vega(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 1, 1, 2, 2],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class KOALA_700M(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 5],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class KOALA_1B(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 6],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
@ -242,5 +284,161 @@ class SVD_img2vid(supported_models_base.BASE):
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
class SV3D_u(SVD_img2vid):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 256,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
vae_key_prefix = ["conditioner.embedders.1.encoder."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_u(self, device=device)
return out
class SV3D_p(SV3D_u):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 1280,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_p(self, device=device)
return out
class Stable_Zero123(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
unet_extra_config = {
"num_heads": 8,
"num_head_channels": -1,
}
clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
def clip_target(self):
return None
class SD_X4Upscaler(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 256,
'in_channels': 7,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
"disable_self_attentions": [True, True, True, False],
"num_classes": 1000,
"num_heads": 8,
"num_head_channels": -1,
}
latent_format = latent_formats.SD_X4
sampling_settings = {
"linear_start": 0.0001,
"linear_end": 0.02,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD_X4Upscaler(self, device=device)
return out
class Stable_Cascade_C(supported_models_base.BASE):
unet_config = {
"stable_cascade_stage": 'c',
}
unet_extra_config = {}
latent_format = latent_formats.SC_Prior
supported_inference_dtypes = [torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 2.0,
}
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoder."]
clip_vision_prefix = "clip_l_vision."
def process_unet_state_dict(self, state_dict):
key_list = list(state_dict.keys())
for y in ["weight", "bias"]:
suffix = "in_proj_{}".format(y)
keys = filter(lambda a: a.endswith(suffix), key_list)
for k_from in keys:
weights = state_dict.pop(k_from)
prefix = k_from[:-(len(suffix) + 1)]
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["to_q", "to_k", "to_v"]
k_to = "{}.{}.{}".format(prefix, p[x], y)
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return state_dict
def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
if "clip_g.text_projection" in state_dict:
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
return state_dict
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_C(self, device=device)
return out
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
class Stable_Cascade_B(Stable_Cascade_C):
unet_config = {
"stable_cascade_stage": 'b',
}
unet_extra_config = {}
latent_format = latent_formats.SC_B
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
clip_vision_prefix = None
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_B(self, device=device)
return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models += [SVD_img2vid]

21
comfy/supported_models_base.py

@ -21,11 +21,16 @@ class BASE:
noise_aug_config = None
sampling_settings = {}
latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]
text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
if s.unet_config[k] != unet_config[k]:
if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
return True
@ -51,6 +56,7 @@ class BASE:
return out
def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
return state_dict
def process_unet_state_dict(self, state_dict):
@ -60,7 +66,13 @@ class BASE:
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
replace_prefix = {"": self.text_encoder_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
replace_prefix = {}
if self.clip_vision_prefix is not None:
replace_prefix[""] = self.clip_vision_prefix
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict):
@ -68,6 +80,9 @@ class BASE:
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "first_stage_model."}
replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype

5
comfy/taesd/taesd.py

@ -7,9 +7,10 @@ import torch
import torch.nn as nn
import comfy.utils
import comfy.ops
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
@ -19,7 +20,7 @@ class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))

62
comfy/utils.py

@ -5,6 +5,7 @@ import comfy.checkpoint_pickle
import safetensors.torch
import numpy as np
from PIL import Image
import logging
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@ -14,14 +15,14 @@ def load_torch_file(ckpt, safe_load=False, device=None):
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@ -98,8 +99,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
tp = "{}text_projection.weight".format(prefix_from)
if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
tp = "{}text_projection".format(prefix_from)
if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
return sd
UNET_MAP_ATTENTIONS = {
"proj_in.weight",
"proj_in.bias",
@ -169,6 +184,8 @@ UNET_MAP_BASIC = {
}
def unet_to_diffusers(unet_config):
if "num_res_blocks" not in unet_config:
return {}
num_res_blocks = unet_config["num_res_blocks"]
channel_mult = unet_config["channel_mult"]
transformer_depth = unet_config["transformer_depth"][:]
@ -239,6 +256,26 @@ def repeat_to_batch_size(tensor, batch_size):
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
return tensor
def resize_to_batch_size(tensor, batch_size):
in_batch_size = tensor.shape[0]
if in_batch_size == batch_size:
return tensor
if batch_size <= 1:
return tensor[:batch_size]
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
return output
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:
@ -258,8 +295,11 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev
setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
@ -356,7 +396,7 @@ def lanczos(samples, width, height):
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images)
return result
return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
@ -385,17 +425,19 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
x = max(0, min(s.shape[-1] - overlap, x))
y = max(0, min(s.shape[-2] - overlap, y))
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
ps = function(s_in).cpu()
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):

272
comfy_extras/nodes_canny.py

@ -5,275 +5,7 @@ import torch
import torch.nn.functional as F
import comfy.model_management
def get_canny_nms_kernel(device=None, dtype=None):
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
return torch.tensor(
[
[[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]],
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
[[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
def get_hysteresis_kernel(device=None, dtype=None):
"""Utility function that returns the 3x3 kernels for the Canny hysteresis."""
return torch.tensor(
[
[[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]],
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
[[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = torch.nn.functional.pad(img, padding, mode="reflect")
img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3])
return img
def get_sobel_kernel2d(device=None, dtype=None):
kernel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=device, dtype=dtype)
kernel_y = kernel_x.transpose(0, 1)
return torch.stack([kernel_x, kernel_y])
def spatial_gradient(input, normalized: bool = True):
r"""Compute the first order image derivative in both x and y using a Sobel operator.
.. image:: _static/img/spatial_gradient.png
Args:
input: input image tensor with shape :math:`(B, C, H, W)`.
mode: derivatives modality, can be: `sobel` or `diff`.
order: the order of the derivatives.
normalized: whether the output is normalized.
Return:
the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
filtering_edges.html>`__.
Examples:
>>> input = torch.rand(1, 3, 4, 4)
>>> output = spatial_gradient(input) # 1x3x2x4x4
>>> output.shape
torch.Size([1, 3, 2, 4, 4])
"""
# KORNIA_CHECK_IS_TENSOR(input)
# KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W'])
# allocate kernel
kernel = get_sobel_kernel2d(device=input.device, dtype=input.dtype)
if normalized:
kernel = normalize_kernel2d(kernel)
# prepare kernel
b, c, h, w = input.shape
tmp_kernel = kernel[:, None, ...]
# Pad with "replicate for spatial dims, but with zeros for channel
spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
out_channels: int = 2
padded_inp = torch.nn.functional.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')
out = F.conv2d(padded_inp, tmp_kernel, groups=1, padding=0, stride=1)
return out.reshape(b, c, out_channels, h, w)
def rgb_to_grayscale(image, rgb_weights = None):
r"""Convert a RGB image to grayscale version of image.
.. image:: _static/img/rgb_to_grayscale.png
The image data is assumed to be in the range of (0, 1).
Args:
image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`.
rgb_weights: Weights that will be applied on each channel (RGB).
The sum of the weights should add up to one.
Returns:
grayscale version of the image with shape :math:`(*,1,H,W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
color_conversions.html>`__.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> gray = rgb_to_grayscale(input) # 2x1x4x5
"""
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
if rgb_weights is None:
# 8 bit images
if image.dtype == torch.uint8:
rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
# floating point images
elif image.dtype in (torch.float16, torch.float32, torch.float64):
rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
else:
raise TypeError(f"Unknown data type: {image.dtype}")
else:
# is tensor that we make sure is in the same device/dtype
rgb_weights = rgb_weights.to(image)
# unpack the color image channels with RGB order
r: Tensor = image[..., 0:1, :, :]
g: Tensor = image[..., 1:2, :, :]
b: Tensor = image[..., 2:3, :, :]
w_r, w_g, w_b = rgb_weights.unbind()
return w_r * r + w_g * g + w_b * b
def canny(
input,
low_threshold = 0.1,
high_threshold = 0.2,
kernel_size = 5,
sigma = 1,
hysteresis = True,
eps = 1e-6,
):
r"""Find edges of the input image and filters them using the Canny algorithm.
.. image:: _static/img/canny.png
Args:
input: input image tensor with shape :math:`(B,C,H,W)`.
low_threshold: lower threshold for the hysteresis procedure.
high_threshold: upper threshold for the hysteresis procedure.
kernel_size: the size of the kernel for the gaussian blur.
sigma: the standard deviation of the kernel for the gaussian blur.
hysteresis: if True, applies the hysteresis edge tracking.
Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
eps: regularization number to avoid NaN during backprop.
Returns:
- the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
- the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
canny.html>`__.
Example:
>>> input = torch.rand(5, 3, 4, 4)
>>> magnitude, edges = canny(input) # 5x3x4x4
>>> magnitude.shape
torch.Size([5, 1, 4, 4])
>>> edges.shape
torch.Size([5, 1, 4, 4])
"""
# KORNIA_CHECK_IS_TENSOR(input)
# KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W'])
# KORNIA_CHECK(
# low_threshold <= high_threshold,
# "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: "
# f"{low_threshold}>{high_threshold}",
# )
# KORNIA_CHECK(0 < low_threshold < 1, f'Invalid low threshold. Should be in range (0, 1). Got: {low_threshold}')
# KORNIA_CHECK(0 < high_threshold < 1, f'Invalid high threshold. Should be in range (0, 1). Got: {high_threshold}')
device = input.device
dtype = input.dtype
# To Grayscale
if input.shape[1] == 3:
input = rgb_to_grayscale(input)
# Gaussian filter
blurred: Tensor = gaussian_blur_2d(input, kernel_size, sigma)
# Compute the gradients
gradients: Tensor = spatial_gradient(blurred, normalized=False)
# Unpack the edges
gx: Tensor = gradients[:, :, 0]
gy: Tensor = gradients[:, :, 1]
# Compute gradient magnitude and angle
magnitude: Tensor = torch.sqrt(gx * gx + gy * gy + eps)
angle: Tensor = torch.atan2(gy, gx)
# Radians to Degrees
angle = 180.0 * angle / math.pi
# Round angle to the nearest 45 degree
angle = torch.round(angle / 45) * 45
# Non-maximal suppression
nms_kernels: Tensor = get_canny_nms_kernel(device, dtype)
nms_magnitude: Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)
# Get the indices for both directions
positive_idx: Tensor = (angle / 45) % 8
positive_idx = positive_idx.long()
negative_idx: Tensor = ((angle / 45) + 4) % 8
negative_idx = negative_idx.long()
# Apply the non-maximum suppression to the different directions
channel_select_filtered_positive: Tensor = torch.gather(nms_magnitude, 1, positive_idx)
channel_select_filtered_negative: Tensor = torch.gather(nms_magnitude, 1, negative_idx)
channel_select_filtered: Tensor = torch.stack(
[channel_select_filtered_positive, channel_select_filtered_negative], 1
)
is_max: Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
magnitude = magnitude * is_max
# Threshold
edges: Tensor = F.threshold(magnitude, low_threshold, 0.0)
low: Tensor = magnitude > low_threshold
high: Tensor = magnitude > high_threshold
edges = low * 0.5 + high * 0.5
edges = edges.to(dtype)
# Hysteresis
if hysteresis:
edges_old: Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
hysteresis_kernels: Tensor = get_hysteresis_kernel(device, dtype)
while ((edges_old - edges).abs() != 0).any():
weak: Tensor = (edges == 0.5).float()
strong: Tensor = (edges == 1).float()
hysteresis_magnitude: Tensor = F.conv2d(
edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
)
hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
hysteresis_magnitude = hysteresis_magnitude * weak + strong
edges_old = edges.clone()
edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
edges = hysteresis_magnitude
return magnitude, edges
from kornia.filters import canny
class Canny:
@ -291,7 +23,7 @@ class Canny:
def detect_edge(self, image, low_threshold, high_threshold):
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,)
NODE_CLASS_MAPPINGS = {

25
comfy_extras/nodes_cond.py

@ -0,0 +1,25 @@
class CLIPTextEncodeControlnet:
@classmethod
def INPUT_TYPES(s):
return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "_for_testing/conditioning"
def encode(self, clip, conditioning, text):
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['cross_attn_controlnet'] = cond
n[1]['pooled_output_controlnet'] = pooled
c.append(n)
return (c, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
}

104
comfy_extras/nodes_custom_sampler.py

@ -13,6 +13,7 @@ class BasicScheduler:
{"model": ("MODEL",),
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
@ -20,8 +21,14 @@ class BasicScheduler:
FUNCTION = "get_sigmas"
def get_sigmas(self, model, scheduler, steps):
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu()
def get_sigmas(self, model, scheduler, steps, denoise):
total_steps = steps
if denoise < 1.0:
total_steps = int(steps/denoise)
comfy.model_management.load_models_gpu([model])
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):]
return (sigmas, )
@ -87,6 +94,7 @@ class SDTurboScheduler:
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
@ -94,8 +102,10 @@ class SDTurboScheduler:
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
def get_sigmas(self, model, steps, denoise):
start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
comfy.model_management.load_models_gpu([model])
sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
@ -171,6 +181,28 @@ class KSamplerSelect:
sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, )
class SamplerDPMPP_3M_SDE:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"noise_device": (['gpu', 'cpu'], ),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise, noise_device):
if noise_device == 'cpu':
sampler_name = "dpmpp_3m_sde"
else:
sampler_name = "dpmpp_3m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerDPMPP_2M_SDE:
@classmethod
def INPUT_TYPES(s):
@ -218,6 +250,66 @@ class SamplerDPMPP_SDE:
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerEulerAncestral:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise):
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerLMS:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"order": ("INT", {"default": 4, "min": 1, "max": 100}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, order):
sampler = comfy.samplers.ksampler("lms", {"order": order})
return (sampler, )
class SamplerDPMAdaptative:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"order": ("INT", {"default": 3, "min": 2, "max": 3}),
"rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise):
sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
"icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
"s_noise":s_noise })
return (sampler, )
class SamplerCustom:
@classmethod
def INPUT_TYPES(s):
@ -278,8 +370,12 @@ NODE_CLASS_MAPPINGS = {
"VPScheduler": VPScheduler,
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect,
"SamplerEulerAncestral": SamplerEulerAncestral,
"SamplerLMS": SamplerLMS,
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
}

42
comfy_extras/nodes_differential_diffusion.py

@ -0,0 +1,42 @@
# code adapted from https://github.com/exx8/differential-diffusion
import torch
class DifferentialDiffusion():
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply"
CATEGORY = "_for_testing"
INIT = False
def apply(self, model):
model = model.clone()
model.set_model_denoise_mask_function(self.forward)
return (model,)
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
model = extra_options["model"]
step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min
if step_sigmas[-1] > sigma_to:
sigma_to = step_sigmas[-1]
sigma_from = step_sigmas[0]
ts_from = model.inner_model.model_sampling.timestep(sigma_from)
ts_to = model.inner_model.model_sampling.timestep(sigma_to)
current_ts = model.inner_model.model_sampling.timestep(sigma[0])
threshold = (current_ts - ts_to) / (ts_from - ts_to)
return (denoise_mask >= threshold).to(denoise_mask.dtype)
NODE_CLASS_MAPPINGS = {
"DifferentialDiffusion": DifferentialDiffusion,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DifferentialDiffusion": "Differential Diffusion",
}

10
comfy_extras/nodes_freelunch.py

@ -1,7 +1,7 @@
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
import torch
import logging
def Fourier_filter(x, threshold, scale):
# FFT
@ -34,7 +34,7 @@ class FreeU:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
@ -49,7 +49,7 @@ class FreeU:
try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else:
@ -73,7 +73,7 @@ class FreeU_V2:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
@ -95,7 +95,7 @@ class FreeU_V2:
try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else:

3
comfy_extras/nodes_hypernetwork.py

@ -1,6 +1,7 @@
import comfy.utils
import folder_paths
import torch
import logging
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
@ -23,7 +24,7 @@ def load_hypernetwork_patch(path, strength):
}
if activation_func not in valid_activation:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
return None
out = {}

36
comfy_extras/nodes_hypertile.py

@ -2,9 +2,10 @@
import math
from einops import rearrange
import random
# Use torch rng for consistency across generations
from torch import randint
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value)
# All big divisors of value (inclusive)
@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
random.seed(counter)
idx = random.randint(0, len(ns) - 1)
if len(ns) - 1 > 0:
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
else:
idx = 0
return ns[idx]
@ -29,34 +32,31 @@ class HyperTile:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"]
apply_to = set()
temp = model_channels
for x in range(max_depth + 1):
apply_to.add(temp)
temp *= 2
latent_tile_size = max(32, tile_size) // 8
self.temp = None
self.counter = 1
def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to:
model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
for i in range(max_depth + 1):
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
if model_chans in apply_to:
shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size)
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)

26
comfy_extras/nodes_images.py

@ -37,7 +37,7 @@ class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
@ -48,6 +48,25 @@ class RepeatImageBatch:
s = image.repeat((amount, 1,1,1))
return (s,)
class ImageFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "frombatch"
CATEGORY = "image/batch"
def frombatch(self, image, batch_index, length):
s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone()
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -74,7 +93,7 @@ class SaveAnimatedWEBP:
OUTPUT_NODE = True
CATEGORY = "_for_testing"
CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
@ -136,7 +155,7 @@ class SaveAnimatedPNG:
OUTPUT_NODE = True
CATEGORY = "_for_testing"
CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
@ -170,6 +189,7 @@ class SaveAnimatedPNG:
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
}

49
comfy_extras/nodes_latent.py

@ -3,9 +3,7 @@ import torch
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
latent.movedim(1, -1)
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
latent.movedim(-1, 1)
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
@ -102,9 +100,56 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
class LatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
if s1.shape[1:] != s2.shape[1:]:
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
class LatentBatchSeedBehavior:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
samples_out = samples.copy()
latent = samples["samples"]
if seed_behavior == "random":
if 'batch_index' in samples_out:
samples_out.pop('batch_index')
elif seed_behavior == "fixed":
batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
}

22
comfy_extras/nodes_mask.py

@ -6,6 +6,7 @@ import comfy.utils
from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
@ -20,7 +21,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.clone()
mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
@ -340,6 +341,24 @@ class GrowMask:
out.append(output)
return (torch.stack(out, dim=0),)
class ThresholdMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, mask, value):
mask = (mask > value).float()
return (mask,)
NODE_CLASS_MAPPINGS = {
@ -354,6 +373,7 @@ NODE_CLASS_MAPPINGS = {
"MaskComposite": MaskComposite,
"FeatherMask": FeatherMask,
"GrowMask": GrowMask,
"ThresholdMask": ThresholdMask,
}
NODE_DISPLAY_NAME_MAPPINGS = {

96
comfy_extras/nodes_model_advanced.py

@ -1,6 +1,7 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
import torch
class LCM(comfy.model_sampling.EPS):
@ -17,41 +18,23 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteDistilled(torch.nn.Module):
original_timesteps = 50
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
class X0(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
self.skip_steps = timesteps // self.original_timesteps
def __init__(self, model_config=None):
super().__init__(model_config)
self.skip_steps = self.num_timesteps // self.original_timesteps
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
@property
def sigma_max(self):
return self.sigmas[-1]
self.set_sigmas(sigmas_valid)
def timestep(self, sigma):
log_sigma = sigma.log()
@ -66,14 +49,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
@ -98,7 +73,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],),
"sampling": (["eps", "v_prediction", "lcm", "x0"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
@ -118,22 +93,50 @@ class ModelSamplingDiscrete:
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "x0":
sampling_type = X0
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling = ModelSamplingAdvanced(model.model.model_config)
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingStableCascade:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, shift):
m = model.clone()
sampling_base = comfy.model_sampling.StableCascadeSampling
sampling_type = comfy.model_sampling.EPS
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
@ -146,17 +149,25 @@ class ModelSamplingContinuousEDM:
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
latent_format = None
sigma_data = 1.0
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "edm_playground_v2.5":
sampling_type = comfy.model_sampling.EDM
sigma_data = 0.5
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling.set_sigma_range(sigma_min, sigma_max)
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
m.add_object_patch("model_sampling", model_sampling)
if latent_format is not None:
m.add_object_patch("latent_format", latent_format)
return (m, )
class RescaleCFG:
@ -201,5 +212,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"ModelSamplingStableCascade": ModelSamplingStableCascade,
"RescaleCFG": RescaleCFG,
}

129
comfy_extras/nodes_model_merging.py

@ -87,6 +87,50 @@ class CLIPMergeSimple:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
class CLIPSubtract:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, multiplier):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
class CLIPAdd:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
class ModelMergeBlocks:
@classmethod
def INPUT_TYPES(s):
@ -119,6 +163,48 @@ class ModelMergeBlocks:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
class CheckpointSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -137,46 +223,7 @@ class CheckpointSave:
CATEGORY = "advanced/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
class CLIPSave:
@ -276,6 +323,8 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple,
"CLIPMergeSubtract": CLIPSubtract,
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
}

49
comfy_extras/nodes_morphology.py

@ -0,0 +1,49 @@
import torch
import comfy.model_management
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
class Morphology:
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],),
"kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "image/postprocessing"
def process(self, image, operation, kernel_size):
device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1)
if operation == "erode":
output = erosion(image_k, kernel)
elif operation == "dilate":
output = dilation(image_k, kernel)
elif operation == "open":
output = opening(image_k, kernel)
elif operation == "close":
output = closing(image_k, kernel)
elif operation == "gradient":
output = gradient(image_k, kernel)
elif operation == "top_hat":
output = top_hat(image_k, kernel)
elif operation == "bottom_hat":
output = bottom_hat(image_k, kernel)
else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Morphology": "ImageMorphology",
}

55
comfy_extras/nodes_perpneg.py

@ -0,0 +1,55 @@
import torch
import comfy.model_management
import comfy.sample
import comfy.samplers
import comfy.utils
class PerpNeg:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()
nocond = comfy.sample.convert_cond(empty_conditioning)
def cfg_function(args):
model = args["model"]
noise_pred_pos = args["cond_denoised"]
noise_pred_neg = args["uncond_denoised"]
cond_scale = args["cond_scale"]
x = args["input"]
sigma = args["sigma"]
model_options = args["model_options"]
nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond
perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
cfg_result = x - cfg_result
return cfg_result
m.set_model_sampler_cfg_function(cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg",
}

187
comfy_extras/nodes_photomaker.py

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
import folder_paths
import comfy.clip_model
import comfy.clip_vision
import comfy.ops
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
VISION_CONFIG_DICT = {
"hidden_size": 1024,
"image_size": 224,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"hidden_act": "quick_gelu",
}
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = operations.LayerNorm(in_dim)
self.fc1 = operations.Linear(in_dim, hidden_dim)
self.fc2 = operations.Linear(hidden_dim, out_dim)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class FuseModule(nn.Module):
def __init__(self, embed_dim, operations):
super().__init__()
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
self.layer_norm = operations.LayerNorm(embed_dim)
def fuse_fn(self, prompt_embeds, id_embeds):
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
stacked_id_embeds = self.mlp2(stacked_id_embeds)
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
return stacked_id_embeds
def forward(
self,
prompt_embeds,
id_embeds,
class_tokens_mask,
) -> torch.Tensor:
# id_embeds shape: [b, max_num_inputs, 1, 2048]
id_embeds = id_embeds.to(prompt_embeds.dtype)
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
batch_size, max_num_inputs = id_embeds.shape[:2]
# seq_length: 77
seq_length = prompt_embeds.shape[1]
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
flat_id_embeds = id_embeds.view(
-1, id_embeds.shape[-2], id_embeds.shape[-1]
)
# valid_id_mask [b*max_num_inputs]
valid_id_mask = (
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
< num_inputs[:, None]
)
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
class_tokens_mask = class_tokens_mask.view(-1)
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
# slice out the image token embeddings
image_token_embeds = prompt_embeds[class_tokens_mask]
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
return updated_prompt_embeds
class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
def __init__(self):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
dtype = comfy.model_management.text_encoder_dtype(self.load_device)
super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast)
self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False)
self.fuse_module = FuseModule(2048, comfy.ops.manual_cast)
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
b, num_inputs, c, h, w = id_pixel_values.shape
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
shared_id_embeds = self.vision_model(id_pixel_values)[2]
id_embeds = self.visual_projection(shared_id_embeds)
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
return updated_prompt_embeds
class PhotoMakerLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
RETURN_TYPES = ("PHOTOMAKER",)
FUNCTION = "load_photomaker_model"
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:
data = data["id_encoder"]
photomaker_model.load_state_dict(data)
return (photomaker_model,)
class PhotoMakerEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker": ("PHOTOMAKER",),
"image": ("IMAGE",),
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_photomaker"
CATEGORY = "_for_testing/photomaker"
def apply_photomaker(self, photomaker, image, clip, text):
special_token = "photomaker"
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
try:
index = text.split(" ").index(special_token) + 1
except ValueError:
index = -1
tokens = clip.tokenize(text, return_word_ids=True)
out_tokens = {}
for k in tokens:
out_tokens[k] = []
for t in tokens[k]:
f = list(filter(lambda x: x[2] != index, t))
while len(f) < len(t):
f.append(t[-1])
out_tokens[k].append(f)
cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
if index > 0:
token_index = index - 1
num_id_images = 1
class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
else:
out = cond
return ([[out, {"pooled_output": pooled}]], )
NODE_CLASS_MAPPINGS = {
"PhotoMakerLoader": PhotoMakerLoader,
"PhotoMakerEncode": PhotoMakerEncode,
}

3
comfy_extras/nodes_post_processing.py

@ -33,6 +33,7 @@ class Blend:
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
image2 = image2.to(image1.device)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
@ -226,7 +227,7 @@ class Sharpen:
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)

30
comfy_extras/nodes_rebatch.py

@ -99,10 +99,40 @@ class LatentRebatch:
return (output_list,)
class ImageRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
batch_size = batch_size[0]
output_list = []
all_images = []
for img in images:
for i in range(img.shape[0]):
all_images.append(img[i:i+1])
for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
"RebatchImages": "Rebatch Images",
}

170
comfy_extras/nodes_sag.py

@ -0,0 +1,170 @@
import torch
from torch import einsum
import torch.nn.functional as F
import math
from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
import comfy.samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return (out, sim)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, scale, blur_sigma):
m = model.clone()
attn_scores = None
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
def attn_and_record(q, k, v, extra_options):
nonlocal attn_scores
# if uncond, save the attention scores
heads = extra_options["n_heads"]
cond_or_uncond = extra_options["cond_or_uncond"]
b = q.shape[0] // len(cond_or_uncond)
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
def post_cfg_function(args):
nonlocal attn_scores
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
return cfg_result
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, )
NODE_CLASS_MAPPINGS = {
"SelfAttentionGuidance": SelfAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SelfAttentionGuidance": "Self-Attention Guidance",
}

47
comfy_extras/nodes_sdupscale.py

@ -0,0 +1,47 @@
import torch
import nodes
import comfy.utils
class SD_4XUpscale_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/upscale_diffusion"
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
width = max(1, round(images.shape[-2] * scale_ratio))
height = max(1, round(images.shape[-3] * scale_ratio))
pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
out_cp = []
out_cn = []
for t in positive:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cp.append(n)
for t in negative:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cn.append(n)
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
return (out_cp, out_cn, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
}

143
comfy_extras/nodes_stable3d.py

@ -0,0 +1,143 @@
import torch
import nodes
import comfy.utils
def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation])
azimuth = torch.as_tensor([azimuth])
embeddings = torch.stack(
[
torch.deg2rad(
(90 - elevation) - (90)
), # Zero123 polar is 90-elevation
torch.sin(torch.deg2rad(azimuth)),
torch.cos(torch.deg2rad(azimuth)),
torch.deg2rad(
90 - torch.full_like(elevation, 0)
),
], dim=-1).unsqueeze(1)
return embeddings
class StableZero123_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = camera_embeddings(elevation, azimuth)
cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = []
for i in range(batch_size):
cam_embeds.append(camera_embeddings(elevation, azimuth))
elevation += elevation_batch_increment
azimuth += azimuth_batch_increment
cam_embeds = torch.cat(cam_embeds, dim=0)
cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
class SV3D_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
azimuth = 0
azimuth_increment = 360 / (max(video_frames, 2) - 1)
elevations = []
azimuths = []
for i in range(video_frames):
elevations.append(elevation)
azimuths.append(azimuth)
azimuth += azimuth_increment
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning,
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
"SV3D_Conditioning": SV3D_Conditioning,
}

140
comfy_extras/nodes_stable_cascade.py

@ -0,0 +1,140 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import nodes
import comfy.utils
class StableCascade_EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, width, height, compression, batch_size=1):
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageC_VAEEncode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, image, vae, compression):
width = image.shape[-2]
height = image.shape[-3]
out_width = (width // compression) * vae.downscale_ratio
out_height = (height // compression) * vae.downscale_ratio
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageB_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "conditioning": ("CONDITIONING",),
"stage_c": ("LATENT",),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "set_prior"
CATEGORY = "conditioning/stable_cascade"
def set_prior(self, conditioning, stage_c):
c = []
for t in conditioning:
d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples']
n = [t[0], d]
c.append(n)
return (c, )
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):
width = image.shape[-2]
height = image.shape[-3]
batch_size = image.shape[0]
controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1)
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})
NODE_CLASS_MAPPINGS = {
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
}

45
comfy_extras/nodes_video_model.py

@ -3,6 +3,7 @@ import torch
import comfy.utils
import comfy.sd
import folder_paths
import comfy_extras.nodes_model_merging
class ImageOnlyCheckpointLoader:
@ -78,10 +79,54 @@ class VideoLinearCFGGuidance:
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
class VideoTriangleCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
period = 1.0
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
}
NODE_DISPLAY_NAME_MAPPINGS = {

10
cuda_malloc.py

@ -1,6 +1,7 @@
import os
import importlib.util
from comfy.cli_args import args
import subprocess
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
@ -34,14 +35,19 @@ def get_gpu_names():
return gpu_names
return enum_display_devices()
else:
return set()
gpu_names = set()
out = subprocess.check_output(['nvidia-smi', '-L'])
for l in out.split(b'\n'):
if len(l) > 0:
gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
return gpu_names
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630"
"GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
}
def cuda_malloc_supported():

16
custom_nodes/example_node.py.example

@ -6,6 +6,8 @@ class Example:
-------------
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
Attributes
----------
@ -89,6 +91,20 @@ class Example:
image = 1.0 - image
return (image,)
"""
The node will always be re executed if any of the inputs change but
this method can be used to force the node to execute again even when the inputs don't change.
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
executed, if it is different the node will be executed again.
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
changes between executions the LoadImage node is executed again.
"""
#@classmethod
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
# return ""
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
# WEB_DIRECTORY = "./somejs"
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique

45
custom_nodes/websocket_image_save.py

@ -0,0 +1,45 @@
from PIL import Image, ImageOps
from io import BytesIO
import numpy as np
import struct
import comfy.utils
import time
#You can use this node to save full size images through the websocket, the
#images will be sent in exactly the same format as the image previews: as
#binary images on the websocket with a 8 byte header indicating the type
#of binary message (first 4 bytes) and the image format (next 4 bytes).
#Note that no metadata will be put in the images saved with this node.
class SaveImageWebsocket:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),}
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "api/image"
def save_images(self, images):
pbar = comfy.utils.ProgressBar(images.shape[0])
step = 0
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pbar.update_absolute(step, images.shape[0], ("PNG", img, None))
step += 1
return {}
def IS_CHANGED(s, images):
return time.time()
NODE_CLASS_MAPPINGS = {
"SaveImageWebsocket": SaveImageWebsocket,
}

172
execution.py

@ -1,12 +1,11 @@
import os
import sys
import copy
import json
import logging
import threading
import heapq
import traceback
import gc
import inspect
from typing import List, Literal, NamedTuple, Optional
import torch
import nodes
@ -36,8 +35,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
if h[x] == "PROMPT":
input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = [extra_data['extra_pnginfo']]
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
return input_data_all
@ -195,8 +193,12 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item):
def recursive_will_execute(prompt, outputs, current_item, memo={}):
unique_id = current_item
if unique_id in memo:
return memo[unique_id]
inputs = prompt[unique_id]['inputs']
will_execute = []
if unique_id in outputs:
@ -208,9 +210,10 @@ def recursive_will_execute(prompt, outputs, current_item):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
return will_execute + [unique_id]
memo[unique_id] = will_execute + [unique_id]
return memo[unique_id]
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
@ -267,11 +270,21 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor:
def __init__(self, server):
self.server = server
self.reset()
def reset(self):
self.outputs = {}
self.object_storage = {}
self.outputs_ui = {}
self.status_messages = []
self.success = True
self.old_prompt = {}
self.server = server
def add_message(self, event, data, broadcast: bool):
self.status_messages.append((event, data))
if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
@ -286,22 +299,21 @@ class PromptExecutor:
"node_type": class_type,
"executed": list(executed),
}
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
self.add_message("execution_interrupted", mes, broadcast=True)
else:
if self.server.client_id is not None:
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
}
self.server.send_sync("execution_error", mes, self.server.client_id)
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
}
self.add_message("execution_error", mes, broadcast=False)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
@ -323,8 +335,8 @@ class PromptExecutor:
else:
self.server.client_id = None
if self.server.client_id is not None:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
@ -356,9 +368,9 @@ class PromptExecutor:
d = self.outputs_ui.pop(x)
del d
comfy.model_management.cleanup_models()
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)
executed = set()
output_node_id = None
to_execute = []
@ -368,20 +380,23 @@ class PromptExecutor:
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
memo = {}
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if success is not True:
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
@ -400,6 +415,10 @@ def validate_inputs(prompt, item, validated):
errors = []
valid = True
validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
for x in required_inputs:
if x not in inputs:
error = {
@ -529,29 +548,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if x not in validate_function_inputs:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
@ -578,6 +575,35 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
else:
@ -692,6 +718,7 @@ class PromptQueue:
self.queue = []
self.currently_running = {}
self.history = {}
self.flags = {}
server.prompt_queue = self
def put(self, item):
@ -713,14 +740,27 @@ class PromptQueue:
self.server.queue_updated()
return (item, i)
def task_done(self, item_id, outputs):
class ExecutionStatus(NamedTuple):
status_str: Literal['success', 'error']
completed: bool
messages: List[str]
def task_done(self, item_id, outputs,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex:
prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o]
status_dict: Optional[dict] = None
if status is not None:
status_dict = copy.deepcopy(status._asdict())
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
'status': status_dict,
}
self.server.queue_updated()
def get_current_queue(self):
@ -778,3 +818,17 @@ class PromptQueue:
def delete_history_item(self, id_to_delete):
with self.mutex:
self.history.pop(id_to_delete, None)
def set_flag(self, name, data):
with self.mutex:
self.flags[name] = data
self.not_empty.notify()
def get_flags(self, reset=True):
with self.mutex:
if reset:
ret = self.flags
self.flags = {}
return ret
else:
return self.flags.copy()

22
folder_paths.py

@ -29,11 +29,14 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
filename_list_cache = {}
@ -137,15 +140,27 @@ def recursive_search(directory, excluded_dir_names=None):
excluded_dir_names = []
result = []
dirs = {directory: os.path.getmtime(directory)}
dirs = {}
# Attempt to add the initial directory to dirs with error handling
try:
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
print(f"Warning: Unable to access {directory}. Skipping this path.")
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
for d in subdirs:
path = os.path.join(dirpath, d)
dirs[path] = os.path.getmtime(path)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
print(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs
def filter_files_extensions(files, extensions):
@ -184,8 +199,7 @@ def cached_filename_list_(folder_name):
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]:
time_modified = out[1][x]
folder = x

3
latent_preview.py

@ -6,6 +6,7 @@ from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import folder_paths
import comfy.utils
import logging
MAX_PREVIEW_RESOLUTION = 512
@ -70,7 +71,7 @@ def get_previewer(device, latent_format):
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None:
if latent_format.latent_rgb_factors is not None:

58
main.py

@ -54,15 +54,19 @@ import threading
import gc
from comfy.cli_args import args
import logging
if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__":
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc
@ -84,7 +88,7 @@ def cuda_malloc_warning():
if b in device_name:
cuda_malloc_warning = True
if cuda_malloc_warning:
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
@ -93,7 +97,7 @@ def prompt_worker(q, server):
gc_collect_interval = 10.0
while True:
timeout = None
timeout = 1000.0
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
@ -102,19 +106,40 @@ def prompt_worker(q, server):
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
server.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id, e.outputs_ui)
q.task_done(item_id,
e.outputs_ui,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)
if flags.get("unload_models", free_memory):
comfy.model_management.unload_all_models()
need_gc = True
last_gc_collect = 0
if free_memory:
e.reset()
need_gc = True
last_gc_collect = 0
if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
@ -127,7 +152,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server):
def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted()
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
server.send_sync("progress", progress, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
@ -156,17 +183,24 @@ def load_extra_path_config(yaml_path):
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
print("Adding extra search path", x, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__":
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
print(f"Setting temp directory to: {temp_dir}")
logging.info(f"Setting temp directory to: {temp_dir}")
folder_paths.set_temp_directory(temp_dir)
cleanup_temp()
if args.windows_standalone_build:
try:
import new_updater
new_updater.update_windows_updater()
except:
pass
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
@ -191,7 +225,7 @@ if __name__ == "__main__":
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
print(f"Setting output directory to: {output_dir}")
logging.info(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
@ -201,7 +235,7 @@ if __name__ == "__main__":
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
print(f"Setting input directory to: {input_dir}")
logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci:
@ -219,6 +253,6 @@ if __name__ == "__main__":
try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
print("\nStopped server")
logging.info("\nStopped server")
cleanup_temp()

0
models/photomaker/put_photomaker_models_here

35
new_updater.py

@ -0,0 +1,35 @@
import os
import shutil
base_path = os.path.dirname(os.path.realpath(__file__))
def update_windows_updater():
top_path = os.path.dirname(base_path)
updater_path = os.path.join(base_path, ".ci/update_windows/update.py")
bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat")
dest_updater_path = os.path.join(top_path, "update/update.py")
dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat")
dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat")
try:
with open(dest_bat_path, 'rb') as f:
contents = f.read()
except:
return
if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"):
return
shutil.copy(updater_path, dest_updater_path)
try:
with open(dest_bat_deps_path, 'rb') as f:
contents = f.read()
contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause')
with open(dest_bat_deps_path, 'wb') as f:
f.write(contents)
except:
pass
shutil.copy(bat_path, dest_bat_path)
print("Updated the windows standalone package updater.")

212
nodes.py

@ -8,8 +8,9 @@ import traceback
import math
import time
import random
import logging
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
@ -40,7 +41,7 @@ def before_node_execution():
def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192
MAX_RESOLUTION=16384
class CLIPTextEncode:
@classmethod
@ -83,7 +84,7 @@ class ConditioningAverage :
out = []
if len(conditioning_from) > 1:
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0]
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
@ -122,7 +123,7 @@ class ConditioningConcat:
out = []
if len(conditioning_from) > 1:
print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0]
@ -184,6 +185,26 @@ class ConditioningSetAreaPercentage:
c.append(n)
return (c, )
class ConditioningSetAreaStrength:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
return (c, )
class ConditioningSetMask:
@classmethod
def INPUT_TYPES(s):
@ -289,18 +310,7 @@ class VAEEncode:
CATEGORY = "latent"
@staticmethod
def vae_encode_crop_pixels(pixels):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def encode(self, vae, pixels):
pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, )
@ -316,7 +326,6 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size):
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, )
@ -330,14 +339,14 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
@ -359,6 +368,62 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class InpaintModelConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
orig_pixels = pixels
pixels = orig_pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
m = (1.0 - mask.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels)
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -778,15 +843,20 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name):
def load_clip(self, clip_name, type="stable_diffusion"):
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
class DualCLIPLoader:
@ -934,7 +1004,7 @@ class GLIGENTextBoxApply:
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
for t in conditioning_to:
n = [t[0], t[1].copy()]
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
@ -947,8 +1017,8 @@ class GLIGENTextBoxApply:
return (c, )
class EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
@ -961,7 +1031,7 @@ class EmptyLatentImage:
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
@ -1358,7 +1428,7 @@ class SaveImage:
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
for image in images:
for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
@ -1370,7 +1440,8 @@ class SaveImage:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png"
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({
"filename": file,
@ -1410,17 +1481,32 @@ class LoadImage:
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
img = Image.open(image_path)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (image, mask.unsqueeze(0))
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
@ -1457,6 +1543,8 @@ class LoadImageMask:
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA")
mask = None
c = channel[0].upper()
@ -1478,13 +1566,10 @@ class LoadImageMask:
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image, channel):
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)
return True
class ImageScale:
@ -1614,10 +1699,11 @@ class ImagePadForOutpaint:
def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size()
new_image = torch.zeros(
new_image = torch.ones(
(d1, d2 + top + bottom, d3 + left + right, d4),
dtype=torch.float32,
)
) * 0.5
new_image[:, top:top + d2, left:left + d3, :] = image
mask = torch.ones(
@ -1683,6 +1769,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningConcat": ConditioningConcat,
"ConditioningSetArea": ConditioningSetArea,
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
"ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,
@ -1709,6 +1796,7 @@ NODE_CLASS_MAPPINGS = {
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
"InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
@ -1812,11 +1900,11 @@ def load_custom_node(module_path, ignore=set()):
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
else:
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {module_path} module for custom nodes:", e)
logging.warning(traceback.format_exc())
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
return False
def load_custom_nodes():
@ -1837,14 +1925,14 @@ def load_custom_nodes():
node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0:
print("\nImport times for custom nodes:")
logging.info("\nImport times for custom nodes:")
for n in sorted(node_import_times):
if n[2]:
import_message = ""
else:
import_message = " (IMPORT FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
def init_custom_nodes():
extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras")
@ -1867,9 +1955,31 @@ def init_custom_nodes():
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
"nodes_sag.py",
"nodes_perpneg.py",
"nodes_stable3d.py",
"nodes_sdupscale.py",
"nodes_photomaker.py",
"nodes_cond.py",
"nodes_morphology.py",
"nodes_stable_cascade.py",
"nodes_differential_diffusion.py",
]
import_failed = []
for node_file in extras_files:
load_custom_node(os.path.join(extras_dir, node_file))
if not load_custom_node(os.path.join(extras_dir, node_file)):
import_failed.append(node_file)
load_custom_nodes()
if len(import_failed) > 0:
logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
for node in import_failed:
logging.warning("IMPORT FAILED: {}".format(node))
logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.")
if args.windows_standalone_build:
logging.warning("Please run the update script: update/update_comfyui.bat")
else:
logging.warning("Please do a: pip install -r requirements.txt")
logging.warning("")

3
requirements.txt

@ -1,13 +1,14 @@
torch
torchsde
torchvision
einops
transformers>=4.25.1
safetensors>=0.3.0
aiohttp
accelerate
pyyaml
Pillow
scipy
tqdm
psutil
kornia>=0.7.1
spandrel==0.3.1

159
script_examples/websockets_api_example_ws_images.py

@ -0,0 +1,159 @@
#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without
#them being saved to disk
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
server_address = "127.0.0.1:8188"
client_id = str(uuid.uuid4())
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)['prompt_id']
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['prompt_id'] == prompt_id:
if data['node'] is None:
break #Execution is done
else:
current_node = data['node']
else:
if current_node == 'save_image_websocket_node':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
return output_images
prompt_text = """
{
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": 8,
"denoise": 1,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "euler",
"scheduler": "normal",
"seed": 8566257,
"steps": 20
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": 512,
"width": 512
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "masterpiece best quality girl"
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "bad hands"
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"save_image_websocket_node": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": [
"8",
0
]
}
}
}
"""
prompt = json.loads(prompt_text)
#set the text prompt for our positive CLIPTextEncode
prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
#set the seed for our KSampler node
prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt)
#Commented out code to display the output images:
# for node_id in images:
# for image_data in images[node_id]:
# from PIL import Image
# import io
# image = Image.open(io.BytesIO(image_data))
# image.show()

58
server.py

@ -15,21 +15,16 @@ from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
try:
import aiohttp
from aiohttp import web
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()
import aiohttp
from aiohttp import web
import logging
import mimetypes
from comfy.cli_args import args
import comfy.utils
import comfy.model_management
from app.user_manager import UserManager
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@ -39,7 +34,7 @@ async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
print("send error:", err)
logging.warning("send error: {}".format(err))
@web.middleware
async def cache_control(request: web.Request, handler):
@ -72,6 +67,7 @@ class PromptServer():
mimetypes.init()
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.loop = loop
@ -116,7 +112,7 @@ class PromptServer():
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
logging.warning('ws connection closed with exception %s' % ws.exception())
finally:
self.sockets.pop(sid, None)
return ws
@ -417,8 +413,8 @@ class PromptServer():
try:
out[x] = node_info(x)
except Exception as e:
print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr)
traceback.print_exc()
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}")
@ -451,7 +447,7 @@ class PromptServer():
@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
logging.info("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()
@ -483,7 +479,7 @@ class PromptServer():
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
else:
print("invalid prompt:", valid[1])
logging.warning("invalid prompt: {}".format(valid[1]))
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
@ -507,6 +503,17 @@ class PromptServer():
nodes.interrupt_processing()
return web.Response(status=200)
@routes.post("/free")
async def post_free(request):
json_data = await request.json()
unload_models = json_data.get("unload_models", False)
free_memory = json_data.get("free_memory", False)
if unload_models:
self.prompt_queue.set_flag("unload_models", unload_models)
if free_memory:
self.prompt_queue.set_flag("free_memory", free_memory)
return web.Response(status=200)
@routes.post("/history")
async def post_history(request):
json_data = await request.json()
@ -521,15 +528,16 @@ class PromptServer():
return web.Response(status=200)
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.app.add_routes(self.routes)
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([
web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True),
web.static('/extensions/' + urllib.parse.quote(name), dir),
])
self.app.add_routes([
web.static('/', self.web_root, follow_symlinks=True),
web.static('/', self.web_root),
])
def get_queue_info(self):
@ -584,7 +592,8 @@ class PromptServer():
message = self.encode_bytes(event, data)
if sid is None:
for ws in self.sockets.values():
sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_bytes, message)
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
@ -593,7 +602,8 @@ class PromptServer():
message = {"type": event, "data": data}
if sid is None:
for ws in self.sockets.values():
sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_json, message)
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message)
@ -616,11 +626,9 @@ class PromptServer():
site = web.TCPSite(runner, address, port)
await site.start()
if address == '':
address = '0.0.0.0'
if verbose:
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
logging.info("Starting server\n")
logging.info("To see the GUI go to: http://{}:{}".format(address, port))
if call_on_start is not None:
call_on_start(address, port)
@ -632,7 +640,7 @@ class PromptServer():
try:
json_data = handler(json_data)
except Exception as e:
print(f"[ERROR] An error occurred during the on_prompt_handler processing")
traceback.print_exc()
logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
logging.warning(traceback.format_exc())
return json_data

9
tests-ui/afterSetup.js

@ -0,0 +1,9 @@
const { start } = require("./utils");
const lg = require("./utils/litegraph");
// Load things once per test file before to ensure its all warmed up for the tests
beforeAll(async () => {
lg.setup(global);
await start({ resetEnv: true });
lg.teardown(global);
});

3
tests-ui/babel.config.json

@ -1,3 +1,4 @@
{
"presets": ["@babel/preset-env"]
"presets": ["@babel/preset-env"],
"plugins": ["babel-plugin-transform-import-meta"]
}

2
tests-ui/jest.config.js

@ -2,8 +2,10 @@
const config = {
testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"],
setupFilesAfterEnv: ["./afterSetup.js"],
clearMocks: true,
resetModules: true,
testTimeout: 10000
};
module.exports = config;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save