Browse Source

Merge branch 'comfyanonymous:master' into bugfix/extra_data

pull/820/head
Dr.Lt.Data 10 months ago committed by GitHub
parent
commit
fba067a8d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      .github/workflows/test-ui.yaml
  2. 1
      .gitignore
  3. 9
      .vscode/settings.json
  4. 54
      app/app_settings.py
  5. 140
      app/user_manager.py
  6. 6
      comfy/cli_args.py
  7. 4
      comfy/clip_model.py
  8. 23
      comfy/clip_vision.py
  9. 5
      comfy/controlnet.py
  10. 4
      comfy/latent_formats.py
  11. 5
      comfy/ldm/models/autoencoder.py
  12. 36
      comfy/ldm/modules/attention.py
  13. 2
      comfy/ldm/modules/diffusionmodules/model.py
  14. 6
      comfy/ldm/modules/diffusionmodules/openaimodel.py
  15. 16
      comfy/ldm/modules/diffusionmodules/upscaling.py
  16. 6
      comfy/ldm/modules/diffusionmodules/util.py
  17. 4
      comfy/ldm/modules/encoders/noise_aug_modules.py
  18. 28
      comfy/ldm/modules/sub_quadratic_attention.py
  19. 8
      comfy/ldm/modules/temporal_ae.py
  20. 111
      comfy/model_base.py
  21. 8
      comfy/model_detection.py
  22. 69
      comfy/model_management.py
  23. 19
      comfy/model_patcher.py
  24. 92
      comfy/ops.py
  25. 4
      comfy/sample.py
  26. 19
      comfy/samplers.py
  27. 23
      comfy/sd.py
  28. 56
      comfy/supported_models.py
  29. 5
      comfy/taesd/taesd.py
  30. 21
      comfy_extras/nodes_custom_sampler.py
  31. 16
      comfy_extras/nodes_hypertile.py
  32. 4
      comfy_extras/nodes_images.py
  33. 25
      comfy_extras/nodes_latent.py
  34. 3
      comfy_extras/nodes_mask.py
  35. 48
      comfy_extras/nodes_model_advanced.py
  36. 55
      comfy_extras/nodes_perpneg.py
  37. 30
      comfy_extras/nodes_rebatch.py
  38. 12
      comfy_extras/nodes_sag.py
  39. 47
      comfy_extras/nodes_sdupscale.py
  40. 102
      comfy_extras/nodes_stable3d.py
  41. 80
      execution.py
  42. 4
      folder_paths.py
  43. 25
      main.py
  44. 93
      nodes.py
  45. 2
      requirements.txt
  46. 20
      server.py
  47. 9
      tests-ui/afterSetup.js
  48. 3
      tests-ui/babel.config.json
  49. 2
      tests-ui/jest.config.js
  50. 20
      tests-ui/package-lock.json
  51. 1
      tests-ui/package.json
  52. 32
      tests-ui/tests/groupNode.test.js
  53. 295
      tests-ui/tests/users.test.js
  54. 16
      tests-ui/utils/index.js
  55. 36
      tests-ui/utils/setup.js
  56. 9
      web/extensions/core/groupNode.js
  57. 62
      web/extensions/core/nodeTemplates.js
  58. 30
      web/index.html
  59. 45
      web/lib/litegraph.core.js
  60. 100
      web/scripts/api.js
  61. 101
      web/scripts/app.js
  62. 1
      web/scripts/domWidget.js
  63. 269
      web/scripts/ui.js
  64. 32
      web/scripts/ui/dialog.js
  65. 307
      web/scripts/ui/settings.js
  66. 34
      web/scripts/ui/spinner.css
  67. 9
      web/scripts/ui/spinner.js
  68. 135
      web/scripts/ui/userSelection.css
  69. 114
      web/scripts/ui/userSelection.js
  70. 21
      web/scripts/utils.js
  71. 2
      web/style.css

2
.github/workflows/test-ui.yaml

@ -22,5 +22,5 @@ jobs:
run: | run: |
npm ci npm ci
npm run test:generate npm run test:generate
npm test npm test -- --verbose
working-directory: ./tests-ui working-directory: ./tests-ui

1
.gitignore vendored

@ -15,3 +15,4 @@ venv/
!/web/extensions/logging.js.example !/web/extensions/logging.js.example
!/web/extensions/core/ !/web/extensions/core/
/tests-ui/data/object_info.json /tests-ui/data/object_info.json
/user/

9
.vscode/settings.json vendored

@ -1,9 +0,0 @@
{
"path-intellisense.mappings": {
"../": "${workspaceFolder}/web/extensions/core"
},
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

54
app/app_settings.py

@ -0,0 +1,54 @@
import os
import json
from aiohttp import web
class AppSettings():
def __init__(self, user_manager):
self.user_manager = user_manager
def get_settings(self, request):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
if os.path.isfile(file):
with open(file) as f:
return json.load(f)
else:
return {}
def save_settings(self, request, settings):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
with open(file, "w") as f:
f.write(json.dumps(settings, indent=4))
def add_routes(self, routes):
@routes.get("/settings")
async def get_settings(request):
return web.json_response(self.get_settings(request))
@routes.get("/settings/{id}")
async def get_setting(request):
value = None
settings = self.get_settings(request)
setting_id = request.match_info.get("id", None)
if setting_id and setting_id in settings:
value = settings[setting_id]
return web.json_response(value)
@routes.post("/settings")
async def post_settings(request):
settings = self.get_settings(request)
new_settings = await request.json()
self.save_settings(request, {**settings, **new_settings})
return web.Response(status=200)
@routes.post("/settings/{id}")
async def post_setting(request):
setting_id = request.match_info.get("id", None)
if not setting_id:
return web.Response(status=400)
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)

140
app/user_manager.py

@ -0,0 +1,140 @@
import json
import os
import re
import uuid
from aiohttp import web
from comfy.cli_args import args
from folder_paths import user_directory
from .app_settings import AppSettings
default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class UserManager():
def __init__(self):
global user_directory
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(users_file):
with open(users_file) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
if user not in self.users:
raise KeyError("Unknown user: " + user)
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory
if type == "userdata":
root_dir = user_directory
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
return None
parent = user_root
if file is not None:
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
return None
if create_dir and not os.path.exists(parent):
os.mkdir(parent)
return path
def add_user(self, name):
name = name.strip()
if not name:
raise ValueError("username not provided")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
global users_file
with open(users_file, "w") as f:
json.dump(self.users, f)
return user_id
def add_routes(self, routes):
self.settings.add_routes(routes)
@routes.get("/users")
async def get_users(request):
if args.multi_user:
return web.json_response({"storage": "server", "users": self.users})
else:
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
return web.json_response({
"storage": "server",
"migrated": os.path.exists(user_dir)
})
@routes.post("/users")
async def post_users(request):
body = await request.json()
username = body["username"]
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
user_id = self.add_user(username)
return web.json_response(user_id)
@routes.get("/userdata/{file}")
async def getuserdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if not os.path.exists(path):
return web.Response(status=404)
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):
file = request.match_info.get("file", None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
body = await request.read()
with open(path, "wb") as f:
f.write(body)
return web.Response(status=200)

6
comfy/cli_args.py

@ -66,6 +66,8 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
fpte_group = parser.add_mutually_exclusive_group() fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
@ -102,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@ -110,6 +112,8 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()
else: else:

4
comfy/clip_model.py

@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module):
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None): def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None: if intermediate_output is not None:
if intermediate_output < 0: if intermediate_output < 0:
@ -151,7 +151,7 @@ class CLIPVisionEmbeddings(torch.nn.Module):
def forward(self, pixel_values): def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module): class CLIPVision(torch.nn.Module):

23
comfy/clip_vision.py

@ -19,8 +19,10 @@ class Output:
def clip_preprocess(image, size=224): def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2])) image = image.movedim(-1, 1)
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2 h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2 w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size] image = image[:,:,h:h+size,w:w+size]
@ -34,11 +36,9 @@ class ClipVisionModel():
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32 self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.dtype = torch.float16 self.model.eval()
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd): def load_sd(self, sd):
@ -46,14 +46,7 @@ class ClipVisionModel():
def encode_image(self, image): def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)) pixel_values = clip_preprocess(image.to(self.load_device)).float()
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
out = self.model(pixel_values=pixel_values, intermediate_output=-2) out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output() outputs = Output()

5
comfy/controlnet.py

@ -125,6 +125,9 @@ class ControlBase:
elif prev_val is not None: elif prev_val is not None:
if o[i] is None: if o[i] is None:
o[i] = prev_val o[i] = prev_val
else:
if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else: else:
o[i] += prev_val o[i] += prev_val
return out return out
@ -283,7 +286,7 @@ class ControlLora(ControlNet):
cm = self.control_model.state_dict() cm = self.control_model.state_dict()
for k in sd: for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) weight = sd[k]
try: try:
comfy.utils.set_attr(self.control_model, k, weight) comfy.utils.set_attr(self.control_model, k, weight)
except: except:

4
comfy/latent_formats.py

@ -33,3 +33,7 @@ class SDXL(LatentFormat):
[-0.3112, -0.2359, -0.2076] [-0.3112, -0.2359, -0.2076]
] ]
self.taesd_decoder_name = "taesdxl_decoder" self.taesd_decoder_name = "taesdxl_decoder"
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333

5
comfy/ldm/models/autoencoder.py

@ -8,6 +8,7 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module): class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True): def __init__(self, sample: bool = True):
@ -161,12 +162,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
}, },
**kwargs, **kwargs,
) )
self.quant_conv = torch.nn.Conv2d( self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim, (1 + ddconfig["double_z"]) * embed_dim,
1, 1,
) )
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:

36
comfy/ldm/modules/attention.py

@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'): sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else: else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -179,6 +177,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
kv_chunk_size_min=kv_chunk_size_min, kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False, use_checkpoint=False,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
mask=mask,
) )
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)
@ -241,6 +240,12 @@ def attention_split(q, k, v, heads, mask=None):
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype) s2 = s1.softmax(dim=-1).to(v.dtype)
del s1 del s1
first_op_done = True first_op_done = True
@ -296,11 +301,14 @@ def attention_xformers(q, k, v, heads, mask=None):
(q, k, v), (q, k, v),
) )
# actually compute the attention, what we cannot get enough of if mask is not None:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if exists(mask):
raise NotImplementedError
out = ( out = (
out.unsqueeze(0) out.unsqueeze(0)
.reshape(b, heads, -1, dim_head) .reshape(b, heads, -1, dim_head)
@ -325,7 +333,6 @@ def attention_pytorch(q, k, v, heads, mask=None):
optimized_attention = attention_basic optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") print("Using xformers cross attention")
@ -341,15 +348,18 @@ else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled(): optimized_attention_masked = optimized_attention
optimized_attention_masked = attention_pytorch
def optimized_attention_for_device(device, mask=False): def optimized_attention_for_device(device, mask=False, small_input=False):
if device == torch.device("cpu"): #TODO if small_input:
if model_management.pytorch_attention_enabled(): if model_management.pytorch_attention_enabled():
return attention_pytorch return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else: else:
return attention_basic return attention_basic
if device == torch.device("cpu"):
return attention_sub_quad
if mask: if mask:
return optimized_attention_masked return optimized_attention_masked

2
comfy/ldm/modules/diffusionmodules/model.py

@ -41,7 +41,7 @@ def nonlinearity(x):
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module): class Upsample(nn.Module):

6
comfy/ldm/modules/diffusionmodules/openaimodel.py

@ -437,9 +437,6 @@ class UNetModel(nn.Module):
operations=ops, operations=ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None: if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
@ -456,7 +453,6 @@ class UNetModel(nn.Module):
if num_head_channels == -1: if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -502,7 +498,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None: if self.num_classes is not None:
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer") print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)

16
comfy/ldm/modules/diffusionmodules/upscaling.py

@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module):
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None, seed=None):
noise = default(noise, lambda: torch.randn_like(x_start)) if noise is None:
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + if seed is None:
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) noise = torch.randn_like(x_start)
else:
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
def forward(self, x): def forward(self, x):
return x, None return x, None
@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
super().__init__(noise_schedule_config=noise_schedule_config) super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None): def forward(self, x, noise_level=None, seed=None):
if noise_level is None: if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else: else:
assert isinstance(noise_level, torch.Tensor) assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level) z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level return z, noise_level

