diff --git a/README.md b/README.md
index ff3ab642..a94a212a 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
-- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
+- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
diff --git a/comfy/clip_model.py b/comfy/clip_model.py
index 09e7bbca..9b82a246 100644
--- a/comfy/clip_model.py
+++ b/comfy/clip_model.py
@@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
- mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index d9d990a7..41619758 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -318,9 +318,10 @@ def load_controlnet(ckpt_path, model=None):
return ControlLora(controlnet_data)
controlnet_config = None
+ supported_inference_dtypes = None
+
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
- unet_dtype = comfy.model_management.unet_dtype()
- controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@@ -380,12 +381,20 @@ def load_controlnet(ckpt_path, model=None):
return net
if controlnet_config is None:
- unet_dtype = comfy.model_management.unet_dtype()
- controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
+ supported_inference_dtypes = model_config.supported_inference_dtypes
+ controlnet_config = model_config.unet_config
+
load_device = comfy.model_management.get_torch_device()
+ if supported_inference_dtypes is None:
+ unet_dtype = comfy.model_management.unet_dtype()
+ else:
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
+ controlnet_config["dtype"] = unet_dtype
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
diff --git a/comfy/gligen.py b/comfy/gligen.py
index 71892dfb..59252276 100644
--- a/comfy/gligen.py
+++ b/comfy/gligen.py
@@ -2,7 +2,8 @@ import torch
from torch import nn
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
-
+import comfy.ops
+ops = comfy.ops.manual_cast
def exists(val):
return val is not None
@@ -22,7 +23,7 @@ def default(val, d):
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
+ self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@@ -35,14 +36,14 @@ class FeedForward(nn.Module):
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
- nn.Linear(dim, inner_dim),
+ ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
- nn.Linear(inner_dim, dim_out)
+ ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
@@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module):
query_dim=query_dim,
context_dim=context_dim,
heads=n_heads,
- dim_head=d_head)
+ dim_head=d_head,
+ operations=ops)
self.ff = FeedForward(query_dim, glu=True)
- self.norm1 = nn.LayerNorm(query_dim)
- self.norm2 = nn.LayerNorm(query_dim)
+ self.norm1 = ops.LayerNorm(query_dim)
+ self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module):
# we need a linear projection since we need cat visual feature and obj
# feature
- self.linear = nn.Linear(context_dim, query_dim)
+ self.linear = ops.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=query_dim,
heads=n_heads,
- dim_head=d_head)
+ dim_head=d_head,
+ operations=ops)
self.ff = FeedForward(query_dim, glu=True)
- self.norm1 = nn.LayerNorm(query_dim)
- self.norm2 = nn.LayerNorm(query_dim)
+ self.norm1 = ops.LayerNorm(query_dim)
+ self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module):
# we need a linear projection since we need cat visual feature and obj
# feature
- self.linear = nn.Linear(context_dim, query_dim)
+ self.linear = ops.Linear(context_dim, query_dim)
self.attn = CrossAttention(
- query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
+ query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
self.ff = FeedForward(query_dim, glu=True)
- self.norm1 = nn.LayerNorm(query_dim)
- self.norm2 = nn.LayerNorm(query_dim)
+ self.norm1 = ops.LayerNorm(query_dim)
+ self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@@ -201,11 +204,11 @@ class PositionNet(nn.Module):
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
self.linears = nn.Sequential(
- nn.Linear(self.in_dim + self.position_dim, 512),
+ ops.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
- nn.Linear(512, 512),
+ ops.Linear(512, 512),
nn.SiLU(),
- nn.Linear(512, out_dim),
+ ops.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(
@@ -215,16 +218,15 @@ class PositionNet(nn.Module):
def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
- dtype = self.linears[0].weight.dtype
- masks = masks.unsqueeze(-1).to(dtype)
- positive_embeddings = positive_embeddings.to(dtype)
+ masks = masks.unsqueeze(-1)
+ positive_embeddings = positive_embeddings
# embedding position (it may includes padding as placeholder)
- xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
# learnable null embedding
- positive_null = self.null_positive_feature.view(1, 1, -1)
- xyxy_null = self.null_position_feature.view(1, 1, -1)
+ positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
+ xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \
@@ -251,7 +253,7 @@ class Gligen(nn.Module):
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
- return module(x, objs)
+ return module(x, objs.to(device=x.device, dtype=x.dtype))
return func
def set_position(self, latent_image_shape, position_params, device):
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 2252a075..03fd59e3 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -37,3 +37,41 @@ class SDXL(LatentFormat):
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
+ self.latent_rgb_factors = [
+ [-0.2340, -0.3863, -0.3257],
+ [ 0.0994, 0.0885, -0.0908],
+ [-0.2833, -0.2349, -0.3741],
+ [ 0.2523, -0.0055, -0.1651]
+ ]
+
+class SC_Prior(LatentFormat):
+ def __init__(self):
+ self.scale_factor = 1.0
+ self.latent_rgb_factors = [
+ [-0.0326, -0.0204, -0.0127],
+ [-0.1592, -0.0427, 0.0216],
+ [ 0.0873, 0.0638, -0.0020],
+ [-0.0602, 0.0442, 0.1304],
+ [ 0.0800, -0.0313, -0.1796],
+ [-0.0810, -0.0638, -0.1581],
+ [ 0.1791, 0.1180, 0.0967],
+ [ 0.0740, 0.1416, 0.0432],
+ [-0.1745, -0.1888, -0.1373],
+ [ 0.2412, 0.1577, 0.0928],
+ [ 0.1908, 0.0998, 0.0682],
+ [ 0.0209, 0.0365, -0.0092],
+ [ 0.0448, -0.0650, -0.1728],
+ [-0.1658, -0.1045, -0.1308],
+ [ 0.0542, 0.1545, 0.1325],
+ [-0.0352, -0.1672, -0.2541]
+ ]
+
+class SC_B(LatentFormat):
+ def __init__(self):
+ self.scale_factor = 1.0
+ self.latent_rgb_factors = [
+ [ 0.1121, 0.2006, 0.1023],
+ [-0.2093, -0.0222, -0.0195],
+ [-0.3087, -0.1535, 0.0366],
+ [ 0.0290, -0.1574, -0.4078]
+ ]
diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py
new file mode 100644
index 00000000..124902c0
--- /dev/null
+++ b/comfy/ldm/cascade/common.py
@@ -0,0 +1,161 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+import torch.nn as nn
+from comfy.ldm.modules.attention import optimized_attention
+
+class Linear(torch.nn.Linear):
+ def reset_parameters(self):
+ return None
+
+class Conv2d(torch.nn.Conv2d):
+ def reset_parameters(self):
+ return None
+
+class OptimizedAttention(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.heads = nhead
+
+ self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+ self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+ self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+
+ def forward(self, q, k, v):
+ q = self.to_q(q)
+ k = self.to_k(k)
+ v = self.to_v(v)
+
+ out = optimized_attention(q, k, v, self.heads)
+
+ return self.out_proj(out)
+
+class Attention2D(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
+ # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
+
+ def forward(self, x, kv, self_attn=False):
+ orig_shape = x.shape
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
+ if self_attn:
+ kv = torch.cat([x, kv], dim=1)
+ # x = self.attn(x, kv, kv, need_weights=False)[0]
+ x = self.attn(x, kv, kv)
+ x = x.permute(0, 2, 1).view(*orig_shape)
+ return x
+
+
+def LayerNorm2d_op(operations):
+ class LayerNorm2d(operations.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return LayerNorm2d
+
+class GlobalResponseNorm(nn.Module):
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
+ def __init__(self, dim, dtype=None, device=None):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
+ super().__init__()
+ self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.channelwise = nn.Sequential(
+ operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
+ nn.Dropout(dropout),
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x, x_skip=None):
+ x_res = x
+ x = self.norm(self.depthwise(x))
+ if x_skip is not None:
+ x = torch.cat([x, x_skip], dim=1)
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x + x_res
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(c_cond, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x, kv):
+ kv = self.kv_mapper(kv)
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
+ return x
+
+
+class FeedForwardBlock(nn.Module):
+ def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.channelwise = nn.Sequential(
+ operations.Linear(c, c * 4, dtype=dtype, device=device),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
+ nn.Dropout(dropout),
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x):
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x
+
+
+class TimestepBlock(nn.Module):
+ def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
+ super().__init__()
+ self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
+ self.conds = conds
+ for cname in conds:
+ setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
+
+ def forward(self, x, t):
+ t = t.chunk(len(self.conds) + 1, dim=1)
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
+ for i, c in enumerate(self.conds):
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
+ a, b = a + ac, b + bc
+ return x * (1 + a) + b
diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py
new file mode 100644
index 00000000..260ccfc0
--- /dev/null
+++ b/comfy/ldm/cascade/stage_a.py
@@ -0,0 +1,258 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+from torch import nn
+from torch.autograd import Function
+
+class vector_quantize(Function):
+ @staticmethod
+ def forward(ctx, x, codebook):
+ with torch.no_grad():
+ codebook_sqr = torch.sum(codebook ** 2, dim=1)
+ x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
+
+ dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
+ _, indices = dist.min(dim=1)
+
+ ctx.save_for_backward(indices, codebook)
+ ctx.mark_non_differentiable(indices)
+
+ nn = torch.index_select(codebook, 0, indices)
+ return nn, indices
+
+ @staticmethod
+ def backward(ctx, grad_output, grad_indices):
+ grad_inputs, grad_codebook = None, None
+
+ if ctx.needs_input_grad[0]:
+ grad_inputs = grad_output.clone()
+ if ctx.needs_input_grad[1]:
+ # Gradient wrt. the codebook
+ indices, codebook = ctx.saved_tensors
+
+ grad_codebook = torch.zeros_like(codebook)
+ grad_codebook.index_add_(0, indices, grad_output)
+
+ return (grad_inputs, grad_codebook)
+
+
+class VectorQuantize(nn.Module):
+ def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
+ """
+ Takes an input of variable size (as long as the last dimension matches the embedding size).
+ Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
+ with the same size as the input, vq and commitment components for the loss as a touple
+ in the second output and the indices of the quantized vectors in the third:
+ quantized, (vq_loss, commit_loss), indices
+ """
+ super(VectorQuantize, self).__init__()
+
+ self.codebook = nn.Embedding(k, embedding_size)
+ self.codebook.weight.data.uniform_(-1./k, 1./k)
+ self.vq = vector_quantize.apply
+
+ self.ema_decay = ema_decay
+ self.ema_loss = ema_loss
+ if ema_loss:
+ self.register_buffer('ema_element_count', torch.ones(k))
+ self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
+
+ def _laplace_smoothing(self, x, epsilon):
+ n = torch.sum(x)
+ return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
+
+ def _updateEMA(self, z_e_x, indices):
+ mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
+ elem_count = mask.sum(dim=0)
+ weight_sum = torch.mm(mask.t(), z_e_x)
+
+ self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
+ self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
+ self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
+
+ self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
+
+ def idx2vq(self, idx, dim=-1):
+ q_idx = self.codebook(idx)
+ if dim != -1:
+ q_idx = q_idx.movedim(-1, dim)
+ return q_idx
+
+ def forward(self, x, get_losses=True, dim=-1):
+ if dim != -1:
+ x = x.movedim(dim, -1)
+ z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
+ z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
+ vq_loss, commit_loss = None, None
+ if self.ema_loss and self.training:
+ self._updateEMA(z_e_x.detach(), indices.detach())
+ # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
+ z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
+ if get_losses:
+ vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
+ commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
+
+ z_q_x = z_q_x.view(x.shape)
+ if dim != -1:
+ z_q_x = z_q_x.movedim(-1, dim)
+ return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
+
+
+class ResBlock(nn.Module):
+ def __init__(self, c, c_hidden):
+ super().__init__()
+ # depthwise/attention
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.depthwise = nn.Sequential(
+ nn.ReplicationPad2d(1),
+ nn.Conv2d(c, c, kernel_size=3, groups=c)
+ )
+
+ # channelwise
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ nn.Linear(c, c_hidden),
+ nn.GELU(),
+ nn.Linear(c_hidden, c),
+ )
+
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
+
+ # Init weights
+ def _basic_init(module):
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ def _norm(self, x, norm):
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+ def forward(self, x):
+ mods = self.gammas
+
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
+ try:
+ x = x + self.depthwise(x_temp) * mods[2]
+ except: #operation not implemented for bf16
+ x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
+ x = x + self.depthwise[1](x_temp) * mods[2]
+
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
+
+ return x
+
+
+class StageA(nn.Module):
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
+ scale_factor=0.43): # 0.3764
+ super().__init__()
+ self.c_latent = c_latent
+ self.scale_factor = scale_factor
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
+
+ # Encoder blocks
+ self.in_block = nn.Sequential(
+ nn.PixelUnshuffle(2),
+ nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
+ )
+ down_blocks = []
+ for i in range(levels):
+ if i > 0:
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
+ block = ResBlock(c_levels[i], c_levels[i] * 4)
+ down_blocks.append(block)
+ down_blocks.append(nn.Sequential(
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
+ ))
+ self.down_blocks = nn.Sequential(*down_blocks)
+ self.down_blocks[0]
+
+ self.codebook_size = codebook_size
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
+
+ # Decoder blocks
+ up_blocks = [nn.Sequential(
+ nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
+ )]
+ for i in range(levels):
+ for j in range(bottleneck_blocks if i == 0 else 1):
+ block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
+ up_blocks.append(block)
+ if i < levels - 1:
+ up_blocks.append(
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
+ padding=1))
+ self.up_blocks = nn.Sequential(*up_blocks)
+ self.out_block = nn.Sequential(
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
+ nn.PixelShuffle(2),
+ )
+
+ def encode(self, x, quantize=False):
+ x = self.in_block(x)
+ x = self.down_blocks(x)
+ if quantize:
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
+ return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
+ else:
+ return x / self.scale_factor
+
+ def decode(self, x):
+ x = x * self.scale_factor
+ x = self.up_blocks(x)
+ x = self.out_block(x)
+ return x
+
+ def forward(self, x, quantize=False):
+ qe, x, _, vq_loss = self.encode(x, quantize)
+ x = self.decode(qe)
+ return x, vq_loss
+
+
+class Discriminator(nn.Module):
+ def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
+ super().__init__()
+ d = max(depth - 3, 3)
+ layers = [
+ nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
+ nn.LeakyReLU(0.2),
+ ]
+ for i in range(depth - 1):
+ c_in = c_hidden // (2 ** max((d - i), 0))
+ c_out = c_hidden // (2 ** max((d - 1 - i), 0))
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
+ layers.append(nn.InstanceNorm2d(c_out))
+ layers.append(nn.LeakyReLU(0.2))
+ self.encoder = nn.Sequential(*layers)
+ self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
+ self.logits = nn.Sigmoid()
+
+ def forward(self, x, cond=None):
+ x = self.encoder(x)
+ if cond is not None:
+ cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
+ x = torch.cat([x, cond], dim=1)
+ x = self.shuffle(x)
+ x = self.logits(x)
+ return x
diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py
new file mode 100644
index 00000000..6d2c2223
--- /dev/null
+++ b/comfy/ldm/cascade/stage_b.py
@@ -0,0 +1,257 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import math
+import numpy as np
+import torch
+from torch import nn
+from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
+
+class StageB(nn.Module):
+ def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
+ nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
+ c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
+ t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.dtype = dtype
+ self.c_r = c_r
+ self.t_conds = t_conds
+ self.c_clip_seq = c_clip_seq
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+ if not isinstance(self_attn, list):
+ self_attn = [self_attn] * len(c_hidden)
+
+ # CONDITIONING
+ self.effnet_mapper = nn.Sequential(
+ operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
+ nn.GELU(),
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ )
+ self.pixels_mapper = nn.Sequential(
+ operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
+ nn.GELU(),
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ )
+ self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
+ if block_type == 'C':
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'A':
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'F':
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'T':
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
+ else:
+ raise Exception(f'Block type {block_type} not supported')
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ self.down_downscalers = nn.ModuleList()
+ self.down_repeat_mappers = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ if i > 0:
+ self.down_downscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
+ operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
+ ))
+ else:
+ self.down_downscalers.append(nn.Identity())
+ down_block = nn.ModuleList()
+ for _ in range(blocks[0][i]):
+ for block_type in level_config[i]:
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
+ down_block.append(block)
+ self.down_blocks.append(down_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[0][i] - 1):
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
+ self.down_repeat_mappers.append(block_repeat_mappers)
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ self.up_upscalers = nn.ModuleList()
+ self.up_repeat_mappers = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ if i > 0:
+ self.up_upscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
+ operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
+ ))
+ else:
+ self.up_upscalers.append(nn.Identity())
+ up_block = nn.ModuleList()
+ for j in range(blocks[1][::-1][i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
+ self_attn=self_attn[i])
+ up_block.append(block)
+ self.up_blocks.append(up_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[1][::-1][i] - 1):
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
+ self.up_repeat_mappers.append(block_repeat_mappers)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ # self.apply(self._init_weights) # General init
+ # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
+ # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
+ # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
+ # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
+ # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
+ #
+ # # blocks
+ # for level_block in self.down_blocks + self.up_blocks:
+ # for block in level_block:
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ # elif isinstance(block, TimestepBlock):
+ # for layer in block.modules():
+ # if isinstance(layer, nn.Linear):
+ # nn.init.constant_(layer.weight, 0)
+ #
+ # def _init_weights(self, m):
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
+ # torch.nn.init.xavier_uniform_(m.weight)
+ # if m.bias is not None:
+ # nn.init.constant_(m.bias, 0)
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+
+ def gen_c_embeddings(self, clip):
+ if len(clip.shape) == 2:
+ clip = clip.unsqueeze(1)
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
+ clip = self.clip_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, clip):
+ level_outputs = []
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
+ for down_block, downscaler, repmap in block_group:
+ x = downscaler(x)
+ for i in range(len(repmap) + 1):
+ for block in down_block:
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ x = block(x)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if i < len(repmap):
+ x = repmap[i](x)
+ level_outputs.insert(0, x)
+ return level_outputs
+
+ def _up_decode(self, level_outputs, r_embed, clip):
+ x = level_outputs[0]
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
+ for j in range(len(repmap) + 1):
+ for k, block in enumerate(up_block):
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ skip = level_outputs[i] if k == 0 and i > 0 else None
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if j < len(repmap):
+ x = repmap[j](x)
+ x = upscaler(x)
+ return x
+
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
+ if pixels is None:
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
+
+ # Process the conditioning embeddings
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
+ for c in self.t_conds:
+ t_cond = kwargs.get(c, torch.zeros_like(r))
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
+ clip = self.gen_c_embeddings(clip)
+
+ # Model Blocks
+ x = self.embedding(x)
+ x = x + self.effnet_mapper(
+ nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
+ x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ level_outputs = self._down_encode(x, r_embed, clip)
+ x = self._up_decode(level_outputs, r_embed, clip)
+ return self.clf(x)
+
+ def update_weights_ema(self, src_model, beta=0.999):
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py
new file mode 100644
index 00000000..08e33ade
--- /dev/null
+++ b/comfy/ldm/cascade/stage_c.py
@@ -0,0 +1,271 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+from torch import nn
+import numpy as np
+import math
+from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
+# from .controlnet import ControlNetDeliverer
+
+class UpDownBlock2d(nn.Module):
+ def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ assert mode in ['up', 'down']
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
+ align_corners=True) if enabled else nn.Identity()
+ mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+
+class StageC(nn.Module):
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
+ dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
+ dtype=None, device=None, operations=None):
+ super().__init__()
+ self.dtype = dtype
+ self.c_r = c_r
+ self.t_conds = t_conds
+ self.c_clip_seq = c_clip_seq
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+ if not isinstance(self_attn, list):
+ self_attn = [self_attn] * len(c_hidden)
+
+ # CONDITIONING
+ self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
+ self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
+ self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
+ if block_type == 'C':
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'A':
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'F':
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'T':
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
+ else:
+ raise Exception(f'Block type {block_type} not supported')
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ self.down_downscalers = nn.ModuleList()
+ self.down_repeat_mappers = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ if i > 0:
+ self.down_downscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
+ ))
+ else:
+ self.down_downscalers.append(nn.Identity())
+ down_block = nn.ModuleList()
+ for _ in range(blocks[0][i]):
+ for block_type in level_config[i]:
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
+ down_block.append(block)
+ self.down_blocks.append(down_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[0][i] - 1):
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
+ self.down_repeat_mappers.append(block_repeat_mappers)
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ self.up_upscalers = nn.ModuleList()
+ self.up_repeat_mappers = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ if i > 0:
+ self.up_upscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
+ ))
+ else:
+ self.up_upscalers.append(nn.Identity())
+ up_block = nn.ModuleList()
+ for j in range(blocks[1][::-1][i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
+ self_attn=self_attn[i])
+ up_block.append(block)
+ self.up_blocks.append(up_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[1][::-1][i] - 1):
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
+ self.up_repeat_mappers.append(block_repeat_mappers)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ # self.apply(self._init_weights) # General init
+ # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
+ # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
+ # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
+ #
+ # # blocks
+ # for level_block in self.down_blocks + self.up_blocks:
+ # for block in level_block:
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ # elif isinstance(block, TimestepBlock):
+ # for layer in block.modules():
+ # if isinstance(layer, nn.Linear):
+ # nn.init.constant_(layer.weight, 0)
+ #
+ # def _init_weights(self, m):
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
+ # torch.nn.init.xavier_uniform_(m.weight)
+ # if m.bias is not None:
+ # nn.init.constant_(m.bias, 0)
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
+ clip_txt = self.clip_txt_mapper(clip_txt)
+ if len(clip_txt_pooled.shape) == 2:
+ clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
+ if len(clip_img.shape) == 2:
+ clip_img = clip_img.unsqueeze(1)
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
+ clip = self.clip_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, clip, cnet=None):
+ level_outputs = []
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
+ for down_block, downscaler, repmap in block_group:
+ x = downscaler(x)
+ for i in range(len(repmap) + 1):
+ for block in down_block:
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ if cnet is not None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if i < len(repmap):
+ x = repmap[i](x)
+ level_outputs.insert(0, x)
+ return level_outputs
+
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
+ x = level_outputs[0]
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
+ for j in range(len(repmap) + 1):
+ for k, block in enumerate(up_block):
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ skip = level_outputs[i] if k == 0 and i > 0 else None
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
+ align_corners=True)
+ if cnet is not None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if j < len(repmap):
+ x = repmap[j](x)
+ x = upscaler(x)
+ return x
+
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
+ # Process the conditioning embeddings
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
+ for c in self.t_conds:
+ t_cond = kwargs.get(c, torch.zeros_like(r))
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
+
+ # Model Blocks
+ x = self.embedding(x)
+ if cnet is not None:
+ cnet = ControlNetDeliverer(cnet)
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
+ x = self._up_decode(level_outputs, r_embed, clip, cnet)
+ return self.clf(x)
+
+ def update_weights_ema(self, src_model, beta=0.999):
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py
new file mode 100644
index 00000000..0cb7c49f
--- /dev/null
+++ b/comfy/ldm/cascade/stage_c_coder.py
@@ -0,0 +1,95 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+import torch
+import torchvision
+from torch import nn
+
+
+# EfficientNet
+class EfficientNetEncoder(nn.Module):
+ def __init__(self, c_latent=16):
+ super().__init__()
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
+ self.mapper = nn.Sequential(
+ nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
+ )
+ self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
+ self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
+
+ def forward(self, x):
+ x = x * 0.5 + 0.5
+ x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
+ o = self.mapper(self.backbone(x))
+ return o
+
+
+# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
+class Previewer(nn.Module):
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
+ super().__init__()
+ self.blocks = nn.Sequential(
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden),
+
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden),
+
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 2),
+
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 2),
+
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
+ )
+
+ def forward(self, x):
+ return (self.blocks(x) - 0.5) * 2.0
+
+class StageC_coder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.previewer = Previewer()
+ self.encoder = EfficientNetEncoder()
+
+ def encode(self, x):
+ return self.encoder(x)
+
+ def decode(self, x):
+ return self.previewer(x)
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 9c9cb761..48399bc0 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -114,7 +114,12 @@ def attention_basic(q, k, v, heads, mask=None):
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
- sim += mask
+ if len(mask.shape) == 2:
+ bs = 1
+ else:
+ bs = mask.shape[0]
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
+ sim.add_(mask)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
@@ -165,6 +170,13 @@ def attention_sub_quad(query, key, value, heads, mask=None):
if query_chunk_size is None:
query_chunk_size = 512
+ if mask is not None:
+ if len(mask.shape) == 2:
+ bs = 1
+ else:
+ bs = mask.shape[0]
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
+
hidden_states = efficient_dot_product_attention(
query,
key,
@@ -223,6 +235,13 @@ def attention_split(q, k, v, heads, mask=None):
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
+ if mask is not None:
+ if len(mask.shape) == 2:
+ bs = 1
+ else:
+ bs = mask.shape[0]
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
+
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
diff --git a/comfy/model_base.py b/comfy/model_base.py
index aafb88e0..421f271b 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1,5 +1,7 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
+from comfy.ldm.cascade.stage_c import StageC
+from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
import comfy.model_management
@@ -12,9 +14,10 @@ class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
V_PREDICTION_EDM = 3
+ STABLE_CASCADE = 4
-from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
+from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
def model_sampling(model_config, model_type):
@@ -27,6 +30,9 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
+ elif model_type == ModelType.STABLE_CASCADE:
+ c = EPS
+ s = StableCascadeSampling
class ModelSampling(s, c):
pass
@@ -35,7 +41,7 @@ def model_sampling(model_config, model_type):
class BaseModel(torch.nn.Module):
- def __init__(self, model_config, model_type=ModelType.EPS, device=None):
+ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
unet_config = model_config.unet_config
@@ -48,7 +54,7 @@ class BaseModel(torch.nn.Module):
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
- self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
+ self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@@ -427,3 +433,52 @@ class SD_X4Upscaler(BaseModel):
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
return out
+
+class StableCascade_C(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=StageC)
+ self.diffusion_model.eval().requires_grad_(False)
+
+ def extra_conds(self, **kwargs):
+ out = {}
+ clip_text_pooled = kwargs["pooled_output"]
+ if clip_text_pooled is not None:
+ out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
+
+ if "unclip_conditioning" in kwargs:
+ embeds = []
+ for unclip_cond in kwargs["unclip_conditioning"]:
+ weight = unclip_cond["strength"]
+ embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
+ clip_img = torch.cat(embeds, dim=1)
+ else:
+ clip_img = torch.zeros((1, 1, 768))
+ out["clip_img"] = comfy.conds.CONDRegular(clip_img)
+ out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
+ out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
+ return out
+
+
+class StableCascade_B(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=StageB)
+ self.diffusion_model.eval().requires_grad_(False)
+
+ def extra_conds(self, **kwargs):
+ out = {}
+ noise = kwargs.get("noise", None)
+
+ clip_text_pooled = kwargs["pooled_output"]
+ if clip_text_pooled is not None:
+ out['clip'] = comfy.conds.CONDRegular(clip_text_pooled)
+
+ #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
+ prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
+
+ out["effnet"] = comfy.conds.CONDRegular(prior)
+ out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
+ return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index ea824c44..8fca6d8c 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -28,9 +28,38 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
-def detect_unet_config(state_dict, key_prefix, dtype):
+def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys())
+ if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
+ unet_config = {}
+ text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
+ if text_mapper_name in state_dict_keys:
+ unet_config['stable_cascade_stage'] = 'c'
+ w = state_dict[text_mapper_name]
+ if w.shape[0] == 1536: #stage c lite
+ unet_config['c_cond'] = 1536
+ unet_config['c_hidden'] = [1536, 1536]
+ unet_config['nhead'] = [24, 24]
+ unet_config['blocks'] = [[4, 12], [12, 4]]
+ elif w.shape[0] == 2048: #stage c full
+ unet_config['c_cond'] = 2048
+ elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
+ unet_config['stable_cascade_stage'] = 'b'
+ w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
+ if w.shape[-1] == 640:
+ unet_config['c_hidden'] = [320, 640, 1280, 1280]
+ unet_config['nhead'] = [-1, -1, 20, 20]
+ unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
+ unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
+ elif w.shape[-1] == 576: #stage b lite
+ unet_config['c_hidden'] = [320, 576, 1152, 1152]
+ unet_config['nhead'] = [-1, 9, 18, 18]
+ unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
+ unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
+
+ return unet_config
+
unet_config = {
"use_checkpoint": False,
"image_size": 32,
@@ -45,7 +74,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
else:
unet_config["adm_in_channels"] = None
- unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
@@ -159,8 +187,8 @@ def model_config_from_unet_config(unet_config):
print("no match", unet_config)
return None
-def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
- unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
+def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
+ unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
@@ -206,7 +234,7 @@ def convert_config(unet_config):
return new_config
-def unet_config_from_diffusers_unet(state_dict, dtype):
+def unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {}
transformer_depth = []
@@ -313,8 +341,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
return convert_config(unet_config)
return None
-def model_config_from_diffusers_unet(state_dict, dtype):
- unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
+def model_config_from_diffusers_unet(state_dict):
+ unet_config = unet_config_from_diffusers_unet(state_dict)
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None
diff --git a/comfy/model_management.py b/comfy/model_management.py
index a8dc91b9..adcc0e8a 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
-def unet_dtype(device=None, model_params=0):
+def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -497,20 +497,31 @@ def unet_dtype(device=None, model_params=0):
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
- return torch.float16
+ if torch.float16 in supported_dtypes:
+ return torch.float16
+ if should_use_bf16(device, model_params=model_params, manual_cast=True):
+ if torch.bfloat16 in supported_dtypes:
+ return torch.bfloat16
return torch.float32
# None means no manual cast
-def unet_manual_cast(weight_dtype, inference_device):
+def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
return None
- fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
+ fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
- if fp16_supported:
+ bf16_supported = should_use_bf16(inference_device)
+ if bf16_supported and weight_dtype == torch.bfloat16:
+ return None
+
+ if fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
+
+ elif bf16_supported and torch.bfloat16 in supported_dtypes:
+ return torch.bfloat16
else:
return torch.float32
@@ -684,17 +695,20 @@ def mps_mode():
global cpu_state
return cpu_state == CPUState.MPS
-def is_device_cpu(device):
+def is_device_type(device, type):
if hasattr(device, 'type'):
- if (device.type == 'cpu'):
+ if (device.type == type):
return True
return False
+def is_device_cpu(device):
+ return is_device_type(device, 'cpu')
+
def is_device_mps(device):
- if hasattr(device, 'type'):
- if (device.type == 'mps'):
- return True
- return False
+ return is_device_type(device, 'mps')
+
+def is_device_cuda(device):
+ return is_device_type(device, 'cuda')
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled
@@ -706,9 +720,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP16:
return True
- if device is not None: #TODO
+ if device is not None:
if is_device_mps(device):
- return False
+ return True
if FORCE_FP32:
return False
@@ -716,8 +730,11 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if directml_enabled:
return False
- if cpu_mode() or mps_mode():
- return False #TODO ?
+ if mps_mode():
+ return True
+
+ if cpu_mode():
+ return False
if is_intel_xpu():
return True
@@ -757,6 +774,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True
+def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
+ if device is not None:
+ if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
+ return False
+
+ if device is not None: #TODO not sure about mps bf16 support
+ if is_device_mps(device):
+ return False
+
+ if FORCE_FP32:
+ return False
+
+ if directml_enabled:
+ return False
+
+ if cpu_mode() or mps_mode():
+ return False
+
+ if is_intel_xpu():
+ return True
+
+ if device is None:
+ device = torch.device("cuda")
+
+ props = torch.cuda.get_device_properties(device)
+ if props.major >= 8:
+ return True
+
+ bf16_works = torch.cuda.is_bf16_supported()
+
+ if bf16_works or manual_cast:
+ free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
+ if (not prioritize_performance) or model_params * 4 > free_model_memory:
+ return True
+
+ return False
+
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py
index d5870027..97e91a01 100644
--- a/comfy/model_sampling.py
+++ b/comfy/model_sampling.py
@@ -132,3 +132,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
+
+class StableCascadeSampling(ModelSamplingDiscrete):
+ def __init__(self, model_config=None):
+ super().__init__()
+
+ if model_config is not None:
+ sampling_settings = model_config.sampling_settings
+ else:
+ sampling_settings = {}
+
+ self.set_parameters(sampling_settings.get("shift", 1.0))
+
+ def set_parameters(self, shift=1.0, cosine_s=8e-3):
+ self.shift = shift
+ self.cosine_s = torch.tensor(cosine_s)
+ self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
+
+ #This part is just for compatibility with some schedulers in the codebase
+ self.num_timesteps = 10000
+ sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
+ for x in range(self.num_timesteps):
+ t = (x + 1) / self.num_timesteps
+ sigmas[x] = self.sigma(t)
+
+ self.set_sigmas(sigmas)
+
+ def sigma(self, timestep):
+ alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
+
+ if self.shift != 1.0:
+ var = alpha_cumprod
+ logSNR = (var/(1-var)).log()
+ logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
+ alpha_cumprod = logSNR.sigmoid()
+
+ alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
+ return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
+
+ def timestep(self, sigma):
+ var = 1 / ((sigma * sigma) + 1)
+ var = var.clamp(0, 1.0)
+ s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
+ t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
+ return t
+
+ def percent_to_sigma(self, percent):
+ if percent <= 0.0:
+ return 999999999.9
+ if percent >= 1.0:
+ return 0.0
+
+ percent = 1.0 - percent
+ return self.sigma(torch.tensor(percent))
diff --git a/comfy/ops.py b/comfy/ops.py
index f674b47f..517688e8 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -1,3 +1,21 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
import torch
import comfy.model_management
@@ -78,7 +96,11 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
+ if self.weight is not None:
+ weight, bias = cast_bias_weight(self, input)
+ else:
+ weight = None
+ bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
@@ -87,6 +109,28 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
+ class ConvTranspose2d(torch.nn.ConvTranspose2d):
+ comfy_cast_weights = False
+ def reset_parameters(self):
+ return None
+
+ def forward_comfy_cast_weights(self, input, output_size=None):
+ num_spatial_dims = 2
+ output_padding = self._output_padding(
+ input, output_size, self.stride, self.padding, self.kernel_size,
+ num_spatial_dims, self.dilation)
+
+ weight, bias = cast_bias_weight(self, input)
+ return torch.nn.functional.conv_transpose2d(
+ input, weight, bias, self.stride, self.padding,
+ output_padding, self.groups, self.dilation)
+
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
@@ -112,3 +156,6 @@ class manual_cast(disable_weight_init):
class LayerNorm(disable_weight_init.LayerNorm):
comfy_cast_weights = True
+
+ class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
+ comfy_cast_weights = True
diff --git a/comfy/samplers.py b/comfy/samplers.py
index f2ac3c5d..c795f208 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -652,6 +652,7 @@ def sampler_object(name):
class KSampler:
SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES
+ DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2'))
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
@@ -670,7 +671,7 @@ class KSampler:
sigmas = None
discard_penultimate_sigma = False
- if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
+ if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS:
steps += 1
discard_penultimate_sigma = True
diff --git a/comfy/sd.py b/comfy/sd.py
index c15d73fe..7a77bb17 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -1,7 +1,11 @@
import torch
+from enum import Enum
from comfy import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
+from .ldm.cascade.stage_a import StageA
+from .ldm.cascade.stage_c_coder import StageC_coder
+
import yaml
import comfy.utils
@@ -134,8 +138,11 @@ class CLIP:
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
- def load_sd(self, sd):
- return self.cond_stage_model.load_sd(sd)
+ def load_sd(self, sd, full_model=False):
+ if full_model:
+ return self.cond_stage_model.load_state_dict(sd, strict=False)
+ else:
+ return self.cond_stage_model.load_sd(sd)
def get_sd(self):
return self.cond_stage_model.state_dict()
@@ -155,7 +162,10 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
+ self.upscale_ratio = 8
self.latent_channels = 4
+ self.process_input = lambda image: image * 2.0 - 1.0
+ self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -168,6 +178,34 @@ class VAE:
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
+ elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
+ self.first_stage_model = StageA()
+ self.downscale_ratio = 4
+ self.upscale_ratio = 4
+ #TODO
+ #self.memory_used_encode
+ #self.memory_used_decode
+ self.process_input = lambda image: image
+ self.process_output = lambda image: image
+ elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
+ self.first_stage_model = StageC_coder()
+ self.downscale_ratio = 32
+ self.latent_channels = 16
+ new_sd = {}
+ for k in sd:
+ new_sd["encoder.{}".format(k)] = sd[k]
+ sd = new_sd
+ elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
+ self.first_stage_model = StageC_coder()
+ self.latent_channels = 16
+ new_sd = {}
+ for k in sd:
+ new_sd["previewer.{}".format(k)] = sd[k]
+ sd = new_sd
+ elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
+ self.first_stage_model = StageC_coder()
+ self.downscale_ratio = 32
+ self.latent_channels = 16
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -175,6 +213,7 @@ class VAE:
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
+ self.upscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
@@ -200,18 +239,27 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
+ def vae_encode_crop_pixels(self, pixels):
+ x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
+ y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
+ if pixels.shape[1] != x or pixels.shape[2] != y:
+ x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
+ y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
+ pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
+ return pixels
+
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
- decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
- output = torch.clamp((
- (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
- comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
- comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
- / 3.0) / 2.0, min=0.0, max=1.0)
+ decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
+ output = self.process_output(
+ (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
+ comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
+ comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
+ / 3.0)
return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@@ -220,7 +268,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
- encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
+ encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@@ -235,10 +283,10 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
- pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
+ pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
- pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
+ pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
@@ -252,6 +300,7 @@ class VAE:
return output.movedim(1,-1)
def encode(self, pixel_samples):
+ pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
@@ -261,7 +310,7 @@ class VAE:
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
- pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
+ pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e:
@@ -271,6 +320,7 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
+ pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
@@ -297,8 +347,11 @@ def load_style_model(ckpt_path):
model.load_state_dict(model_data)
return StyleModel(model)
+class CLIPType(Enum):
+ STABLE_DIFFUSION = 1
+ STABLE_CASCADE = 2
-def load_clip(ckpt_paths, embedding_directory=None):
+def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
@@ -314,8 +367,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
clip_target.params = {}
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
- clip_target.clip = sdxl_clip.SDXLRefinerClipModel
- clip_target.tokenizer = sdxl_clip.SDXLTokenizer
+ if clip_type == CLIPType.STABLE_CASCADE:
+ clip_target.clip = sdxl_clip.StableCascadeClipModel
+ clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
+ else:
+ clip_target.clip = sdxl_clip.SDXLRefinerClipModel
+ clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
@@ -438,15 +495,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
- unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
-
- class WeightsLoader(torch.nn.Module):
- pass
- model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
- model_config.set_manual_cast(manual_cast_dtype)
+ model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@@ -467,13 +521,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
vae = VAE(sd=vae_sd)
if output_clip:
- w = WeightsLoader()
clip_target = model_config.clip_target()
if clip_target is not None:
- clip = CLIP(clip_target, embedding_directory=embedding_directory)
- w.cond_stage_model = clip.cond_stage_model
- sd = model_config.process_clip_state_dict(sd)
- load_model_weights(w, sd)
+ clip_sd = model_config.process_clip_state_dict(sd)
+ if len(clip_sd) > 0:
+ clip = CLIP(clip_target, embedding_directory=embedding_directory)
+ m, u = clip.load_sd(clip_sd, full_model=True)
+ if len(m) > 0:
+ print("clip missing:", m)
+
+ if len(u) > 0:
+ print("clip unexpected:", u)
+ else:
+ print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
left_over = sd.keys()
if len(left_over) > 0:
@@ -492,16 +552,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
- if "input_blocks.0.0.weight" in sd: #ldm
- model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
+ if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
+ model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None
new_sd = sd
else: #diffusers
- model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
+ model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
@@ -513,8 +572,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
+
offload_device = model_management.unet_offload_device()
- model_config.set_manual_cast(manual_cast_dtype)
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 65ea909f..8287ad2e 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
- special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
+ special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
@@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
- self.enable_attention_masks = False
+ self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py
index b35056bb..3ce5c7e0 100644
--- a/comfy/sdxl_clip.py
+++ b/comfy/sdxl_clip.py
@@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
+
+
+class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, tokenizer_path=None, embedding_directory=None):
+ super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
+
+class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None):
+ super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
+
+class StableCascadeClipG(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
+ super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
+ special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
+
+ def load_sd(self, sd):
+ return super().load_sd(sd)
+
+class StableCascadeClipModel(sd1_clip.SD1ClipModel):
+ def __init__(self, device="cpu", dtype=None):
+ super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 1d442d4d..5bb98d88 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE):
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
replace_prefix = {}
- replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
- state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
+ replace_prefix["cond_stage_model."] = "clip_l."
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE):
def process_clip_state_dict(self, state_dict):
replace_prefix = {}
- replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
- state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
-
- state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
+ replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
+ replace_prefix["cond_stage_model.model."] = "clip_h."
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
+ state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE):
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
replace_prefix = {}
+ replace_prefix["conditioner.embedders.0.model."] = "clip_g."
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
- state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
- keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
- keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
-
+ state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
@@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE):
keys_to_replace = {}
replace_prefix = {}
- replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
- state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
- keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
- keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
- keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
+ replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
+ replace_prefix["conditioner.embedders.1.model."] = "clip_g."
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
+
+ state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
+ keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection"
- state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
@@ -306,5 +305,66 @@ class SD_X4Upscaler(SD20):
out = model_base.SD_X4Upscaler(self, device=device)
return out
-models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
+class Stable_Cascade_C(supported_models_base.BASE):
+ unet_config = {
+ "stable_cascade_stage": 'c',
+ }
+
+ unet_extra_config = {}
+
+ latent_format = latent_formats.SC_Prior
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ sampling_settings = {
+ "shift": 2.0,
+ }
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoder."]
+ clip_vision_prefix = "clip_l_vision."
+
+ def process_unet_state_dict(self, state_dict):
+ key_list = list(state_dict.keys())
+ for y in ["weight", "bias"]:
+ suffix = "in_proj_{}".format(y)
+ keys = filter(lambda a: a.endswith(suffix), key_list)
+ for k_from in keys:
+ weights = state_dict.pop(k_from)
+ prefix = k_from[:-(len(suffix) + 1)]
+ shape_from = weights.shape[0] // 3
+ for x in range(3):
+ p = ["to_q", "to_k", "to_v"]
+ k_to = "{}.{}.{}".format(prefix, p[x], y)
+ state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
+ return state_dict
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.StableCascade_C(self, device=device)
+ return out
+
+ def clip_target(self):
+ return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
+
+class Stable_Cascade_B(Stable_Cascade_C):
+ unet_config = {
+ "stable_cascade_stage": 'b',
+ }
+
+ unet_extra_config = {}
+
+ latent_format = latent_formats.SC_B
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ sampling_settings = {
+ "shift": 1.0,
+ }
+
+ clip_vision_prefix = None
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.StableCascade_B(self, device=device)
+ return out
+
+
+models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models += [SVD_img2vid]
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index 58535a9f..4d7e2593 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -22,13 +22,15 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]
+ text_encoder_key_prefix = ["cond_stage_model."]
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
- if s.unet_config[k] != unet_config[k]:
+ if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
return True
@@ -54,6 +56,7 @@ class BASE:
return out
def process_clip_state_dict(self, state_dict):
+ state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
return state_dict
def process_unet_state_dict(self, state_dict):
@@ -63,7 +66,7 @@ class BASE:
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
- replace_prefix = {"": "cond_stage_model."}
+ replace_prefix = {"": self.text_encoder_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
@@ -77,8 +80,9 @@ class BASE:
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict):
- replace_prefix = {"": "first_stage_model."}
+ replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
- def set_manual_cast(self, manual_cast_dtype):
+ def set_inference_dtype(self, dtype, manual_cast_dtype):
+ self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype
diff --git a/comfy/utils.py b/comfy/utils.py
index 1113bf0f..04cf76ed 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -169,6 +169,8 @@ UNET_MAP_BASIC = {
}
def unet_to_diffusers(unet_config):
+ if "num_res_blocks" not in unet_config:
+ return {}
num_res_blocks = unet_config["num_res_blocks"]
channel_mult = unet_config["channel_mult"]
transformer_depth = unet_config["transformer_depth"][:]
diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py
index aa80f526..8f638bf8 100644
--- a/comfy_extras/nodes_images.py
+++ b/comfy_extras/nodes_images.py
@@ -48,6 +48,25 @@ class RepeatImageBatch:
s = image.repeat((amount, 1,1,1))
return (s,)
+class ImageFromBatch:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "image": ("IMAGE",),
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
+ "length": ("INT", {"default": 1, "min": 1, "max": 64}),
+ }}
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "frombatch"
+
+ CATEGORY = "image/batch"
+
+ def frombatch(self, image, batch_index, length):
+ s_in = image
+ batch_index = min(s_in.shape[0] - 1, batch_index)
+ length = min(s_in.shape[0] - batch_index, length)
+ s = s_in[batch_index:batch_index + length].clone()
+ return (s,)
+
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@@ -170,6 +189,7 @@ class SaveAnimatedPNG:
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
+ "ImageFromBatch": ImageFromBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
}
diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py
index 541ce8fa..1b3f3945 100644
--- a/comfy_extras/nodes_model_advanced.py
+++ b/comfy_extras/nodes_model_advanced.py
@@ -17,6 +17,10 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
+class X0(comfy.model_sampling.EPS):
+ def calculate_denoised(self, sigma, model_output, model_input):
+ return model_output
+
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
@@ -68,7 +72,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
- "sampling": (["eps", "v_prediction", "lcm"],),
+ "sampling": (["eps", "v_prediction", "lcm", "x0"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
@@ -88,6 +92,8 @@ class ModelSamplingDiscrete:
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
+ elif sampling == "x0":
+ sampling_type = X0
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
@@ -99,6 +105,32 @@ class ModelSamplingDiscrete:
m.add_object_patch("model_sampling", model_sampling)
return (m, )
+class ModelSamplingStableCascade:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "model": ("MODEL",),
+ "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}),
+ }}
+
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "advanced/model"
+
+ def patch(self, model, shift):
+ m = model.clone()
+
+ sampling_base = comfy.model_sampling.StableCascadeSampling
+ sampling_type = comfy.model_sampling.EPS
+
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
+ pass
+
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
+ model_sampling.set_parameters(shift)
+ m.add_object_patch("model_sampling", model_sampling)
+ return (m, )
+
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
@@ -171,5 +203,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
+ "ModelSamplingStableCascade": ModelSamplingStableCascade,
"RescaleCFG": RescaleCFG,
}
diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py
new file mode 100644
index 00000000..b795d008
--- /dev/null
+++ b/comfy_extras/nodes_stable_cascade.py
@@ -0,0 +1,109 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+import nodes
+import comfy.utils
+
+
+class StableCascade_EmptyLatentImage:
+ def __init__(self, device="cpu"):
+ self.device = device
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
+ "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
+ "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
+ }}
+ RETURN_TYPES = ("LATENT", "LATENT")
+ RETURN_NAMES = ("stage_c", "stage_b")
+ FUNCTION = "generate"
+
+ CATEGORY = "_for_testing/stable_cascade"
+
+ def generate(self, width, height, compression, batch_size=1):
+ c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
+ b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
+ return ({
+ "samples": c_latent,
+ }, {
+ "samples": b_latent,
+ })
+
+class StableCascade_StageC_VAEEncode:
+ def __init__(self, device="cpu"):
+ self.device = device
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image": ("IMAGE",),
+ "vae": ("VAE", ),
+ "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
+ }}
+ RETURN_TYPES = ("LATENT", "LATENT")
+ RETURN_NAMES = ("stage_c", "stage_b")
+ FUNCTION = "generate"
+
+ CATEGORY = "_for_testing/stable_cascade"
+
+ def generate(self, image, vae, compression):
+ width = image.shape[-2]
+ height = image.shape[-3]
+ out_width = (width // compression) * vae.downscale_ratio
+ out_height = (height // compression) * vae.downscale_ratio
+
+ s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
+
+ c_latent = vae.encode(s[:,:,:,:3])
+ b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4])
+ return ({
+ "samples": c_latent,
+ }, {
+ "samples": b_latent,
+ })
+
+class StableCascade_StageB_Conditioning:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "conditioning": ("CONDITIONING",),
+ "stage_c": ("LATENT",),
+ }}
+ RETURN_TYPES = ("CONDITIONING",)
+
+ FUNCTION = "set_prior"
+
+ CATEGORY = "_for_testing/stable_cascade"
+
+ def set_prior(self, conditioning, stage_c):
+ c = []
+ for t in conditioning:
+ d = t[1].copy()
+ d['stable_cascade_prior'] = stage_c['samples']
+ n = [t[0], d]
+ c.append(n)
+ return (c, )
+
+NODE_CLASS_MAPPINGS = {
+ "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
+ "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
+ "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
+}
diff --git a/custom_nodes/websocket_image_save.py.disabled b/custom_nodes/websocket_image_save.py.disabled
new file mode 100644
index 00000000..b85a5de8
--- /dev/null
+++ b/custom_nodes/websocket_image_save.py.disabled
@@ -0,0 +1,49 @@
+from PIL import Image, ImageOps
+from io import BytesIO
+import numpy as np
+import struct
+import comfy.utils
+import time
+
+#You can use this node to save full size images through the websocket, the
+#images will be sent in exactly the same format as the image previews: as
+#binary images on the websocket with a 8 byte header indicating the type
+#of binary message (first 4 bytes) and the image format (next 4 bytes).
+
+#The reason this node is disabled by default is because there is a small
+#issue when using it with the default ComfyUI web interface: When generating
+#batches only the last image will be shown in the UI.
+
+#Note that no metadata will be put in the images saved with this node.
+
+class SaveImageWebsocket:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"images": ("IMAGE", ),}
+ }
+
+ RETURN_TYPES = ()
+ FUNCTION = "save_images"
+
+ OUTPUT_NODE = True
+
+ CATEGORY = "image"
+
+ def save_images(self, images):
+ pbar = comfy.utils.ProgressBar(images.shape[0])
+ step = 0
+ for image in images:
+ i = 255. * image.cpu().numpy()
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
+ pbar.update_absolute(step, images.shape[0], ("PNG", img, None))
+ step += 1
+
+ return {}
+
+ def IS_CHANGED(s, images):
+ return time.time()
+
+NODE_CLASS_MAPPINGS = {
+ "SaveImageWebsocket": SaveImageWebsocket,
+}
diff --git a/execution.py b/execution.py
index 00908ead..3e9d53b0 100644
--- a/execution.py
+++ b/execution.py
@@ -194,8 +194,12 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
return (True, None, None)
-def recursive_will_execute(prompt, outputs, current_item):
+def recursive_will_execute(prompt, outputs, current_item, memo={}):
unique_id = current_item
+
+ if unique_id in memo:
+ return memo[unique_id]
+
inputs = prompt[unique_id]['inputs']
will_execute = []
if unique_id in outputs:
@@ -207,9 +211,10 @@ def recursive_will_execute(prompt, outputs, current_item):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
- will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
+ will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
- return will_execute + [unique_id]
+ memo[unique_id] = will_execute + [unique_id]
+ return memo[unique_id]
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
@@ -377,7 +382,8 @@ class PromptExecutor:
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
- to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
+ memo = {}
+ to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
diff --git a/nodes.py b/nodes.py
index d9bc4884..a577c212 100644
--- a/nodes.py
+++ b/nodes.py
@@ -309,18 +309,7 @@ class VAEEncode:
CATEGORY = "latent"
- @staticmethod
- def vae_encode_crop_pixels(pixels):
- x = (pixels.shape[1] // 8) * 8
- y = (pixels.shape[2] // 8) * 8
- if pixels.shape[1] != x or pixels.shape[2] != y:
- x_offset = (pixels.shape[1] % 8) // 2
- y_offset = (pixels.shape[2] % 8) // 2
- pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
- return pixels
-
def encode(self, vae, pixels):
- pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, )
@@ -336,7 +325,6 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size):
- pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, )
@@ -350,14 +338,14 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6):
- x = (pixels.shape[1] // 8) * 8
- y = (pixels.shape[2] // 8) * 8
+ x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
+ y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
- x_offset = (pixels.shape[1] % 8) // 2
- y_offset = (pixels.shape[2] % 8) // 2
+ x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
+ y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
@@ -854,15 +842,20 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
+ "type": (["stable_diffusion", "stable_cascade"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
- def load_clip(self, clip_name):
+ def load_clip(self, clip_name, type="stable_diffusion"):
+ clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
+ if type == "stable_cascade":
+ clip_type = comfy.sd.CLIPType.STABLE_CASCADE
+
clip_path = folder_paths.get_full_path("clip", clip_name)
- clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
class DualCLIPLoader:
@@ -1434,7 +1427,7 @@ class SaveImage:
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
- for image in images:
+ for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
@@ -1446,7 +1439,8 @@ class SaveImage:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
- file = f"{filename}_{counter:05}_.png"
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
+ file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({
"filename": file,
@@ -1966,6 +1960,7 @@ def init_custom_nodes():
"nodes_sdupscale.py",
"nodes_photomaker.py",
"nodes_cond.py",
+ "nodes_stable_cascade.py",
]
for node_file in extras_files:
diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js
index b12ad968..23f51d81 100644
--- a/web/extensions/core/widgetInputs.js
+++ b/web/extensions/core/widgetInputs.js
@@ -22,6 +22,7 @@ function isConvertableWidget(widget, config) {
}
function hideWidget(node, widget, suffix = "") {
+ if (widget.type?.startsWith(CONVERTED_TYPE)) return;
widget.origType = widget.type;
widget.origComputeSize = widget.computeSize;
widget.origSerializeValue = widget.serializeValue;
diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js
index 83a4ebc8..16960920 100644
--- a/web/scripts/pnginfo.js
+++ b/web/scripts/pnginfo.js
@@ -24,7 +24,7 @@ export function getPngMetadata(file) {
const length = dataView.getUint32(offset);
// Get the chunk type
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
- if (type === "tEXt" || type == "comf") {
+ if (type === "tEXt" || type == "comf" || type === "iTXt") {
// Get the keyword
let keyword_end = offset + 8;
while (pngData[keyword_end] !== 0) {
@@ -33,7 +33,7 @@ export function getPngMetadata(file) {
const keyword = String.fromCharCode(...pngData.slice(offset + 8, keyword_end));
// Get the text
const contentArraySegment = pngData.slice(keyword_end + 1, offset + 8 + length);
- const contentJson = Array.from(contentArraySegment).map(s=>String.fromCharCode(s)).join('')
+ const contentJson = new TextDecoder("utf-8").decode(contentArraySegment);
txt_chunks[keyword] = contentJson;
}
diff --git a/web/style.css b/web/style.css
index 863840b2..cf7a8b9e 100644
--- a/web/style.css
+++ b/web/style.css
@@ -197,6 +197,7 @@ button.comfy-close-menu-btn {
.comfy-modal button:hover,
.comfy-menu-actions button:hover {
filter: brightness(1.2);
+ will-change: transform;
cursor: pointer;
}
@@ -462,11 +463,13 @@ dialog::backdrop {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
filter: brightness(95%);
+ will-change: transform;
}
.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) {
background-color: var(--comfy-menu-bg) !important;
filter: brightness(155%);
+ will-change: transform;
color: var(--input-text);
}
@@ -527,12 +530,14 @@ dialog::backdrop {
color: var(--input-text);
background-color: var(--comfy-input-bg);
filter: brightness(80%);
+ will-change: transform;
padding-left: 0.2em;
}
.litegraph.lite-search-item.generic_type {
color: var(--input-text);
filter: brightness(50%);
+ will-change: transform;
}
@media only screen and (max-width: 450px) {
@@ -551,4 +556,4 @@ dialog::backdrop {
text-align: center;
border-top: none;
}
-}
\ No newline at end of file
+}