6
comfy/ldm/modules/diffusionmodules/util.py

@ -51,9 +51,9 @@ class AlphaBlender(nn.Module):
if self.merge_strategy == "fixed": if self.merge_strategy == "fixed":
# make shape compatible # make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor alpha = self.mix_factor.to(image_only_indicator.device)
elif self.merge_strategy == "learned": elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor) alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
# make shape compatible # make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs) # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images": elif self.merge_strategy == "learned_with_images":
@ -61,7 +61,7 @@ class AlphaBlender(nn.Module):
alpha = torch.where( alpha = torch.where(
image_only_indicator.bool(), image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device), torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
) )
alpha = rearrange(alpha, self.rearrange_pattern) alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible # make shape compatible

4
comfy/ldm/modules/encoders/noise_aug_modules.py

@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def scale(self, x): def scale(self, x):
# re-normalize to centered mean and unit variance # re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1. / self.data_std x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
return x return x
def unscale(self, x): def unscale(self, x):
# back to original data stats # back to original data stats
x = (x * self.data_std) + self.data_mean x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x return x
def forward(self, x, noise_level=None): def forward(self, x, noise_level=None):

28
comfy/ldm/modules/sub_quadratic_attention.py

@ -61,6 +61,7 @@ def _summarize_chunk(
value: Tensor, value: Tensor,
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
mask,
) -> AttnChunk: ) -> AttnChunk:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -84,6 +85,8 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach() max_score = max_score.detach()
attn_weights -= max_score attn_weights -= max_score
if mask is not None:
attn_weights += mask
torch.exp(attn_weights, out=attn_weights) torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype) exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value) exp_values = torch.bmm(exp_weights, value)
@ -96,11 +99,12 @@ def _query_chunk_attention(
value: Tensor, value: Tensor,
summarize_chunk: SummarizeChunk, summarize_chunk: SummarizeChunk,
kv_chunk_size: int, kv_chunk_size: int,
mask,
) -> Tensor: ) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape _, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk: def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
key_chunk = dynamic_slice( key_chunk = dynamic_slice(
key_t, key_t,
(0, 0, chunk_idx), (0, 0, chunk_idx),
@ -111,10 +115,13 @@ def _query_chunk_attention(
(0, chunk_idx, 0), (0, chunk_idx, 0),
(batch_x_heads, kv_chunk_size, v_channels_per_head) (batch_x_heads, kv_chunk_size, v_channels_per_head)
) )
return summarize_chunk(query, key_chunk, value_chunk) if mask is not None:
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
chunks: List[AttnChunk] = [ chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
] ]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk chunk_values, chunk_weights, chunk_max = acc_chunk
@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking(
value: Tensor, value: Tensor,
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
mask,
) -> Tensor: ) -> Tensor:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking(
beta=0, beta=0,
) )
if mask is not None:
attn_scores += mask
try: try:
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
@ -183,6 +193,7 @@ def efficient_dot_product_attention(
kv_chunk_size_min: Optional[int] = None, kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True, use_checkpoint=True,
upcast_attention=False, upcast_attention=False,
mask = None,
): ):
"""Computes efficient dot-product attention given query, transposed key, and value. """Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in This is efficient version of attention presented in
@ -209,6 +220,9 @@ def efficient_dot_product_attention(
if kv_chunk_size_min is not None: if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
if mask is not None and len(mask.shape) == 2:
mask = mask.unsqueeze(0)
def get_query_chunk(chunk_idx: int) -> Tensor: def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice( return dynamic_slice(
query, query,
@ -216,6 +230,12 @@ def efficient_dot_product_attention(
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
) )
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@ -237,6 +257,7 @@ def efficient_dot_product_attention(
query=query, query=query,
key_t=key_t, key_t=key_t,
value=value, value=value,
mask=mask,
) )
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
@ -246,6 +267,7 @@ def efficient_dot_product_attention(
query=get_query_chunk(i * query_chunk_size), query=get_query_chunk(i * query_chunk_size),
key_t=key_t, key_t=key_t,
value=value, value=value,
mask=get_mask_chunk(i * query_chunk_size)
) for i in range(math.ceil(q_tokens / query_chunk_size)) ) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1) ], dim=1)
return res return res

8
comfy/ldm/modules/temporal_ae.py

@ -82,14 +82,14 @@ class VideoResBlock(ResnetBlock):
x = self.time_stack(x, temb) x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps) alpha = self.get_alpha(bs=b // timesteps).to(x.device)
x = alpha * x + (1.0 - alpha) * x_mix x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
return x return x
class AE3DConv(torch.nn.Conv2d): class AE3DConv(ops.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs) super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable): if isinstance(video_kernel_size, Iterable):
@ -97,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
else: else:
padding = int(video_kernel_size // 2) padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d( self.time_mix_conv = ops.Conv3d(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=video_kernel_size, kernel_size=video_kernel_size,
@ -167,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
emb = emb[:, None, :] emb = emb[:, None, :]
x_mix = x_mix + emb x_mix = x_mix + emb
alpha = self.get_alpha() alpha = self.get_alpha().to(x.device)
x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge x = alpha * x + (1.0 - alpha) * x_mix # alpha merge

111
comfy/model_base.py

@ -1,7 +1,7 @@
import torch import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.ops import comfy.ops
@ -78,7 +78,8 @@ class BaseModel(torch.nn.Module):
extra_conds = {} extra_conds = {}
for o in kwargs: for o in kwargs:
extra = kwargs[o] extra = kwargs[o]
if hasattr(extra, "to"): if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype) extra = extra.to(dtype)
extra_conds[o] = extra extra_conds[o] = extra
@ -99,11 +100,29 @@ class BaseModel(torch.nn.Module):
if self.inpaint_model: if self.inpaint_model:
concat_keys = ("mask", "masked_image") concat_keys = ("mask", "masked_image")
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
latent_image = kwargs.get("latent_image", None) concat_latent_image = kwargs.get("concat_latent_image", None)
if concat_latent_image is None:
concat_latent_image = kwargs.get("latent_image", None)
else:
concat_latent_image = self.process_latent_in(concat_latent_image)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image): def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image) blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space # these are the values for "zero" in pixel space translated to latent space
@ -116,9 +135,9 @@ class BaseModel(torch.nn.Module):
for ck in concat_keys: for ck in concat_keys:
if denoise_mask is not None: if denoise_mask is not None:
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device)) cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else: else:
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])
@ -126,9 +145,15 @@ class BaseModel(torch.nn.Module):
cond_concat.append(blank_inpaint_image_like(noise)) cond_concat.append(blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data) out['c_concat'] = comfy.conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs) adm = self.encode_adm(**kwargs)
if adm is not None: if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm) out['y'] = comfy.conds.CONDRegular(adm)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
return out return out
def load_model_weights(self, sd, unet_prefix=""): def load_model_weights(self, sd, unet_prefix=""):
@ -156,11 +181,7 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_sd = self.diffusion_model.state_dict() unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:
@ -322,9 +343,75 @@ class SVD_img2vid(BaseModel):
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
if "time_conditioning" in kwargs: if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out return out
class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device)
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
self.cc_projection.weight.copy_(cc_projection_weight)
self.cc_projection.bias.copy_(cc_projection_bias)
def extra_conds(self, **kwargs):
out = {}
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if cross_attn.shape[-1] != 768:
cross_attn = self.cc_projection(cross_attn)
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class SD_X4Upscaler(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_image", None)
noise = kwargs.get("noise", None)
noise_augment = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
seed = kwargs["seed"] - 10
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
if image is None:
image = torch.zeros_like(noise)[:,:3]
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
noise_level = torch.tensor([noise_level], device=device)
if noise_augment > 0:
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
return out

8
comfy/model_detection.py

@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config = { unet_config = {
"use_checkpoint": False, "use_checkpoint": False,
"image_size": 32, "image_size": 32,
"out_channels": 4,
"use_spatial_transformer": True, "use_spatial_transformer": True,
"legacy": False "legacy": False
} }
@ -50,6 +49,12 @@ def detect_unet_config(state_dict, key_prefix, dtype):
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
out_channels = state_dict[out_key].shape[0]
else:
out_channels = 4
num_res_blocks = [] num_res_blocks = []
channel_mult = [] channel_mult = []
attention_resolutions = [] attention_resolutions = []
@ -122,6 +127,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
transformer_depth_middle = -1 transformer_depth_middle = -1
unet_config["in_channels"] = in_channels unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels
unet_config["model_channels"] = model_channels unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks unet_config["num_res_blocks"] = num_res_blocks
unet_config["transformer_depth"] = transformer_depth unet_config["transformer_depth"] = transformer_depth

69
comfy/model_management.py

@ -28,6 +28,10 @@ total_vram = 0
lowvram_available = True lowvram_available = True
xpu_available = False xpu_available = False
if args.deterministic:
print("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False directml_enabled = False
if args.directml is not None: if args.directml is not None:
import torch_directml import torch_directml
@ -182,6 +186,9 @@ except:
if is_intel_xpu(): if is_intel_xpu():
VAE_DTYPE = torch.bfloat16 VAE_DTYPE = torch.bfloat16
if args.cpu_vae:
VAE_DTYPE = torch.float32
if args.fp16_vae: if args.fp16_vae:
VAE_DTYPE = torch.float16 VAE_DTYPE = torch.float16
elif args.bf16_vae: elif args.bf16_vae:
@ -214,15 +221,8 @@ if args.force_fp16:
FORCE_FP16 = True FORCE_FP16 = True
if lowvram_available: if lowvram_available:
try:
import accelerate
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
@ -262,6 +262,14 @@ print("VAE dtype:", VAE_DTYPE)
current_loaded_models = [] current_loaded_models = []
def module_size(module):
module_mem = 0
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
@ -294,8 +302,20 @@ class LoadedModel:
if lowvram_model_memory > 0: if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) mem_counter = 0
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
print("lowvram: loaded module regularly", m)
self.model_accelerated = True self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
@ -305,7 +325,11 @@ class LoadedModel:
def model_unload(self): def model_unload(self):
if self.model_accelerated: if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model) for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
@ -398,14 +422,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
else: else:
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024 lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory) cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
@ -534,6 +558,8 @@ def intermediate_device():
return torch.device("cpu") return torch.device("cpu")
def vae_device(): def vae_device():
if args.cpu_vae:
return torch.device("cpu")
return get_torch_device() return get_torch_device()
def vae_offload_device(): def vae_offload_device():
@ -562,6 +588,11 @@ def supports_dtype(device, dtype): #TODO
return True return True
return False return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
@ -572,9 +603,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu(): elif is_intel_xpu():
device_supports_cast = True device_supports_cast = True
non_blocking = True non_blocking = device_supports_non_blocking(device)
if is_device_mps(device):
non_blocking = False #pytorch bug? mps doesn't support non blocking
if device_supports_cast: if device_supports_cast:
if copy: if copy:
@ -738,11 +767,11 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key): def unload_all_models():
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. free_memory(1e30, get_torch_device())
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]] def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight return weight
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else

19
comfy/model_patcher.py

@ -28,13 +28,9 @@ class ModelPatcher:
if self.size > 0: if self.size > 0:
return self.size return self.size
model_sd = self.model.state_dict() model_sd = self.model.state_dict()
size = 0 self.size = comfy.model_management.module_size(self.model)
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys()) self.model_keys = set(model_sd.keys())
return size return self.size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
@ -55,14 +51,18 @@ class ModelPatcher:
def memory_required(self, input_shape): def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape) return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_post_cfg_function(self, post_cfg_function): def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_unet_function_wrapper(self, unet_wrapper_function): def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function self.model_options["model_function_wrapper"] = unet_wrapper_function
@ -174,13 +174,14 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_model(self, device_to=None): def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches: for k in self.object_patches:
old = getattr(self.model, k) old = getattr(self.model, k)
if k not in self.object_patches_backup: if k not in self.object_patches_backup:
self.object_patches_backup[k] = old self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k]) setattr(self.model, k, self.object_patches[k])
if patch_weights:
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
for key in self.patches: for key in self.patches:
if key not in model_sd: if key not in model_sd:

92
comfy/ops.py

@ -1,27 +1,93 @@
import torch import torch
from contextlib import contextmanager from contextlib import contextmanager
import comfy.model_management
def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d): class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm): class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm): class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod @classmethod
def conv_nd(s, dims, *args, **kwargs): def conv_nd(s, dims, *args, **kwargs):
if dims == 2: if dims == 2:
@ -31,35 +97,19 @@ class disable_weight_init:
else: else:
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
def cast_bias_weight(s, input):
bias = None
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype)
weight = s.weight.to(device=input.device, dtype=input.dtype)
return weight, bias
class manual_cast(disable_weight_init): class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear): class Linear(disable_weight_init.Linear):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(disable_weight_init.Conv2d): class Conv2d(disable_weight_init.Conv2d):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class Conv3d(disable_weight_init.Conv3d): class Conv3d(disable_weight_init.Conv3d):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class GroupNorm(disable_weight_init.GroupNorm): class GroupNorm(disable_weight_init.GroupNorm):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class LayerNorm(disable_weight_init.LayerNorm): class LayerNorm(disable_weight_init.LayerNorm):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

4
comfy/sample.py

@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None):
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions""" """ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device) noise_mask = noise_mask.to(device)
@ -47,7 +46,8 @@ def convert_cond(cond):
temp = c[1].copy() temp = c[1].copy()
model_conds = temp.get("model_conds", {}) model_conds = temp.get("model_conds", {})
if c[0] is not None: if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds temp["model_conds"] = model_conds
out.append(temp) out.append(temp)
return out return out

19
comfy/samplers.py

@ -244,14 +244,15 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0): if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None uncond_ = None
else: else:
uncond_ = uncond uncond_ = uncond
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args) cfg_result = x - model_options["sampler_cfg_function"](args)
else: else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
@ -598,6 +599,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive) calculate_start_end_timesteps(model, positive)
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
create_cond_with_same_area_if_none(negative, c) create_cond_with_same_area_if_none(negative, c)
@ -609,13 +617,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)

23
comfy/sd.py

@ -157,6 +157,8 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
self.latent_channels = 4
if config is None: if config is None:
if "decoder.mid.block_1.mix_factor" in sd: if "decoder.mid.block_1.mix_factor" in sd:
@ -172,6 +174,11 @@ class VAE:
else: else:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else: else:
self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = AutoencoderKL(**(config['params']))
@ -204,9 +211,9 @@ class VAE:
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp(( output = torch.clamp((
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar)) comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0) / 3.0) / 2.0, min=0.0, max=1.0)
return output return output
@ -217,9 +224,9 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0 samples /= 3.0
return samples return samples
@ -231,7 +238,7 @@ class VAE:
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
@ -255,7 +262,7 @@ class VAE:
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device) samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()

56
comfy/supported_models.py

@ -252,5 +252,59 @@ class SVD_img2vid(supported_models_base.BASE):
def clip_target(self): def clip_target(self):
return None return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] class Stable_Zero123(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
unet_extra_config = {
"num_heads": 8,
"num_head_channels": -1,
}
clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
def clip_target(self):
return None
class SD_X4Upscaler(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 256,
'in_channels': 7,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
"disable_self_attentions": [True, True, True, False],
"num_classes": 1000,
"num_heads": 8,
"num_head_channels": -1,
}
latent_format = latent_formats.SD_X4
sampling_settings = {
"linear_start": 0.0001,
"linear_end": 0.02,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD_X4Upscaler(self, device=device)
return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
models += [SVD_img2vid] models += [SVD_img2vid]

5
comfy/taesd/taesd.py

@ -7,9 +7,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import comfy.utils import comfy.utils
import comfy.ops
def conv(n_in, n_out, **kwargs): def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module): class Clamp(nn.Module):
def forward(self, x): def forward(self, x):
@ -19,7 +20,7 @@ class Block(nn.Module):
def __init__(self, n_in, n_out): def __init__(self, n_in, n_out):
super().__init__() super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU() self.fuse = nn.ReLU()
def forward(self, x): def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))

21
comfy_extras/nodes_custom_sampler.py

@ -13,6 +13,7 @@ class BasicScheduler:
{"model": ("MODEL",), {"model": ("MODEL",),
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ), "scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
@ -20,8 +21,14 @@ class BasicScheduler:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, model, scheduler, steps): def get_sigmas(self, model, scheduler, steps, denoise):
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu() total_steps = steps
if denoise < 1.0:
total_steps = int(steps/denoise)
comfy.model_management.load_models_gpu([model])
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):]
return (sigmas, ) return (sigmas, )
@ -87,6 +94,7 @@ class SDTurboScheduler:
return {"required": return {"required":
{"model": ("MODEL",), {"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}), "steps": ("INT", {"default": 1, "min": 1, "max": 10}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
@ -94,9 +102,12 @@ class SDTurboScheduler:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps): def get_sigmas(self, model, steps, denoise):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] start_step = 10 - int(10 * denoise)
sigmas = model.model.model_sampling.sigma(timesteps) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
inner_model = model.patch_model(patch_weights=False)
sigmas = inner_model.model_sampling.sigma(timesteps)
model.unpatch_model()
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, ) return (sigmas, )

16
comfy_extras/nodes_hypertile.py

@ -37,24 +37,24 @@ class HyperTile:
def patch(self, model, tile_size, swap_size, max_depth, scale_depth): def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
apply_to = set()
temp = model_channels
for x in range(max_depth + 1):
apply_to.add(temp)
temp *= 2
latent_tile_size = max(32, tile_size) // 8 latent_tile_size = max(32, tile_size) // 8
self.temp = None self.temp = None
def hypertile_in(q, k, v, extra_options): def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to: model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
for i in range(max_depth + 1):
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
if model_chans in apply_to:
shape = extra_options["original_shape"] shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2] aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1) hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size) nh = random_divisor(h, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size) nw = random_divisor(w, latent_tile_size * factor, swap_size)

4
comfy_extras/nodes_images.py

@ -74,7 +74,7 @@ class SaveAnimatedWEBP:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "_for_testing" CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method) method = self.methods.get(method)
@ -136,7 +136,7 @@ class SaveAnimatedPNG:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "_for_testing" CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append filename_prefix += self.prefix_append

25
comfy_extras/nodes_latent.py

@ -3,9 +3,7 @@ import torch
def reshape_latent_to(target_shape, latent): def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
latent.movedim(1, -1)
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
latent.movedim(-1, 1)
return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
@ -102,9 +100,32 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,) return (samples_out,)
class LatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
if s1.shape[1:] != s2.shape[1:]:
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, "LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply, "LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate, "LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
} }

3
comfy_extras/nodes_mask.py

@ -6,6 +6,7 @@ import comfy.utils
from nodes import MAX_RESOLUTION from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
if resize_source: if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
@ -20,7 +21,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if mask is None: if mask is None:
mask = torch.ones_like(source) mask = torch.ones_like(source)
else: else:
mask = mask.clone() mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])

48
comfy_extras/nodes_model_advanced.py

@ -17,41 +17,19 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteDistilled(torch.nn.Module): class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50 original_timesteps = 50
def __init__(self): def __init__(self, model_config=None):
super().__init__() super().__init__(model_config)
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 self.skip_steps = self.num_timesteps // self.original_timesteps
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.skip_steps = timesteps // self.original_timesteps sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps): for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas): self.set_sigmas(sigmas_valid)
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
@ -66,14 +44,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device) return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas): def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1) alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
@ -122,7 +92,7 @@ class ModelSamplingDiscrete:
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced() model_sampling = ModelSamplingAdvanced(model.model.model_config)
if zsnr: if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
@ -154,7 +124,7 @@ class ModelSamplingContinuousEDM:
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced() model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max) model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )

55
comfy_extras/nodes_perpneg.py

@ -0,0 +1,55 @@
import torch
import comfy.model_management
import comfy.sample
import comfy.samplers
import comfy.utils
class PerpNeg:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()
nocond = comfy.sample.convert_cond(empty_conditioning)
def cfg_function(args):
model = args["model"]
noise_pred_pos = args["cond_denoised"]
noise_pred_neg = args["uncond_denoised"]
cond_scale = args["cond_scale"]
x = args["input"]
sigma = args["sigma"]
model_options = args["model_options"]
nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond
perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
cfg_result = x - cfg_result
return cfg_result
m.set_model_sampler_cfg_function(cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg",
}

30
comfy_extras/nodes_rebatch.py

@ -99,10 +99,40 @@ class LatentRebatch:
return (output_list,) return (output_list,)
class ImageRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
batch_size = batch_size[0]
output_list = []
all_images = []
for img in images:
for i in range(img.shape[0]):
all_images.append(img[i:i+1])
for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch, "RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents", "RebatchLatents": "Rebatch Latents",
"RebatchImages": "Rebatch Images",
} }

12
comfy_extras/nodes_sag.py

@ -27,9 +27,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'): sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else: else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -60,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2) attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool # Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = round(math.sqrt(lh * lw / hw1)) ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape # Reshape
@ -111,7 +109,6 @@ class SelfAttentionGuidance:
m = model.clone() m = model.clone()
attn_scores = None attn_scores = None
mid_block_shape = None
# TODO: make this work properly with chunked batches # TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call # currently, we can only save the attn from one UNet call
@ -134,7 +131,6 @@ class SelfAttentionGuidance:
def post_cfg_function(args): def post_cfg_function(args):
nonlocal attn_scores nonlocal attn_scores
nonlocal mid_block_shape
uncond_attn = attn_scores uncond_attn = attn_scores
sag_scale = scale sag_scale = scale
@ -147,6 +143,8 @@ class SelfAttentionGuidance:
sigma = args["sigma"] sigma = args["sigma"]
model_options = args["model_options"] model_options = args["model_options"]
x = args["input"] x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
return cfg_result
# create the adversarially blurred image # create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
@ -155,7 +153,7 @@ class SelfAttentionGuidance:
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function) m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
# from diffusers: # from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch

47
comfy_extras/nodes_sdupscale.py

@ -0,0 +1,47 @@
import torch
import nodes
import comfy.utils
class SD_4XUpscale_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/upscale_diffusion"
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
width = max(1, round(images.shape[-2] * scale_ratio))
height = max(1, round(images.shape[-3] * scale_ratio))
pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
out_cp = []
out_cn = []
for t in positive:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cp.append(n)
for t in negative:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cn.append(n)
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
return (out_cp, out_cn, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
}

102
comfy_extras/nodes_stable3d.py

@ -0,0 +1,102 @@
import torch
import nodes
import comfy.utils
def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation])
azimuth = torch.as_tensor([azimuth])
embeddings = torch.stack(
[
torch.deg2rad(
(90 - elevation) - (90)
), # Zero123 polar is 90-elevation
torch.sin(torch.deg2rad(azimuth)),
torch.cos(torch.deg2rad(azimuth)),
torch.deg2rad(
90 - torch.full_like(elevation, 0)
),
], dim=-1).unsqueeze(1)
return embeddings
class StableZero123_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = camera_embeddings(elevation, azimuth)
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = []
for i in range(batch_size):
cam_embeds.append(camera_embeddings(elevation, azimuth))
elevation += elevation_batch_increment
azimuth += azimuth_batch_increment
cam_embeds = torch.cat(cam_embeds, dim=0)
cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning,
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
}

80
execution.py

@ -7,6 +7,7 @@ import threading
import heapq import heapq
import traceback import traceback
import gc import gc
import inspect
import torch import torch
import nodes import nodes
@ -267,11 +268,14 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor: class PromptExecutor:
def __init__(self, server): def __init__(self, server):
self.server = server
self.reset()
def reset(self):
self.outputs = {} self.outputs = {}
self.object_storage = {} self.object_storage = {}
self.outputs_ui = {} self.outputs_ui = {}
self.old_prompt = {} self.old_prompt = {}
self.server = server
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"] node_id = error["node_id"]
@ -382,6 +386,8 @@ class PromptExecutor:
for x in executed: for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
@ -400,6 +406,10 @@ def validate_inputs(prompt, item, validated):
errors = [] errors = []
valid = True valid = True
validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
for x in required_inputs: for x in required_inputs:
if x not in inputs: if x not in inputs:
error = { error = {
@ -529,29 +539,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if hasattr(obj_class, "VALIDATE_INPUTS"): if x not in validate_function_inputs:
input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
input_config = info input_config = info
@ -578,6 +566,35 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True: if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id) ret = (False, errors, unique_id)
else: else:
@ -692,6 +709,7 @@ class PromptQueue:
self.queue = [] self.queue = []
self.currently_running = {} self.currently_running = {}
self.history = {} self.history = {}
self.flags = {}
server.prompt_queue = self server.prompt_queue = self
def put(self, item): def put(self, item):
@ -778,3 +796,17 @@ class PromptQueue:
def delete_history_item(self, id_to_delete): def delete_history_item(self, id_to_delete):
with self.mutex: with self.mutex:
self.history.pop(id_to_delete, None) self.history.pop(id_to_delete, None)
def set_flag(self, name, data):
with self.mutex:
self.flags[name] = data
self.not_empty.notify()
def get_flags(self, reset=True):
with self.mutex:
if reset:
ret = self.flags
self.flags = {}
return ret
else:
return self.flags.copy()

4
folder_paths.py

@ -34,6 +34,7 @@ folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers"
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
filename_list_cache = {} filename_list_cache = {}
@ -184,8 +185,7 @@ def cached_filename_list_(folder_name):
if folder_name not in filename_list_cache: if folder_name not in filename_list_cache:
return None return None
out = filename_list_cache[folder_name] out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]: for x in out[1]:
time_modified = out[1][x] time_modified = out[1][x]
folder = x folder = x

25
main.py

@ -64,6 +64,10 @@ if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device) print("Set cuda device to:", args.cuda_device)
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc import cuda_malloc
import comfy.utils import comfy.utils
@ -93,7 +97,7 @@ def prompt_worker(q, server):
gc_collect_interval = 10.0 gc_collect_interval = 10.0
while True: while True:
timeout = None timeout = 1000.0
if need_gc: if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
@ -102,6 +106,8 @@ def prompt_worker(q, server):
item, item_id = queue_item item, item_id = queue_item
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
prompt_id = item[1] prompt_id = item[1]
server.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
q.task_done(item_id, e.outputs_ui) q.task_done(item_id, e.outputs_ui)
@ -112,6 +118,19 @@ def prompt_worker(q, server):
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time)) print("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)
if flags.get("unload_models", free_memory):
comfy.model_management.unload_all_models()
need_gc = True
last_gc_collect = 0
if free_memory:
e.reset()
need_gc = True
last_gc_collect = 0
if need_gc: if need_gc:
current_time = time.perf_counter() current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval: if (current_time - last_gc_collect) > gc_collect_interval:
@ -127,7 +146,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server): def hijack_progress(server):
def hook(value, total, preview_image): def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
server.send_sync("progress", {"value": value, "max": total}, server.client_id) progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
server.send_sync("progress", progress, server.client_id)
if preview_image is not None: if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)

93
nodes.py

@ -9,7 +9,7 @@ import math
import time import time
import random import random
from PIL import Image, ImageOps from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
@ -359,6 +359,62 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class InpaintModelConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
orig_pixels = pixels
pixels = orig_pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
m = (1.0 - mask.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels)
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -1410,8 +1466,13 @@ class LoadImage:
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path) img = Image.open(image_path)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i) i = ImageOps.exif_transpose(i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
@ -1420,7 +1481,17 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask) mask = 1. - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (image, mask.unsqueeze(0)) output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
@ -1457,6 +1528,8 @@ class LoadImageMask:
i = Image.open(image_path) i = Image.open(image_path)
i = ImageOps.exif_transpose(i) i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"): if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA") i = i.convert("RGBA")
mask = None mask = None
c = channel[0].upper() c = channel[0].upper()
@ -1478,13 +1551,10 @@ class LoadImageMask:
return m.digest().hex() return m.digest().hex()
@classmethod @classmethod
def VALIDATE_INPUTS(s, image, channel): def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image): if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image) return "Invalid image file: {}".format(image)
if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)
return True return True
class ImageScale: class ImageScale:
@ -1614,10 +1684,11 @@ class ImagePadForOutpaint:
def expand_image(self, image, left, top, right, bottom, feathering): def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size() d1, d2, d3, d4 = image.size()
new_image = torch.zeros( new_image = torch.ones(
(d1, d2 + top + bottom, d3 + left + right, d4), (d1, d2 + top + bottom, d3 + left + right, d4),
dtype=torch.float32, dtype=torch.float32,
) ) * 0.5
new_image[:, top:top + d2, left:left + d3, :] = image new_image[:, top:top + d2, left:left + d3, :] = image
mask = torch.ones( mask = torch.ones(
@ -1709,6 +1780,7 @@ NODE_CLASS_MAPPINGS = {
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader, "GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply, "GLIGENTextBoxApply": GLIGENTextBoxApply,
"InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
@ -1868,6 +1940,9 @@ def init_custom_nodes():
"nodes_images.py", "nodes_images.py",
"nodes_video_model.py", "nodes_video_model.py",
"nodes_sag.py", "nodes_sag.py",
"nodes_perpneg.py",
"nodes_stable3d.py",
"nodes_sdupscale.py",
] ]
for node_file in extras_files: for node_file in extras_files:

2
requirements.txt

@ -1,10 +1,10 @@
torch torch
torchsde torchsde
torchvision
einops einops
transformers>=4.25.1 transformers>=4.25.1
safetensors>=0.3.0 safetensors>=0.3.0
aiohttp aiohttp
accelerate
pyyaml pyyaml
Pillow Pillow
scipy scipy

20
server.py

@ -30,6 +30,7 @@ from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from app.user_manager import UserManager
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
@ -72,6 +73,7 @@ class PromptServer():
mimetypes.init() mimetypes.init()
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.supports = ["custom_nodes_from_web"] self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
@ -507,6 +509,17 @@ class PromptServer():
nodes.interrupt_processing() nodes.interrupt_processing()
return web.Response(status=200) return web.Response(status=200)
@routes.post("/free")
async def post_free(request):
json_data = await request.json()
unload_models = json_data.get("unload_models", False)
free_memory = json_data.get("free_memory", False)
if unload_models:
self.prompt_queue.set_flag("unload_models", unload_models)
if free_memory:
self.prompt_queue.set_flag("free_memory", free_memory)
return web.Response(status=200)
@routes.post("/history") @routes.post("/history")
async def post_history(request): async def post_history(request):
json_data = await request.json() json_data = await request.json()
@ -521,6 +534,7 @@ class PromptServer():
return web.Response(status=200) return web.Response(status=200)
def add_routes(self): def add_routes(self):
self.user_manager.add_routes(self.routes)
self.app.add_routes(self.routes) self.app.add_routes(self.routes)
for name, dir in nodes.EXTENSION_WEB_DIRS.items(): for name, dir in nodes.EXTENSION_WEB_DIRS.items():
@ -584,7 +598,8 @@ class PromptServer():
message = self.encode_bytes(event, data) message = self.encode_bytes(event, data)
if sid is None: if sid is None:
for ws in self.sockets.values(): sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_bytes, message) await send_socket_catch_exception(ws.send_bytes, message)
elif sid in self.sockets: elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_bytes, message) await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
@ -593,7 +608,8 @@ class PromptServer():
message = {"type": event, "data": data} message = {"type": event, "data": data}
if sid is None: if sid is None:
for ws in self.sockets.values(): sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_json, message) await send_socket_catch_exception(ws.send_json, message)
elif sid in self.sockets: elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message) await send_socket_catch_exception(self.sockets[sid].send_json, message)

9
tests-ui/afterSetup.js

@ -0,0 +1,9 @@
const { start } = require("./utils");
const lg = require("./utils/litegraph");
// Load things once per test file before to ensure its all warmed up for the tests
beforeAll(async () => {
lg.setup(global);
await start({ resetEnv: true });
lg.teardown(global);
});

3
tests-ui/babel.config.json

@ -1,3 +1,4 @@
{ {
"presets": ["@babel/preset-env"] "presets": ["@babel/preset-env"],
"plugins": ["babel-plugin-transform-import-meta"]
} }

2
tests-ui/jest.config.js

@ -2,8 +2,10 @@
const config = { const config = {
testEnvironment: "jsdom", testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"], setupFiles: ["./globalSetup.js"],
setupFilesAfterEnv: ["./afterSetup.js"],
clearMocks: true, clearMocks: true,
resetModules: true, resetModules: true,
testTimeout: 10000
}; };
module.exports = config; module.exports = config;

20
tests-ui/package-lock.json generated

@ -11,6 +11,7 @@
"devDependencies": { "devDependencies": {
"@babel/preset-env": "^7.22.20", "@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5", "@types/jest": "^29.5.5",
"babel-plugin-transform-import-meta": "^2.2.1",
"jest": "^29.7.0", "jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0" "jest-environment-jsdom": "^29.7.0"
} }
@ -2591,6 +2592,19 @@
"@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0" "@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0"
} }
}, },
"node_modules/babel-plugin-transform-import-meta": {
"version": "2.2.1",
"resolved": "https://registry.npmjs.org/babel-plugin-transform-import-meta/-/babel-plugin-transform-import-meta-2.2.1.tgz",
"integrity": "sha512-AxNh27Pcg8Kt112RGa3Vod2QS2YXKKJ6+nSvRtv7qQTJAdx0MZa4UHZ4lnxHUWA2MNbLuZQv5FVab4P1CoLOWw==",
"dev": true,
"dependencies": {
"@babel/template": "^7.4.4",
"tslib": "^2.4.0"
},
"peerDependencies": {
"@babel/core": "^7.10.0"
}
},
"node_modules/babel-preset-current-node-syntax": { "node_modules/babel-preset-current-node-syntax": {
"version": "1.0.1", "version": "1.0.1",
"resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz", "resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz",
@ -5233,6 +5247,12 @@
"node": ">=12" "node": ">=12"
} }
}, },
"node_modules/tslib": {
"version": "2.6.2",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==",
"dev": true
},
"node_modules/type-detect": { "node_modules/type-detect": {
"version": "4.0.8", "version": "4.0.8",
"resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz",

1
tests-ui/package.json

@ -24,6 +24,7 @@
"devDependencies": { "devDependencies": {
"@babel/preset-env": "^7.22.20", "@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5", "@types/jest": "^29.5.5",
"babel-plugin-transform-import-meta": "^2.2.1",
"jest": "^29.7.0", "jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0" "jest-environment-jsdom": "^29.7.0"
} }

32
tests-ui/tests/groupNode.test.js

@ -970,4 +970,36 @@ describe("group node", () => {
}); });
}); });
}); });
test("converted inputs with linked widgets map values correctly on creation", async () => {
const { ez, graph, app } = await start();
const k1 = ez.KSampler();
const k2 = ez.KSampler();
k1.widgets.seed.convertToInput();
k2.widgets.seed.convertToInput();
const rr = ez.Reroute();
rr.outputs[0].connectTo(k1.inputs.seed);
rr.outputs[0].connectTo(k2.inputs.seed);
const group = await convertToGroup(app, graph, "test", [k1, k2, rr]);
expect(group.widgets.steps.value).toBe(20);
expect(group.widgets.cfg.value).toBe(8);
expect(group.widgets.scheduler.value).toBe("normal");
expect(group.widgets["KSampler steps"].value).toBe(20);
expect(group.widgets["KSampler cfg"].value).toBe(8);
expect(group.widgets["KSampler scheduler"].value).toBe("normal");
});
test("allow multiple of the same node type to be added", async () => {
const { ez, graph, app } = await start();
const nodes = [...Array(10)].map(() => ez.ImageScaleBy());
const group = await convertToGroup(app, graph, "test", nodes);
expect(group.inputs.length).toBe(10);
expect(group.outputs.length).toBe(10);
expect(group.widgets.length).toBe(20);
expect(group.widgets.map((w) => w.widget.name)).toStrictEqual(
[...Array(10)]
.map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`)
.flatMap((p) => [`${p}upscale_method`, `${p}scale_by`])
);
});
}); });

295
tests-ui/tests/users.test.js

@ -0,0 +1,295 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("users", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
function expectNoUserScreen() {
// Ensure login isnt visible
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
expect(selection["style"].display).toBe("none");
const menu = document.querySelectorAll(".comfy-menu")?.[0];
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
}
describe("multi-user", () => {
function mockAddStylesheet() {
const utils = require("../../web/scripts/utils");
utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve());
}
async function waitForUserScreenShow() {
mockAddStylesheet();
// Wait for "show" to be called
const { UserSelectionScreen } = require("../../web/scripts/ui/userSelection");
let resolve, reject;
const fn = UserSelectionScreen.prototype.show;
const p = new Promise((res, rej) => {
resolve = res;
reject = rej;
});
jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => {
const res = fn(...args);
await new Promise(process.nextTick); // wait for promises to resolve
resolve();
return res;
});
// @ts-ignore
setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500);
await p;
await new Promise(process.nextTick); // wait for promises to resolve
}
async function testUserScreen(onShown, users) {
if (!users) {
users = {};
}
const starting = start({
resetEnv: true,
userConfig: { storage: "server", users },
});
// Ensure no current user
expect(localStorage["Comfy.userId"]).toBeFalsy();
expect(localStorage["Comfy.userName"]).toBeFalsy();
await waitForUserScreenShow();
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
expect(selection).toBeTruthy();
// Ensure login is visible
expect(window.getComputedStyle(selection)?.display).not.toBe("none");
// Ensure menu is hidden
const menu = document.querySelectorAll(".comfy-menu")?.[0];
expect(window.getComputedStyle(menu)?.display).toBe("none");
const isCreate = await onShown(selection);
// Submit form
selection.querySelectorAll("form")[0].submit();
await new Promise(process.nextTick); // wait for promises to resolve
// Wait for start
const s = await starting;
// Ensure login is removed
expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0);
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
// Ensure settings + templates are saved
const { api } = require("../../web/scripts/api");
expect(api.createUser).toHaveBeenCalledTimes(+isCreate);
expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate);
expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate);
if (isCreate) {
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
expect(s.app.isNewUserSession).toBeTruthy();
} else {
expect(s.app.isNewUserSession).toBeFalsy();
}
return { users, selection, ...s };
}
it("allows user creation if no users", async () => {
const { users } = await testUserScreen((selection) => {
// Ensure we have no users flag added
expect(selection.classList.contains("no-users")).toBeTruthy();
// Enter a username
const input = selection.getElementsByTagName("input")[0];
input.focus();
input.value = "Test User";
return true;
});
expect(users).toStrictEqual({
"Test User!": "Test User",
});
expect(localStorage["Comfy.userId"]).toBe("Test User!");
expect(localStorage["Comfy.userName"]).toBe("Test User");
});
it("allows user creation if no current user but other users", async () => {
const users = {
"Test User 2!": "Test User 2",
};
await testUserScreen((selection) => {
expect(selection.classList.contains("no-users")).toBeFalsy();
// Enter a username
const input = selection.getElementsByTagName("input")[0];
input.focus();
input.value = "Test User 3";
return true;
}, users);
expect(users).toStrictEqual({
"Test User 2!": "Test User 2",
"Test User 3!": "Test User 3",
});
expect(localStorage["Comfy.userId"]).toBe("Test User 3!");
expect(localStorage["Comfy.userName"]).toBe("Test User 3");
});
it("allows user selection if no current user but other users", async () => {
const users = {
"A!": "A",
"B!": "B",
"C!": "C",
};
await testUserScreen((selection) => {
expect(selection.classList.contains("no-users")).toBeFalsy();
// Check user list
const select = selection.getElementsByTagName("select")[0];
const options = select.getElementsByTagName("option");
expect(
[...options]
.filter((o) => !o.disabled)
.reduce((p, n) => {
p[n.getAttribute("value")] = n.textContent;
return p;
}, {})
).toStrictEqual(users);
// Select an option
select.focus();
select.value = options[2].value;
return false;
}, users);
expect(users).toStrictEqual(users);
expect(localStorage["Comfy.userId"]).toBe("B!");
expect(localStorage["Comfy.userName"]).toBe("B");
});
it("doesnt show user screen if current user", async () => {
const starting = start({
resetEnv: true,
userConfig: {
storage: "server",
users: {
"User!": "User",
},
},
localStorage: {
"Comfy.userId": "User!",
"Comfy.userName": "User",
},
});
await new Promise(process.nextTick); // wait for promises to resolve
expectNoUserScreen();
await starting;
});
it("allows user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: {
storage: "server",
users: {
"User!": "User",
},
},
localStorage: {
"Comfy.userId": "User!",
"Comfy.userName": "User",
},
});
// cant actually test switching user easily but can check the setting is present
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy();
});
});
describe("single-user", () => {
it("doesnt show user creation if no default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: false, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(1);
expect(api.storeUserData).toHaveBeenCalledTimes(1);
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
expect(app.isNewUserSession).toBeTruthy();
});
it("doesnt show user creation if default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt allow user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
});
});
describe("browser-user", () => {
it("doesnt show user creation if no default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: false, storage: "browser" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt show user creation if default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt allow user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "browser" },
});
expectNoUserScreen();
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
});
});
});

16
tests-ui/utils/index.js

@ -1,10 +1,18 @@
const { mockApi } = require("./setup"); const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph"); const { Ez } = require("./ezgraph");
const lg = require("./litegraph"); const lg = require("./litegraph");
const fs = require("fs");
const path = require("path");
const html = fs.readFileSync(path.resolve(__dirname, "../../web/index.html"))
/** /**
* *
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config * @param { Parameters<typeof mockApi>[0] & {
* resetEnv?: boolean,
* preSetup?(app): Promise<void>,
* localStorage?: Record<string, string>
* } } config
* @returns * @returns
*/ */
export async function start(config = {}) { export async function start(config = {}) {
@ -12,12 +20,18 @@ export async function start(config = {}) {
jest.resetModules(); jest.resetModules();
jest.resetAllMocks(); jest.resetAllMocks();
lg.setup(global); lg.setup(global);
localStorage.clear();
sessionStorage.clear();
} }
Object.assign(localStorage, config.localStorage ?? {});
document.body.innerHTML = html;
mockApi(config); mockApi(config);
const { app } = require("../../web/scripts/app"); const { app } = require("../../web/scripts/app");
config.preSetup?.(app); config.preSetup?.(app);
await app.setup(); await app.setup();
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
} }

36
tests-ui/utils/setup.js

@ -18,9 +18,21 @@ function* walkSync(dir) {
*/ */
/** /**
* @param { { mockExtensions?: string[], mockNodeDefs?: Record<string, ComfyObjectInfo> } } config * @param {{
* mockExtensions?: string[],
* mockNodeDefs?: Record<string, ComfyObjectInfo>,
* settings?: Record<string, string>
* userConfig?: {storage: "server" | "browser", users?: Record<string, any>, migrated?: boolean },
* userData?: Record<string, any>
* }} config
*/ */
export function mockApi({ mockExtensions, mockNodeDefs } = {}) { export function mockApi(config = {}) {
let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = {
userConfig,
settings: {},
userData: {},
...config,
};
if (!mockExtensions) { if (!mockExtensions) {
mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core"))) mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core")))
.filter((x) => x.endsWith(".js")) .filter((x) => x.endsWith(".js"))
@ -40,6 +52,26 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
getNodeDefs: jest.fn(() => mockNodeDefs), getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(), init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x), apiURL: jest.fn((x) => "../../web/" + x),
createUser: jest.fn((username) => {
if(username in userConfig.users) {
return { status: 400, json: () => "Duplicate" }
}
userConfig.users[username + "!"] = username;
return { status: 200, json: () => username + "!" }
}),
getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }),
getSettings: jest.fn(() => settings),
storeSettings: jest.fn((v) => Object.assign(settings, v)),
getUserData: jest.fn((f) => {
if (f in userData) {
return { status: 200, json: () => userData[f] };
} else {
return { status: 404 };
}
}),
storeUserData: jest.fn((file, data) => {
userData[file] = data;
}),
}; };
jest.mock("../../web/scripts/api", () => ({ jest.mock("../../web/scripts/api", () => ({
get api() { get api() {

9
web/extensions/core/groupNode.js

@ -331,16 +331,17 @@ export class GroupNodeConfig {
getInputConfig(node, inputName, seenInputs, config, extra) { getInputConfig(node, inputName, seenInputs, config, extra) {
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
let key = name;
let prefix = ""; let prefix = "";
// Special handling for primitive to include the title if it is set rather than just "value" // Special handling for primitive to include the title if it is set rather than just "value"
if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) { if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) {
prefix = `${node.title ?? node.type} `; prefix = `${node.title ?? node.type} `;
name = `${prefix}${inputName}`; key = name = `${prefix}${inputName}`;
if (name in seenInputs) { if (name in seenInputs) {
name = `${prefix}${seenInputs[name]} ${inputName}`; name = `${prefix}${seenInputs[name]} ${inputName}`;
} }
} }
seenInputs[name] = (seenInputs[name] ?? 1) + 1; seenInputs[key] = (seenInputs[key] ?? 1) + 1;
if (inputName === "seed" || inputName === "noise_seed") { if (inputName === "seed" || inputName === "noise_seed") {
if (!extra) extra = {}; if (!extra) extra = {};
@ -1010,10 +1011,10 @@ export class GroupNodeHandler {
const newName = map[oldName]; const newName = map[oldName];
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
const mainWidget = this.node.widgets[widgetIndex]; const mainWidget = this.node.widgets[widgetIndex];
if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift)) { if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) {
// Find the inner widget and shift by the number of linked widgets as they will have been removed too // Find the inner widget and shift by the number of linked widgets as they will have been removed too
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName); const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
linkedShift += innerWidget.linkedWidgets?.length ?? 0; linkedShift += innerWidget?.linkedWidgets?.length ?? 0;
} }
if (widgetIndex === -1) { if (widgetIndex === -1) {
continue; continue;

62
web/extensions/core/nodeTemplates.js

@ -1,4 +1,5 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { ComfyDialog, $el } from "../../scripts/ui.js"; import { ComfyDialog, $el } from "../../scripts/ui.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
@ -20,16 +21,20 @@ import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
// Open the manage dialog and Drag and drop elements using the "Name:" label as handle // Open the manage dialog and Drag and drop elements using the "Name:" label as handle
const id = "Comfy.NodeTemplates"; const id = "Comfy.NodeTemplates";
const file = "comfy.templates.json";
class ManageTemplates extends ComfyDialog { class ManageTemplates extends ComfyDialog {
constructor() { constructor() {
super(); super();
this.load().then((v) => {
this.templates = v;
});
this.element.classList.add("comfy-manage-templates"); this.element.classList.add("comfy-manage-templates");
this.templates = this.load();
this.draggedEl = null; this.draggedEl = null;
this.saveVisualCue = null; this.saveVisualCue = null;
this.emptyImg = new Image(); this.emptyImg = new Image();
this.emptyImg.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs='; this.emptyImg.src = "data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=";
this.importInput = $el("input", { this.importInput = $el("input", {
type: "file", type: "file",
@ -67,32 +72,65 @@ class ManageTemplates extends ComfyDialog {
return btns; return btns;
} }
load() { async load() {
const templates = localStorage.getItem(id); let templates = [];
if (templates) { if (app.storageLocation === "server") {
return JSON.parse(templates); if (app.isNewUserSession) {
// New user so migrate existing templates
const json = localStorage.getItem(id);
if (json) {
templates = JSON.parse(json);
}
await api.storeUserData(file, json, { stringify: false });
} else {
const res = await api.getUserData(file);
if (res.status === 200) {
try {
templates = await res.json();
} catch (error) {
}
} else if (res.status !== 404) {
console.error(res.status + " " + res.statusText);
}
}
} else { } else {
return []; const json = localStorage.getItem(id);
if (json) {
templates = JSON.parse(json);
} }
} }
store() { return templates ?? [];
}
async store() {
if(app.storageLocation === "server") {
const templates = JSON.stringify(this.templates, undefined, 4);
localStorage.setItem(id, templates); // Backwards compatibility
try {
await api.storeUserData(file, templates, { stringify: false });
} catch (error) {
console.error(error);
alert(error.message);
}
} else {
localStorage.setItem(id, JSON.stringify(this.templates)); localStorage.setItem(id, JSON.stringify(this.templates));
} }
}
async importAll() { async importAll() {
for (const file of this.importInput.files) { for (const file of this.importInput.files) {
if (file.type === "application/json" || file.name.endsWith(".json")) { if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = async () => { reader.onload = async () => {
var importFile = JSON.parse(reader.result); const importFile = JSON.parse(reader.result);
if (importFile && importFile?.templates) { if (importFile?.templates) {
for (const template of importFile.templates) { for (const template of importFile.templates) {
if (template?.name && template?.data) { if (template?.name && template?.data) {
this.templates.push(template); this.templates.push(template);
} }
} }
this.store(); await this.store();
} }
}; };
await reader.readAsText(file); await reader.readAsText(file);
@ -159,7 +197,7 @@ class ManageTemplates extends ComfyDialog {
e.currentTarget.style.border = "1px dashed transparent"; e.currentTarget.style.border = "1px dashed transparent";
e.currentTarget.removeAttribute("draggable"); e.currentTarget.removeAttribute("draggable");
// rearrange the elements in the localStorage // rearrange the elements
this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
var prev_i = el.dataset.id; var prev_i = el.dataset.id;

30
web/index.html

@ -16,5 +16,33 @@
window.graph = app.graph; window.graph = app.graph;
</script> </script>
</head> </head>
<body class="litegraph"></body> <body class="litegraph">
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner">
<h1>ComfyUI</h1>
<form>
<section>
<label>New user:
<input placeholder="Enter a username" />
</label>
</section>
<div class="comfy-user-existing">
<span class="or-separator">OR</span>
<section>
<label>
Existing user:
<select>
<option hidden disabled selected value> Select a user </option>
</select>
</label>
</section>
</div>
<footer>
<span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button>
</footer>
</form>
</main>
</div>
</body>
</html> </html>

45
web/lib/litegraph.core.js

@ -48,7 +48,7 @@
EVENT_LINK_COLOR: "#A86", EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA", CONNECTING_LINK_COLOR: "#AFA",
MAX_NUMBER_OF_NODES: 1000, //avoid infinite loops MAX_NUMBER_OF_NODES: 10000, //avoid infinite loops
DEFAULT_POSITION: [100, 100], //default node position DEFAULT_POSITION: [100, 100], //default node position
VALID_SHAPES: ["default", "box", "round", "card"], //,"circle" VALID_SHAPES: ["default", "box", "round", "card"], //,"circle"
@ -3788,16 +3788,42 @@
/** /**
* returns the bounding of the object, used for rendering purposes * returns the bounding of the object, used for rendering purposes
* bounding is: [topleft_cornerx, topleft_cornery, width, height]
* @method getBounding * @method getBounding
* @return {Float32Array[4]} the total size * @param out {Float32Array[4]?} [optional] a place to store the output, to free garbage
* @param compute_outer {boolean?} [optional] set to true to include the shadow and connection points in the bounding calculation
* @return {Float32Array[4]} the bounding box in format of [topleft_cornerx, topleft_cornery, width, height]
*/ */
LGraphNode.prototype.getBounding = function(out) { LGraphNode.prototype.getBounding = function(out, compute_outer) {
out = out || new Float32Array(4); out = out || new Float32Array(4);
out[0] = this.pos[0] - 4; const nodePos = this.pos;
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; const isCollapsed = this.flags.collapsed;
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4; const nodeSize = this.size;
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
let left_offset = 0;
// 1 offset due to how nodes are rendered
let right_offset = 1 ;
let top_offset = 0;
let bottom_offset = 0;
if (compute_outer) {
// 4 offset for collapsed node connection points
left_offset = 4;
// 6 offset for right shadow and collapsed node connection points
right_offset = 6 + left_offset;
// 4 offset for collapsed nodes top connection points
top_offset = 4;
// 5 offset for bottom shadow and collapsed node connection points
bottom_offset = 5 + top_offset;
}
out[0] = nodePos[0] - left_offset;
out[1] = nodePos[1] - LiteGraph.NODE_TITLE_HEIGHT - top_offset;
out[2] = isCollapsed ?
(this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) + right_offset :
nodeSize[0] + right_offset;
out[3] = isCollapsed ?
LiteGraph.NODE_TITLE_HEIGHT + bottom_offset :
nodeSize[1] + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset;
if (this.onBounding) { if (this.onBounding) {
this.onBounding(out); this.onBounding(out);
@ -7674,7 +7700,7 @@ LGraphNode.prototype.executeAction = function(action)
continue; continue;
} }
if (!overlapBounding(this.visible_area, n.getBounding(temp))) { if (!overlapBounding(this.visible_area, n.getBounding(temp, true))) {
continue; continue;
} //out of the visible area } //out of the visible area
@ -11336,6 +11362,7 @@ LGraphNode.prototype.executeAction = function(action)
name_element.innerText = title; name_element.innerText = title;
var value_element = dialog.querySelector(".value"); var value_element = dialog.querySelector(".value");
value_element.value = value; value_element.value = value;
value_element.select();
var input = value_element; var input = value_element;
input.addEventListener("keydown", function(e) { input.addEventListener("keydown", function(e) {

100
web/scripts/api.js

@ -12,6 +12,13 @@ class ComfyApi extends EventTarget {
} }
fetchApi(route, options) { fetchApi(route, options) {
if (!options) {
options = {};
}
if (!options.headers) {
options.headers = {};
}
options.headers["Comfy-User"] = this.user;
return fetch(this.apiURL(route), options); return fetch(this.apiURL(route), options);
} }
@ -315,6 +322,99 @@ class ComfyApi extends EventTarget {
async interrupt() { async interrupt() {
await this.#postItem("interrupt", null); await this.#postItem("interrupt", null);
} }
/**
* Gets user configuration data and where data should be stored
* @returns { Promise<{ storage: "server" | "browser", users?: Promise<string, unknown>, migrated?: boolean }> }
*/
async getUserConfig() {
return (await this.fetchApi("/users")).json();
}
/**
* Creates a new user
* @param { string } username
* @returns The fetch response
*/
createUser(username) {
return this.fetchApi("/users", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ username }),
});
}
/**
* Gets all setting values for the current user
* @returns { Promise<string, unknown> } A dictionary of id -> value
*/
async getSettings() {
return (await this.fetchApi("/settings")).json();
}
/**
* Gets a setting for the current user
* @param { string } id The id of the setting to fetch
* @returns { Promise<unknown> } The setting value
*/
async getSetting(id) {
return (await this.fetchApi(`/settings/${encodeURIComponent(id)}`)).json();
}
/**
* Stores a dictionary of settings for the current user
* @param { Record<string, unknown> } settings Dictionary of setting id -> value to save
* @returns { Promise<void> }
*/
async storeSettings(settings) {
return this.fetchApi(`/settings`, {
method: "POST",
body: JSON.stringify(settings)
});
}
/**
* Stores a setting for the current user
* @param { string } id The id of the setting to update
* @param { unknown } value The value of the setting
* @returns { Promise<void> }
*/
async storeSetting(id, value) {
return this.fetchApi(`/settings/${encodeURIComponent(id)}`, {
method: "POST",
body: JSON.stringify(value)
});
}
/**
* Gets a user data file for the current user
* @param { string } file The name of the userdata file to load
* @param { RequestInit } [options]
* @returns { Promise<unknown> } The fetch response object
*/
async getUserData(file, options) {
return this.fetchApi(`/userdata/${encodeURIComponent(file)}`, options);
}
/**
* Stores a user data file for the current user
* @param { string } file The name of the userdata file to save
* @param { unknown } data The data to save to the file
* @param { RequestInit & { stringify?: boolean, throwOnError?: boolean } } [options]
* @returns { Promise<void> }
*/
async storeUserData(file, data, options = { stringify: true, throwOnError: true }) {
const resp = await this.fetchApi(`/userdata/${encodeURIComponent(file)}`, {
method: "POST",
body: options?.stringify ? JSON.stringify(data) : data,
...options,
});
if (resp.status !== 200) {
throw new Error(`Error storing user data file '${file}': ${resp.status} ${(await resp).statusText}`);
}
}
} }
export const api = new ComfyApi(); export const api = new ComfyApi();

101
web/scripts/app.js

@ -1291,10 +1291,92 @@ export class ComfyApp {
await Promise.all(extensionPromises); await Promise.all(extensionPromises);
} }
async #migrateSettings() {
this.isNewUserSession = true;
// Store all current settings
const settings = Object.keys(this.ui.settings).reduce((p, n) => {
const v = localStorage[`Comfy.Settings.${n}`];
if (v) {
try {
p[n] = JSON.parse(v);
} catch (error) {}
}
return p;
}, {});
await api.storeSettings(settings);
}
async #setUser() {
const userConfig = await api.getUserConfig();
this.storageLocation = userConfig.storage;
if (typeof userConfig.migrated == "boolean") {
// Single user mode migrated true/false for if the default user is created
if (!userConfig.migrated && this.storageLocation === "server") {
// Default user not created yet
await this.#migrateSettings();
}
return;
}
this.multiUserServer = true;
let user = localStorage["Comfy.userId"];
const users = userConfig.users ?? {};
if (!user || !users[user]) {
// This will rarely be hit so move the loading to on demand
const { UserSelectionScreen } = await import("./ui/userSelection.js");
this.ui.menuContainer.style.display = "none";
const { userId, username, created } = await new UserSelectionScreen().show(users, user);
this.ui.menuContainer.style.display = "";
user = userId;
localStorage["Comfy.userName"] = username;
localStorage["Comfy.userId"] = user;
if (created) {
api.user = user;
await this.#migrateSettings();
}
}
api.user = user;
this.ui.settings.addSetting({
id: "Comfy.SwitchUser",
name: "Switch User",
type: (name) => {
let currentUser = localStorage["Comfy.userName"];
if (currentUser) {
currentUser = ` (${currentUser})`;
}
return $el("tr", [
$el("td", [
$el("label", {
textContent: name,
}),
]),
$el("td", [
$el("button", {
textContent: name + (currentUser ?? ""),
onclick: () => {
delete localStorage["Comfy.userId"];
delete localStorage["Comfy.userName"];
window.location.reload();
},
}),
]),
]);
},
});
}
/** /**
* Set up the app on the page * Set up the app on the page
*/ */
async setup() { async setup() {
await this.#setUser();
await this.ui.settings.load();
await this.#loadExtensions(); await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM // Create and mount the LiteGraph in the DOM
@ -1781,10 +1863,19 @@ export class ComfyApp {
} }
} }
output[String(node.id)] = { let node_data = {
inputs, inputs,
class_type: node.comfyClass, class_type: node.comfyClass,
}; };
if (this.ui.settings.getSettingValue("Comfy.DevMode")) {
// Ignored by the backend.
node_data["_meta"] = {
title: node.title,
}
}
output[String(node.id)] = node_data;
} }
} }
@ -2011,12 +2102,8 @@ export class ComfyApp {
async refreshComboInNodes() { async refreshComboInNodes() {
const defs = await api.getNodeDefs(); const defs = await api.getNodeDefs();
for(const nodeId in LiteGraph.registered_node_types) { for (const nodeId in defs) {
const node = LiteGraph.registered_node_types[nodeId]; this.registerNodeDef(nodeId, defs[nodeId]);
const nodeDef = defs[nodeId];
if(!nodeDef) continue;
node.nodeData = nodeDef;
} }
for(let nodeNum in this.graph._nodes) { for(let nodeNum in this.graph._nodes) {

1
web/scripts/domWidget.js

@ -177,6 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () {
for (const w of node.widgets) { for (const w of node.widgets) {
if (w.element) { if (w.element) {
w.element.hidden = hidden; w.element.hidden = hidden;
w.element.style.display = hidden ? "none" : undefined;
if (hidden) { if (hidden) {
w.options.onHide?.(w); w.options.onHide?.(w);
} }

269
web/scripts/ui.js

@ -1,4 +1,8 @@
import {api} from "./api.js"; import { api } from "./api.js";
import { ComfyDialog as _ComfyDialog } from "./ui/dialog.js";
import { ComfySettingsDialog } from "./ui/settings.js";
export const ComfyDialog = _ComfyDialog;
export function $el(tag, propsOrChildren, children) { export function $el(tag, propsOrChildren, children) {
const split = tag.split("."); const split = tag.split(".");
@ -167,267 +171,6 @@ function dragElement(dragEl, settings) {
} }
} }
export class ComfyDialog {
constructor() {
this.element = $el("div.comfy-modal", {parent: document.body}, [
$el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]),
]);
}
createButtons() {
return [
$el("button", {
type: "button",
textContent: "Close",
onclick: () => this.close(),
}),
];
}
close() {
this.element.style.display = "none";
}
show(html) {
if (typeof html === "string") {
this.textElement.innerHTML = html;
} else {
this.textElement.replaceChildren(html);
}
this.element.style.display = "flex";
}
}
class ComfySettingsDialog extends ComfyDialog {
constructor() {
super();
this.element = $el("dialog", {
id: "comfy-settings-dialog",
parent: document.body,
}, [
$el("table.comfy-modal-content.comfy-table", [
$el("caption", {textContent: "Settings"}),
$el("tbody", {$: (tbody) => (this.textElement = tbody)}),
$el("button", {
type: "button",
textContent: "Close",
style: {
cursor: "pointer",
},
onclick: () => {
this.element.close();
},
}),
]),
]);
this.settings = [];
}
getSettingValue(id, defaultValue) {
const settingId = "Comfy.Settings." + id;
const v = localStorage[settingId];
return v == null ? defaultValue : JSON.parse(v);
}
setSettingValue(id, value) {
const settingId = "Comfy.Settings." + id;
localStorage[settingId] = JSON.stringify(value);
}
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
if (!id) {
throw new Error("Settings must have an ID");
}
if (this.settings.find((s) => s.id === id)) {
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
}
const settingId = `Comfy.Settings.${id}`;
const v = localStorage[settingId];
let value = v == null ? defaultValue : JSON.parse(v);
// Trigger initial setting of value
if (onChange) {
onChange(value, undefined);
}
this.settings.push({
render: () => {
const setter = (v) => {
if (onChange) {
onChange(v, value);
}
localStorage[settingId] = JSON.stringify(v);
value = v;
};
value = this.getSettingValue(id, defaultValue);
let element;
const htmlID = id.replaceAll(".", "-");
const labelCell = $el("td", [
$el("label", {
for: htmlID,
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
textContent: name,
})
]);
if (typeof type === "function") {
element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
id: htmlID,
type: "checkbox",
checked: value,
onchange: (event) => {
const isChecked = event.target.checked;
if (onChange !== undefined) {
onChange(isChecked)
}
this.setSettingValue(id, isChecked);
},
}),
]),
])
break;
case "number":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
type,
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
case "slider":
element = $el("tr", [
labelCell,
$el("td", [
$el("div", {
style: {
display: "grid",
gridAutoFlow: "column",
},
}, [
$el("input", {
...attrs,
value,
type: "range",
oninput: (e) => {
setter(e.target.value);
e.target.nextElementSibling.value = e.target.value;
},
}),
$el("input", {
...attrs,
value,
id: htmlID,
type: "number",
style: {maxWidth: "4rem"},
oninput: (e) => {
setter(e.target.value);
e.target.previousElementSibling.value = e.target.value;
},
}),
]),
]),
]);
break;
case "combo":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"select",
{
oninput: (e) => {
setter(e.target.value);
},
},
(typeof options === "function" ? options(value) : options || []).map((opt) => {
if (typeof opt === "string") {
opt = { text: opt };
}
const v = opt.value ?? opt.text;
return $el("option", {
value: v,
textContent: opt.text,
selected: value + "" === v + "",
});
})
),
]),
]);
break;
case "text":
default:
if (type !== "text") {
console.warn(`Unsupported setting type '${type}, defaulting to text`);
}
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
}
}
if (tooltip) {
element.title = tooltip;
}
return element;
},
});
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
}
show() {
this.textElement.replaceChildren(
$el("tr", {
style: {display: "none"},
}, [
$el("th"),
$el("th", {style: {width: "33%"}})
]),
...this.settings.map((s) => s.render()),
)
this.element.showModal();
}
}
class ComfyList { class ComfyList {
#type; #type;
#text; #text;
@ -526,7 +269,7 @@ export class ComfyUI {
constructor(app) { constructor(app) {
this.app = app; this.app = app;
this.dialog = new ComfyDialog(); this.dialog = new ComfyDialog();
this.settings = new ComfySettingsDialog(); this.settings = new ComfySettingsDialog(app);
this.batchCount = 1; this.batchCount = 1;
this.lastQueueSize = 0; this.lastQueueSize = 0;

32
web/scripts/ui/dialog.js

@ -0,0 +1,32 @@
import { $el } from "../ui.js";
export class ComfyDialog {
constructor() {
this.element = $el("div.comfy-modal", { parent: document.body }, [
$el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]),
]);
}
createButtons() {
return [
$el("button", {
type: "button",
textContent: "Close",
onclick: () => this.close(),
}),
];
}
close() {
this.element.style.display = "none";
}
show(html) {
if (typeof html === "string") {
this.textElement.innerHTML = html;
} else {
this.textElement.replaceChildren(html);
}
this.element.style.display = "flex";
}
}

307
web/scripts/ui/settings.js

@ -0,0 +1,307 @@
import { $el } from "../ui.js";
import { api } from "../api.js";
import { ComfyDialog } from "./dialog.js";
export class ComfySettingsDialog extends ComfyDialog {
constructor(app) {
super();
this.app = app;
this.settingsValues = {};
this.settingsLookup = {};
this.element = $el(
"dialog",
{
id: "comfy-settings-dialog",
parent: document.body,
},
[
$el("table.comfy-modal-content.comfy-table", [
$el("caption", { textContent: "Settings" }),
$el("tbody", { $: (tbody) => (this.textElement = tbody) }),
$el("button", {
type: "button",
textContent: "Close",
style: {
cursor: "pointer",
},
onclick: () => {
this.element.close();
},
}),
]),
]
);
}
get settings() {
return Object.values(this.settingsLookup);
}
async load() {
if (this.app.storageLocation === "browser") {
this.settingsValues = localStorage;
} else {
this.settingsValues = await api.getSettings();
}
// Trigger onChange for any settings added before load
for (const id in this.settingsLookup) {
this.settingsLookup[id].onChange?.(this.settingsValues[this.getId(id)]);
}
}
getId(id) {
if (this.app.storageLocation === "browser") {
id = "Comfy.Settings." + id;
}
return id;
}
getSettingValue(id, defaultValue) {
let value = this.settingsValues[this.getId(id)];
if(value != null) {
if(this.app.storageLocation === "browser") {
try {
value = JSON.parse(value);
} catch (error) {
}
}
}
return value ?? defaultValue;
}
async setSettingValueAsync(id, value) {
const json = JSON.stringify(value);
localStorage["Comfy.Settings." + id] = json; // backwards compatibility for extensions keep setting in storage
let oldValue = this.getSettingValue(id, undefined);
this.settingsValues[this.getId(id)] = value;
if (id in this.settingsLookup) {
this.settingsLookup[id].onChange?.(value, oldValue);
}
await api.storeSetting(id, value);
}
setSettingValue(id, value) {
this.setSettingValueAsync(id, value).catch((err) => {
alert(`Error saving setting '${id}'`);
console.error(err);
});
}
addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined }) {
if (!id) {
throw new Error("Settings must have an ID");
}
if (id in this.settingsLookup) {
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
}
let skipOnChange = false;
let value = this.getSettingValue(id);
if (value == null) {
if (this.app.isNewUserSession) {
// Check if we have a localStorage value but not a setting value and we are a new user
const localValue = localStorage["Comfy.Settings." + id];
if (localValue) {
value = JSON.parse(localValue);
this.setSettingValue(id, value); // Store on the server
}
}
if (value == null) {
value = defaultValue;
}
}
// Trigger initial setting of value
if (!skipOnChange) {
onChange?.(value, undefined);
}
this.settingsLookup[id] = {
id,
onChange,
name,
render: () => {
const setter = (v) => {
if (onChange) {
onChange(v, value);
}
this.setSettingValue(id, v);
value = v;
};
value = this.getSettingValue(id, defaultValue);
let element;
const htmlID = id.replaceAll(".", "-");
const labelCell = $el("td", [
$el("label", {
for: htmlID,
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
textContent: name,
}),
]);
if (typeof type === "function") {
element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
id: htmlID,
type: "checkbox",
checked: value,
onchange: (event) => {
const isChecked = event.target.checked;
if (onChange !== undefined) {
onChange(isChecked);
}
this.setSettingValue(id, isChecked);
},
}),
]),
]);
break;
case "number":
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
type,
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
case "slider":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"div",
{
style: {
display: "grid",
gridAutoFlow: "column",
},
},
[
$el("input", {
...attrs,
value,
type: "range",
oninput: (e) => {
setter(e.target.value);
e.target.nextElementSibling.value = e.target.value;
},
}),
$el("input", {
...attrs,
value,
id: htmlID,
type: "number",
style: { maxWidth: "4rem" },
oninput: (e) => {
setter(e.target.value);
e.target.previousElementSibling.value = e.target.value;
},
}),
]
),
]),
]);
break;
case "combo":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"select",
{
oninput: (e) => {
setter(e.target.value);
},
},
(typeof options === "function" ? options(value) : options || []).map((opt) => {
if (typeof opt === "string") {
opt = { text: opt };
}
const v = opt.value ?? opt.text;
return $el("option", {
value: v,
textContent: opt.text,
selected: value + "" === v + "",
});
})
),
]),
]);
break;
case "text":
default:
if (type !== "text") {
console.warn(`Unsupported setting type '${type}, defaulting to text`);
}
element = $el("tr", [
labelCell,
$el("td", [
$el("input", {
value,
id: htmlID,
oninput: (e) => {
setter(e.target.value);
},
...attrs,
}),
]),
]);
break;
}
}
if (tooltip) {
element.title = tooltip;
}
return element;
},
};
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
}
show() {
this.textElement.replaceChildren(
$el(
"tr",
{
style: { display: "none" },
},
[$el("th"), $el("th", { style: { width: "33%" } })]
),
...this.settings.sort((a, b) => a.name.localeCompare(b.name)).map((s) => s.render())
);
this.element.showModal();
}
}

34
web/scripts/ui/spinner.css

@ -0,0 +1,34 @@
.lds-ring {
display: inline-block;
position: relative;
width: 1em;
height: 1em;
}
.lds-ring div {
box-sizing: border-box;
display: block;
position: absolute;
width: 100%;
height: 100%;
border: 0.15em solid #fff;
border-radius: 50%;
animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite;
border-color: #fff transparent transparent transparent;
}
.lds-ring div:nth-child(1) {
animation-delay: -0.45s;
}
.lds-ring div:nth-child(2) {
animation-delay: -0.3s;
}
.lds-ring div:nth-child(3) {
animation-delay: -0.15s;
}
@keyframes lds-ring {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}

9
web/scripts/ui/spinner.js

@ -0,0 +1,9 @@
import { addStylesheet } from "../utils.js";
addStylesheet(import.meta.url);
export function createSpinner() {
const div = document.createElement("div");
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
return div.firstElementChild;
}

135
web/scripts/ui/userSelection.css

@ -0,0 +1,135 @@
.comfy-user-selection {
width: 100vw;
height: 100vh;
position: absolute;
top: 0;
left: 0;
z-index: 999;
display: flex;
align-items: center;
justify-content: center;
font-family: sans-serif;
background: linear-gradient(var(--tr-even-bg-color), var(--tr-odd-bg-color));
}
.comfy-user-selection-inner {
background: var(--comfy-menu-bg);
margin-top: -30vh;
padding: 20px 40px;
border-radius: 10px;
min-width: 365px;
position: relative;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.3);
}
.comfy-user-selection-inner form {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
}
.comfy-user-selection-inner h1 {
margin: 10px 0 30px 0;
font-weight: normal;
}
.comfy-user-selection-inner label {
display: flex;
flex-direction: column;
width: 100%;
}
.comfy-user-selection input,
.comfy-user-selection select {
background-color: var(--comfy-input-bg);
color: var(--input-text);
border: 0;
border-radius: 5px;
padding: 5px;
margin-top: 10px;
}
.comfy-user-selection input::placeholder {
color: var(--descrip-text);
opacity: 1;
}
.comfy-user-existing {
width: 100%;
}
.no-users .comfy-user-existing {
display: none;
}
.comfy-user-selection-inner .or-separator {
margin: 10px 0;
padding: 10px;
display: block;
text-align: center;
width: 100%;
color: var(--descrip-text);
}
.comfy-user-selection-inner .or-separator {
overflow: hidden;
text-align: center;
margin-left: -10px;
}
.comfy-user-selection-inner .or-separator::before,
.comfy-user-selection-inner .or-separator::after {
content: "";
background-color: var(--border-color);
position: relative;
height: 1px;
vertical-align: middle;
display: inline-block;
width: calc(50% - 20px);
top: -1px;
}
.comfy-user-selection-inner .or-separator::before {
right: 10px;
margin-left: -50%;
}
.comfy-user-selection-inner .or-separator::after {
left: 10px;
margin-right: -50%;
}
.comfy-user-selection-inner section {
width: 100%;
padding: 10px;
margin: -10px;
transition: background-color 0.2s;
}
.comfy-user-selection-inner section.selected {
background: var(--border-color);
border-radius: 5px;
}
.comfy-user-selection-inner footer {
display: flex;
flex-direction: column;
align-items: center;
margin-top: 20px;
}
.comfy-user-selection-inner .comfy-user-error {
color: var(--error-text);
margin-bottom: 10px;
}
.comfy-user-button-next {
font-size: 16px;
padding: 6px 10px;
width: 100px;
display: flex;
gap: 5px;
align-items: center;
justify-content: center;
}

114
web/scripts/ui/userSelection.js

@ -0,0 +1,114 @@
import { api } from "../api.js";
import { $el } from "../ui.js";
import { addStylesheet } from "../utils.js";
import { createSpinner } from "./spinner.js";
export class UserSelectionScreen {
async show(users, user) {
// This will rarely be hit so move the loading to on demand
await addStylesheet(import.meta.url);
const userSelection = document.getElementById("comfy-user-selection");
userSelection.style.display = "";
return new Promise((resolve) => {
const input = userSelection.getElementsByTagName("input")[0];
const select = userSelection.getElementsByTagName("select")[0];
const inputSection = input.closest("section");
const selectSection = select.closest("section");
const form = userSelection.getElementsByTagName("form")[0];
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
const button = userSelection.getElementsByClassName("comfy-user-button-next")[0];
let inputActive = null;
input.addEventListener("focus", () => {
inputSection.classList.add("selected");
selectSection.classList.remove("selected");
inputActive = true;
});
select.addEventListener("focus", () => {
inputSection.classList.remove("selected");
selectSection.classList.add("selected");
inputActive = false;
select.style.color = "";
});
select.addEventListener("blur", () => {
if (!select.value) {
select.style.color = "var(--descrip-text)";
}
});
form.addEventListener("submit", async (e) => {
e.preventDefault();
if (inputActive == null) {
error.textContent = "Please enter a username or select an existing user.";
} else if (inputActive) {
const username = input.value.trim();
if (!username) {
error.textContent = "Please enter a username.";
return;
}
// Create new user
input.disabled = select.disabled = input.readonly = select.readonly = true;
const spinner = createSpinner();
button.prepend(spinner);
try {
const resp = await api.createUser(username);
if (resp.status >= 300) {
let message = "Error creating user: " + resp.status + " " + resp.statusText;
try {
const res = await resp.json();
if(res.error) {
message = res.error;
}
} catch (error) {
}
throw new Error(message);
}
resolve({ username, userId: await resp.json(), created: true });
} catch (err) {
spinner.remove();
error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred.";
input.disabled = select.disabled = input.readonly = select.readonly = false;
return;
}
} else if (!select.value) {
error.textContent = "Please select an existing user.";
return;
} else {
resolve({ username: users[select.value], userId: select.value, created: false });
}
});
if (user) {
const name = localStorage["Comfy.userName"];
if (name) {
input.value = name;
}
}
if (input.value) {
// Focus the input, do this separately as sometimes browsers like to fill in the value
input.focus();
}
const userIds = Object.keys(users ?? {});
if (userIds.length) {
for (const u of userIds) {
$el("option", { textContent: users[u], value: u, parent: select });
}
select.style.color = "var(--descrip-text)";
if (select.value) {
// Focus the select, do this separately as sometimes browsers like to fill in the value
select.focus();
}
} else {
userSelection.classList.add("no-users");
input.focus();
}
}).then((r) => {
userSelection.remove();
return r;
});
}
}

21
web/scripts/utils.js

@ -1,3 +1,5 @@
import { $el } from "./ui.js";
// Simple date formatter // Simple date formatter
const parts = { const parts = {
d: (d) => d.getDate(), d: (d) => d.getDate(),
@ -65,3 +67,22 @@ export function applyTextReplacements(app, value) {
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
}); });
} }
export async function addStylesheet(urlOrFile, relativeTo) {
return new Promise((res, rej) => {
let url;
if (urlOrFile.endsWith(".js")) {
url = urlOrFile.substr(0, urlOrFile.length - 2) + "css";
} else {
url = new URL(urlOrFile, relativeTo ?? `${window.location.protocol}//${window.location.host}`).toString();
}
$el("link", {
parent: document.head,
rel: "stylesheet",
type: "text/css",
href: url,
onload: res,
onerror: rej,
});
});
}

2
web/style.css

@ -121,6 +121,7 @@ body {
width: 100%; width: 100%;
} }
.comfy-btn,
.comfy-menu > button, .comfy-menu > button,
.comfy-menu-btns button, .comfy-menu-btns button,
.comfy-menu .comfy-list button, .comfy-menu .comfy-list button,
@ -133,6 +134,7 @@ body {
margin-top: 2px; margin-top: 2px;
} }
.comfy-btn:hover:not(:disabled),
.comfy-menu > button:hover, .comfy-menu > button:hover,
.comfy-menu-btns button:hover, .comfy-menu-btns button:hover,
.comfy-menu .comfy-list button:hover, .comfy-menu .comfy-list button:hover,

Loading…
Cancel
Save