diff --git a/README.md b/README.md index ff3ab642..a94a212a 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. -- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) +- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 09e7bbca..9b82a246 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module): x = self.embeddings(input_tokens) mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d9d990a7..41619758 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -318,9 +318,10 @@ def load_controlnet(ckpt_path, model=None): return ControlLora(controlnet_data) controlnet_config = None + supported_inference_dtypes = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - unet_dtype = comfy.model_management.unet_dtype() - controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -380,12 +381,20 @@ def load_controlnet(ckpt_path, model=None): return net if controlnet_config is None: - unet_dtype = comfy.model_management.unet_dtype() - controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) + supported_inference_dtypes = model_config.supported_inference_dtypes + controlnet_config = model_config.unet_config + load_device = comfy.model_management.get_torch_device() + if supported_inference_dtypes is None: + unet_dtype = comfy.model_management.unet_dtype() + else: + unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) if manual_cast_dtype is not None: controlnet_config["operations"] = comfy.ops.manual_cast + controlnet_config["dtype"] = unet_dtype controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) diff --git a/comfy/gligen.py b/comfy/gligen.py index 71892dfb..59252276 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -2,7 +2,8 @@ import torch from torch import nn from .ldm.modules.attention import CrossAttention from inspect import isfunction - +import comfy.ops +ops = comfy.ops.manual_cast def exists(val): return val is not None @@ -22,7 +23,7 @@ def default(val, d): class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = ops.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -35,14 +36,14 @@ class FeedForward(nn.Module): inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + ops.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + ops.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module): query_dim=query_dim, context_dim=context_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( query_dim=query_dim, context_dim=query_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( - query_dim=query_dim, context_dim=query_dim, dim_head=d_head) + query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -201,11 +204,11 @@ class PositionNet(nn.Module): self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy self.linears = nn.Sequential( - nn.Linear(self.in_dim + self.position_dim, 512), + ops.Linear(self.in_dim + self.position_dim, 512), nn.SiLU(), - nn.Linear(512, 512), + ops.Linear(512, 512), nn.SiLU(), - nn.Linear(512, out_dim), + ops.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter( @@ -215,16 +218,15 @@ class PositionNet(nn.Module): def forward(self, boxes, masks, positive_embeddings): B, N, _ = boxes.shape - dtype = self.linears[0].weight.dtype - masks = masks.unsqueeze(-1).to(dtype) - positive_embeddings = positive_embeddings.to(dtype) + masks = masks.unsqueeze(-1) + positive_embeddings = positive_embeddings # embedding position (it may includes padding as placeholder) - xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C # learnable null embedding - positive_null = self.null_positive_feature.view(1, 1, -1) - xyxy_null = self.null_position_feature.view(1, 1, -1) + positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) + xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * \ @@ -251,7 +253,7 @@ class Gligen(nn.Module): def func(x, extra_options): key = extra_options["transformer_index"] module = self.module_list[key] - return module(x, objs) + return module(x, objs.to(device=x.device, dtype=x.dtype)) return func def set_position(self, latent_image_shape, position_params, device): diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 2252a075..03fd59e3 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -37,3 +37,41 @@ class SDXL(LatentFormat): class SD_X4(LatentFormat): def __init__(self): self.scale_factor = 0.08333 + self.latent_rgb_factors = [ + [-0.2340, -0.3863, -0.3257], + [ 0.0994, 0.0885, -0.0908], + [-0.2833, -0.2349, -0.3741], + [ 0.2523, -0.0055, -0.1651] + ] + +class SC_Prior(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 + self.latent_rgb_factors = [ + [-0.0326, -0.0204, -0.0127], + [-0.1592, -0.0427, 0.0216], + [ 0.0873, 0.0638, -0.0020], + [-0.0602, 0.0442, 0.1304], + [ 0.0800, -0.0313, -0.1796], + [-0.0810, -0.0638, -0.1581], + [ 0.1791, 0.1180, 0.0967], + [ 0.0740, 0.1416, 0.0432], + [-0.1745, -0.1888, -0.1373], + [ 0.2412, 0.1577, 0.0928], + [ 0.1908, 0.0998, 0.0682], + [ 0.0209, 0.0365, -0.0092], + [ 0.0448, -0.0650, -0.1728], + [-0.1658, -0.1045, -0.1308], + [ 0.0542, 0.1545, 0.1325], + [-0.0352, -0.1672, -0.2541] + ] + +class SC_B(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 + self.latent_rgb_factors = [ + [ 0.1121, 0.2006, 0.1023], + [-0.2093, -0.0222, -0.0195], + [-0.3087, -0.1535, 0.0366], + [ 0.0290, -0.1574, -0.4078] + ] diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py new file mode 100644 index 00000000..124902c0 --- /dev/null +++ b/comfy/ldm/cascade/common.py @@ -0,0 +1,161 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torch.nn as nn +from comfy.ldm.modules.attention import optimized_attention + +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + + self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, q, k, v): + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + + out = optimized_attention(q, k, v, self.heads) + + return self.out_proj(out) + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) + # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + # x = self.attn(x, kv, kv, need_weights=False)[0] + x = self.attn(x, kv, kv) + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +def LayerNorm2d_op(operations): + class LayerNorm2d(operations.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return LayerNorm2d + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + operations.Linear(c_cond, c, dtype=dtype, device=device) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None): + super().__init__() + self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py new file mode 100644 index 00000000..260ccfc0 --- /dev/null +++ b/comfy/ldm/cascade/stage_a.py @@ -0,0 +1,258 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +from torch.autograd import Function + +class vector_quantize(Function): + @staticmethod + def forward(ctx, x, codebook): + with torch.no_grad(): + codebook_sqr = torch.sum(codebook ** 2, dim=1) + x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) + + dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) + _, indices = dist.min(dim=1) + + ctx.save_for_backward(indices, codebook) + ctx.mark_non_differentiable(indices) + + nn = torch.index_select(codebook, 0, indices) + return nn, indices + + @staticmethod + def backward(ctx, grad_output, grad_indices): + grad_inputs, grad_codebook = None, None + + if ctx.needs_input_grad[0]: + grad_inputs = grad_output.clone() + if ctx.needs_input_grad[1]: + # Gradient wrt. the codebook + indices, codebook = ctx.saved_tensors + + grad_codebook = torch.zeros_like(codebook) + grad_codebook.index_add_(0, indices, grad_output) + + return (grad_inputs, grad_codebook) + + +class VectorQuantize(nn.Module): + def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): + """ + Takes an input of variable size (as long as the last dimension matches the embedding size). + Returns one tensor containing the nearest neigbour embeddings to each of the inputs, + with the same size as the input, vq and commitment components for the loss as a touple + in the second output and the indices of the quantized vectors in the third: + quantized, (vq_loss, commit_loss), indices + """ + super(VectorQuantize, self).__init__() + + self.codebook = nn.Embedding(k, embedding_size) + self.codebook.weight.data.uniform_(-1./k, 1./k) + self.vq = vector_quantize.apply + + self.ema_decay = ema_decay + self.ema_loss = ema_loss + if ema_loss: + self.register_buffer('ema_element_count', torch.ones(k)) + self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) + + def _laplace_smoothing(self, x, epsilon): + n = torch.sum(x) + return ((x + epsilon) / (n + x.size(0) * epsilon) * n) + + def _updateEMA(self, z_e_x, indices): + mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() + elem_count = mask.sum(dim=0) + weight_sum = torch.mm(mask.t(), z_e_x) + + self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) + self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) + self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) + + self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) + + def idx2vq(self, idx, dim=-1): + q_idx = self.codebook(idx) + if dim != -1: + q_idx = q_idx.movedim(-1, dim) + return q_idx + + def forward(self, x, get_losses=True, dim=-1): + if dim != -1: + x = x.movedim(dim, -1) + z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x + z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) + vq_loss, commit_loss = None, None + if self.ema_loss and self.training: + self._updateEMA(z_e_x.detach(), indices.detach()) + # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss + z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) + if get_losses: + vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() + commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() + + z_q_x = z_q_x.view(x.shape) + if dim != -1: + z_q_x = z_q_x.movedim(-1, dim) + return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) + + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + try: + x = x + self.depthwise(x_temp) * mods[2] + except: #operation not implemented for bf16 + x_temp = self.depthwise[0](x_temp.float()).to(x.dtype) + x = x + self.depthwise[1](x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, + scale_factor=0.43): # 0.3764 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + else: + return x / self.scale_factor + + def decode(self, x): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py new file mode 100644 index 00000000..6d2c2223 --- /dev/null +++ b/comfy/ldm/cascade/stage_b.py @@ -0,0 +1,257 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import math +import numpy as np +import torch +from torch import nn +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock + +class StageB(nn.Module): + def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, + c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True, + t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.pixels_mapper = nn.Sequential( + operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) + x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', + align_corners=True) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py new file mode 100644 index 00000000..08e33ade --- /dev/null +++ b/comfy/ldm/cascade/stage_c.py @@ -0,0 +1,271 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +import numpy as np +import math +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock +# from .controlnet import ControlNetDeliverer + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None, + dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device) + self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pooled = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + x = self.embedding(x) + if cnet is not None: + cnet = ControlNetDeliverer(cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py new file mode 100644 index 00000000..0cb7c49f --- /dev/null +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -0,0 +1,95 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +import torch +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) + self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225])) + + def forward(self, x): + x = x * 0.5 + 0.5 + x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + o = self.mapper(self.backbone(x)) + return o + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return (self.blocks(x) - 0.5) * 2.0 + +class StageC_coder(nn.Module): + def __init__(self): + super().__init__() + self.previewer = Previewer() + self.encoder = EfficientNetEncoder() + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.previewer(x) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9c9cb761..48399bc0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -114,7 +114,12 @@ def attention_basic(q, k, v, heads, mask=None): mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) else: - sim += mask + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + sim.add_(mask) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) @@ -165,6 +170,13 @@ def attention_sub_quad(query, key, value, heads, mask=None): if query_chunk_size is None: query_chunk_size = 512 + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + hidden_states = efficient_dot_product_attention( query, key, @@ -223,6 +235,13 @@ def attention_split(q, k, v, heads, mask=None): raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) first_op_done = False cleared_cache = False diff --git a/comfy/model_base.py b/comfy/model_base.py index aafb88e0..421f271b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,5 +1,7 @@ import torch from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep +from comfy.ldm.cascade.stage_c import StageC +from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation import comfy.model_management @@ -12,9 +14,10 @@ class ModelType(Enum): EPS = 1 V_PREDICTION = 2 V_PREDICTION_EDM = 3 + STABLE_CASCADE = 4 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM +from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -27,6 +30,9 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.V_PREDICTION_EDM: c = V_PREDICTION s = ModelSamplingContinuousEDM + elif model_type == ModelType.STABLE_CASCADE: + c = EPS + s = StableCascadeSampling class ModelSampling(s, c): pass @@ -35,7 +41,7 @@ def model_sampling(model_config, model_type): class BaseModel(torch.nn.Module): - def __init__(self, model_config, model_type=ModelType.EPS, device=None): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() unet_config = model_config.unet_config @@ -48,7 +54,7 @@ class BaseModel(torch.nn.Module): operations = comfy.ops.manual_cast else: operations = comfy.ops.disable_weight_init - self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) + self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -427,3 +433,52 @@ class SD_X4Upscaler(BaseModel): out['c_concat'] = comfy.conds.CONDNoiseShape(image) out['y'] = comfy.conds.CONDRegular(noise_level) return out + +class StableCascade_C(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageC) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + + if "unclip_conditioning" in kwargs: + embeds = [] + for unclip_cond in kwargs["unclip_conditioning"]: + weight = unclip_cond["strength"] + embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight) + clip_img = torch.cat(embeds, dim=1) + else: + clip_img = torch.zeros((1, 1, 768)) + out["clip_img"] = comfy.conds.CONDRegular(clip_img) + out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) + out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,))) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn) + return out + + +class StableCascade_B(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageB) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + noise = kwargs.get("noise", None) + + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip'] = comfy.conds.CONDRegular(clip_text_pooled) + + #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched + prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) + + out["effnet"] = comfy.conds.CONDRegular(prior) + out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ea824c44..8fca6d8c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -28,9 +28,38 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return None -def detect_unet_config(state_dict, key_prefix, dtype): +def detect_unet_config(state_dict, key_prefix): state_dict_keys = list(state_dict.keys()) + if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade + unet_config = {} + text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) + if text_mapper_name in state_dict_keys: + unet_config['stable_cascade_stage'] = 'c' + w = state_dict[text_mapper_name] + if w.shape[0] == 1536: #stage c lite + unet_config['c_cond'] = 1536 + unet_config['c_hidden'] = [1536, 1536] + unet_config['nhead'] = [24, 24] + unet_config['blocks'] = [[4, 12], [12, 4]] + elif w.shape[0] == 2048: #stage c full + unet_config['c_cond'] = 2048 + elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: + unet_config['stable_cascade_stage'] = 'b' + w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] + if w.shape[-1] == 640: + unet_config['c_hidden'] = [320, 640, 1280, 1280] + unet_config['nhead'] = [-1, -1, 20, 20] + unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] + elif w.shape[-1] == 576: #stage b lite + unet_config['c_hidden'] = [320, 576, 1152, 1152] + unet_config['nhead'] = [-1, 9, 18, 18] + unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] + + return unet_config + unet_config = { "use_checkpoint": False, "image_size": 32, @@ -45,7 +74,6 @@ def detect_unet_config(state_dict, key_prefix, dtype): else: unet_config["adm_in_channels"] = None - unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] @@ -159,8 +187,8 @@ def model_config_from_unet_config(unet_config): print("no match", unet_config) return None -def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype) +def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): + unet_config = detect_unet_config(state_dict, unet_key_prefix) model_config = model_config_from_unet_config(unet_config) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) @@ -206,7 +234,7 @@ def convert_config(unet_config): return new_config -def unet_config_from_diffusers_unet(state_dict, dtype): +def unet_config_from_diffusers_unet(state_dict, dtype=None): match = {} transformer_depth = [] @@ -313,8 +341,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype): return convert_config(unet_config) return None -def model_config_from_diffusers_unet(state_dict, dtype): - unet_config = unet_config_from_diffusers_unet(state_dict, dtype) +def model_config_from_diffusers_unet(state_dict): + unet_config = unet_config_from_diffusers_unet(state_dict) if unet_config is not None: return model_config_from_unet_config(unet_config) return None diff --git a/comfy/model_management.py b/comfy/model_management.py index a8dc91b9..adcc0e8a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev -def unet_dtype(device=None, model_params=0): +def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if args.bf16_unet: return torch.bfloat16 if args.fp16_unet: @@ -497,20 +497,31 @@ def unet_dtype(device=None, model_params=0): if args.fp8_e5m2_unet: return torch.float8_e5m2 if should_use_fp16(device=device, model_params=model_params, manual_cast=True): - return torch.float16 + if torch.float16 in supported_dtypes: + return torch.float16 + if should_use_bf16(device, model_params=model_params, manual_cast=True): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 return torch.float32 # None means no manual cast -def unet_manual_cast(weight_dtype, inference_device): +def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if weight_dtype == torch.float32: return None - fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False) + fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) if fp16_supported and weight_dtype == torch.float16: return None - if fp16_supported: + bf16_supported = should_use_bf16(inference_device) + if bf16_supported and weight_dtype == torch.bfloat16: + return None + + if fp16_supported and torch.float16 in supported_dtypes: return torch.float16 + + elif bf16_supported and torch.bfloat16 in supported_dtypes: + return torch.bfloat16 else: return torch.float32 @@ -684,17 +695,20 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS -def is_device_cpu(device): +def is_device_type(device, type): if hasattr(device, 'type'): - if (device.type == 'cpu'): + if (device.type == type): return True return False +def is_device_cpu(device): + return is_device_type(device, 'cpu') + def is_device_mps(device): - if hasattr(device, 'type'): - if (device.type == 'mps'): - return True - return False + return is_device_type(device, 'mps') + +def is_device_cuda(device): + return is_device_type(device, 'cuda') def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled @@ -706,9 +720,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP16: return True - if device is not None: #TODO + if device is not None: if is_device_mps(device): - return False + return True if FORCE_FP32: return False @@ -716,8 +730,11 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if directml_enabled: return False - if cpu_mode() or mps_mode(): - return False #TODO ? + if mps_mode(): + return True + + if cpu_mode(): + return False if is_intel_xpu(): return True @@ -757,6 +774,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return True +def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + if device is not None: + if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow + return False + + if device is not None: #TODO not sure about mps bf16 support + if is_device_mps(device): + return False + + if FORCE_FP32: + return False + + if directml_enabled: + return False + + if cpu_mode() or mps_mode(): + return False + + if is_intel_xpu(): + return True + + if device is None: + device = torch.device("cuda") + + props = torch.cuda.get_device_properties(device) + if props.major >= 8: + return True + + bf16_works = torch.cuda.is_bf16_supported() + + if bf16_works or manual_cast: + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True + + return False + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index d5870027..97e91a01 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -132,3 +132,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module): log_sigma_min = math.log(self.sigma_min) return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) + +class StableCascadeSampling(ModelSamplingDiscrete): + def __init__(self, model_config=None): + super().__init__() + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + self.set_parameters(sampling_settings.get("shift", 1.0)) + + def set_parameters(self, shift=1.0, cosine_s=8e-3): + self.shift = shift + self.cosine_s = torch.tensor(cosine_s) + self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 + + #This part is just for compatibility with some schedulers in the codebase + self.num_timesteps = 10000 + sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) + for x in range(self.num_timesteps): + t = (x + 1) / self.num_timesteps + sigmas[x] = self.sigma(t) + + self.set_sigmas(sigmas) + + def sigma(self, timestep): + alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod) + + if self.shift != 1.0: + var = alpha_cumprod + logSNR = (var/(1-var)).log() + logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift)) + alpha_cumprod = logSNR.sigmoid() + + alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999) + return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5 + + def timestep(self, sigma): + var = 1 / ((sigma * sigma) + 1) + var = var.clamp(0, 1.0) + s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + + percent = 1.0 - percent + return self.sigma(torch.tensor(percent)) diff --git a/comfy/ops.py b/comfy/ops.py index f674b47f..517688e8 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch import comfy.model_management @@ -78,7 +96,11 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + bias = None return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): @@ -87,6 +109,28 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d): + comfy_cast_weights = False + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input, output_size=None): + num_spatial_dims = 2 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, + num_spatial_dims, self.dilation) + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.conv_transpose2d( + input, weight, bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: @@ -112,3 +156,6 @@ class manual_cast(disable_weight_init): class LayerNorm(disable_weight_init.LayerNorm): comfy_cast_weights = True + + class ConvTranspose2d(disable_weight_init.ConvTranspose2d): + comfy_cast_weights = True diff --git a/comfy/samplers.py b/comfy/samplers.py index f2ac3c5d..c795f208 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -652,6 +652,7 @@ def sampler_object(name): class KSampler: SCHEDULERS = SCHEDULER_NAMES SAMPLERS = SAMPLER_NAMES + DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2')) def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -670,7 +671,7 @@ class KSampler: sigmas = None discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']: + if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS: steps += 1 discard_penultimate_sigma = True diff --git a/comfy/sd.py b/comfy/sd.py index c15d73fe..7a77bb17 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,7 +1,11 @@ import torch +from enum import Enum from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine +from .ldm.cascade.stage_a import StageA +from .ldm.cascade.stage_c_coder import StageC_coder + import yaml import comfy.utils @@ -134,8 +138,11 @@ class CLIP: tokens = self.tokenize(text) return self.encode_from_tokens(tokens) - def load_sd(self, sd): - return self.cond_stage_model.load_sd(sd) + def load_sd(self, sd, full_model=False): + if full_model: + return self.cond_stage_model.load_state_dict(sd, strict=False) + else: + return self.cond_stage_model.load_sd(sd) def get_sd(self): return self.cond_stage_model.state_dict() @@ -155,7 +162,10 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.downscale_ratio = 8 + self.upscale_ratio = 8 self.latent_channels = 4 + self.process_input = lambda image: image * 2.0 - 1.0 + self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -168,6 +178,34 @@ class VAE: decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: self.first_stage_model = comfy.taesd.taesd.TAESD() + elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade + self.first_stage_model = StageA() + self.downscale_ratio = 4 + self.upscale_ratio = 4 + #TODO + #self.memory_used_encode + #self.memory_used_decode + self.process_input = lambda image: image + self.process_output = lambda image: image + elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["encoder.{}".format(k)] = sd[k] + sd = new_sd + elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["previewer.{}".format(k)] = sd[k] + sd = new_sd + elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -175,6 +213,7 @@ class VAE: if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE ddconfig['ch_mult'] = [1, 2, 4] self.downscale_ratio = 4 + self.upscale_ratio = 4 self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: @@ -200,18 +239,27 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def vae_encode_crop_pixels(self, pixels): + x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio + y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % self.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % self.downscale_ratio) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + return pixels + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() - output = torch.clamp(( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar)) - / 3.0) / 2.0, min=0.0, max=1.0) + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + output = self.process_output( + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) + / 3.0) return output def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -220,7 +268,7 @@ class VAE: steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -235,10 +283,10 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) + pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) @@ -252,6 +300,7 @@ class VAE: return output.movedim(1,-1) def encode(self, pixel_samples): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) @@ -261,7 +310,7 @@ class VAE: batch_number = max(1, batch_number) samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) + pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: @@ -271,6 +320,7 @@ class VAE: return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) @@ -297,8 +347,11 @@ def load_style_model(ckpt_path): model.load_state_dict(model_data) return StyleModel(model) +class CLIPType(Enum): + STABLE_DIFFUSION = 1 + STABLE_CASCADE = 2 -def load_clip(ckpt_paths, embedding_directory=None): +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] for p in ckpt_paths: clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) @@ -314,8 +367,12 @@ def load_clip(ckpt_paths, embedding_directory=None): clip_target.params = {} if len(clip_data) == 1: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: - clip_target.clip = sdxl_clip.SDXLRefinerClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.STABLE_CASCADE: + clip_target.clip = sdxl_clip.StableCascadeClipModel + clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer + else: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer @@ -438,15 +495,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip_target = None parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") - unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) - - class WeightsLoader(torch.nn.Module): - pass - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) - model_config.set_manual_cast(manual_cast_dtype) + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -467,13 +521,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o vae = VAE(sd=vae_sd) if output_clip: - w = WeightsLoader() clip_target = model_config.clip_target() if clip_target is not None: - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - sd = model_config.process_clip_state_dict(sd) - load_model_weights(w, sd) + clip_sd = model_config.process_clip_state_dict(sd) + if len(clip_sd) > 0: + clip = CLIP(clip_target, embedding_directory=embedding_directory) + m, u = clip.load_sd(clip_sd, full_model=True) + if len(m) > 0: + print("clip missing:", m) + + if len(u) > 0: + print("clip unexpected:", u) + else: + print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: @@ -492,16 +552,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) - if "input_blocks.0.0.weight" in sd: #ldm - model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) + if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade + model_config = model_detection.model_config_from_unet(sd, "") if model_config is None: return None new_sd = sd else: #diffusers - model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) + model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None: return None @@ -513,8 +572,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format new_sd[diffusers_keys[k]] = sd.pop(k) else: print(diffusers_keys[k], k) + offload_device = model_management.unet_offload_device() - model_config.set_manual_cast(manual_cast_dtype) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 65ea909f..8287ad2e 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.special_tokens = special_tokens self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False + self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index b35056bb..3ce5c7e0 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) + + +class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): + def __init__(self, tokenizer_path=None, embedding_directory=None): + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + +class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer) + +class StableCascadeClipG(sd1_clip.SDClipModel): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) + + def load_sd(self, sd): + return super().load_sd(sd) + +class StableCascadeClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1d442d4d..5bb98d88 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE): state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() replace_prefix = {} - replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) + replace_prefix["cond_stage_model."] = "clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE): def process_clip_state_dict(self, state_dict): replace_prefix = {} - replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) - - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) + replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format + replace_prefix["cond_stage_model.model."] = "clip_h." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) + state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE): def process_clip_state_dict(self, state_dict): keys_to_replace = {} replace_prefix = {} + replace_prefix["conditioner.embedders.0.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - + state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE): keys_to_replace = {} replace_prefix = {} - replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" + replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model" + replace_prefix["conditioner.embedders.1.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) + + state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) + keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection" - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -306,5 +305,66 @@ class SD_X4Upscaler(SD20): out = model_base.SD_X4Upscaler(self, device=device) return out -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler] +class Stable_Cascade_C(supported_models_base.BASE): + unet_config = { + "stable_cascade_stage": 'c', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_Prior + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + sampling_settings = { + "shift": 2.0, + } + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoder."] + clip_vision_prefix = "clip_l_vision." + + def process_unet_state_dict(self, state_dict): + key_list = list(state_dict.keys()) + for y in ["weight", "bias"]: + suffix = "in_proj_{}".format(y) + keys = filter(lambda a: a.endswith(suffix), key_list) + for k_from in keys: + weights = state_dict.pop(k_from) + prefix = k_from[:-(len(suffix) + 1)] + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["to_q", "to_k", "to_v"] + k_to = "{}.{}.{}".format(prefix, p[x], y) + state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_C(self, device=device) + return out + + def clip_target(self): + return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) + +class Stable_Cascade_B(Stable_Cascade_C): + unet_config = { + "stable_cascade_stage": 'b', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_B + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + sampling_settings = { + "shift": 1.0, + } + + clip_vision_prefix = None + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_B(self, device=device) + return out + + +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 58535a9f..4d7e2593 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -22,13 +22,15 @@ class BASE: sampling_settings = {} latent_format = latent_formats.LatentFormat vae_key_prefix = ["first_stage_model."] + text_encoder_key_prefix = ["cond_stage_model."] + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] manual_cast_dtype = None @classmethod def matches(s, unet_config): for k in s.unet_config: - if s.unet_config[k] != unet_config[k]: + if k not in unet_config or s.unet_config[k] != unet_config[k]: return False return True @@ -54,6 +56,7 @@ class BASE: return out def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) return state_dict def process_unet_state_dict(self, state_dict): @@ -63,7 +66,7 @@ class BASE: return state_dict def process_clip_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "cond_stage_model."} + replace_prefix = {"": self.text_encoder_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_clip_vision_state_dict_for_saving(self, state_dict): @@ -77,8 +80,9 @@ class BASE: return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_vae_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "first_stage_model."} + replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_manual_cast(self, manual_cast_dtype): + def set_inference_dtype(self, dtype, manual_cast_dtype): + self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy/utils.py b/comfy/utils.py index 1113bf0f..04cf76ed 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -169,6 +169,8 @@ UNET_MAP_BASIC = { } def unet_to_diffusers(unet_config): + if "num_res_blocks" not in unet_config: + return {} num_res_blocks = unet_config["num_res_blocks"] channel_mult = unet_config["channel_mult"] transformer_depth = unet_config["transformer_depth"][:] diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index aa80f526..8f638bf8 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -48,6 +48,25 @@ class RepeatImageBatch: s = image.repeat((amount, 1,1,1)) return (s,) +class ImageFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "frombatch" + + CATEGORY = "image/batch" + + def frombatch(self, image, batch_index, length): + s_in = image + batch_index = min(s_in.shape[0] - 1, batch_index) + length = min(s_in.shape[0] - batch_index, length) + s = s_in[batch_index:batch_index + length].clone() + return (s,) + class SaveAnimatedWEBP: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -170,6 +189,7 @@ class SaveAnimatedPNG: NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, + "ImageFromBatch": ImageFromBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, } diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 541ce8fa..1b3f3945 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -17,6 +17,10 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input +class X0(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -68,7 +72,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm"],), + "sampling": (["eps", "v_prediction", "lcm", "x0"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -88,6 +92,8 @@ class ModelSamplingDiscrete: elif sampling == "lcm": sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled + elif sampling == "x0": + sampling_type = X0 class ModelSamplingAdvanced(sampling_base, sampling_type): pass @@ -99,6 +105,32 @@ class ModelSamplingDiscrete: m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingStableCascade: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -171,5 +203,6 @@ class RescaleCFG: NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingStableCascade": ModelSamplingStableCascade, "RescaleCFG": RescaleCFG, } diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py new file mode 100644 index 00000000..b795d008 --- /dev/null +++ b/comfy_extras/nodes_stable_cascade.py @@ -0,0 +1,109 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import nodes +import comfy.utils + + +class StableCascade_EmptyLatentImage: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, width, height, compression, batch_size=1): + c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) + b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageC_VAEEncode: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, image, vae, compression): + width = image.shape[-2] + height = image.shape[-3] + out_width = (width // compression) * vae.downscale_ratio + out_height = (height // compression) * vae.downscale_ratio + + s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) + + c_latent = vae.encode(s[:,:,:,:3]) + b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageB_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "conditioning": ("CONDITIONING",), + "stage_c": ("LATENT",), + }} + RETURN_TYPES = ("CONDITIONING",) + + FUNCTION = "set_prior" + + CATEGORY = "_for_testing/stable_cascade" + + def set_prior(self, conditioning, stage_c): + c = [] + for t in conditioning: + d = t[1].copy() + d['stable_cascade_prior'] = stage_c['samples'] + n = [t[0], d] + c.append(n) + return (c, ) + +NODE_CLASS_MAPPINGS = { + "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, + "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, + "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, +} diff --git a/custom_nodes/websocket_image_save.py.disabled b/custom_nodes/websocket_image_save.py.disabled new file mode 100644 index 00000000..b85a5de8 --- /dev/null +++ b/custom_nodes/websocket_image_save.py.disabled @@ -0,0 +1,49 @@ +from PIL import Image, ImageOps +from io import BytesIO +import numpy as np +import struct +import comfy.utils +import time + +#You can use this node to save full size images through the websocket, the +#images will be sent in exactly the same format as the image previews: as +#binary images on the websocket with a 8 byte header indicating the type +#of binary message (first 4 bytes) and the image format (next 4 bytes). + +#The reason this node is disabled by default is because there is a small +#issue when using it with the default ComfyUI web interface: When generating +#batches only the last image will be shown in the UI. + +#Note that no metadata will be put in the images saved with this node. + +class SaveImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ),} + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "image" + + def save_images(self, images): + pbar = comfy.utils.ProgressBar(images.shape[0]) + step = 0 + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pbar.update_absolute(step, images.shape[0], ("PNG", img, None)) + step += 1 + + return {} + + def IS_CHANGED(s, images): + return time.time() + +NODE_CLASS_MAPPINGS = { + "SaveImageWebsocket": SaveImageWebsocket, +} diff --git a/execution.py b/execution.py index 00908ead..3e9d53b0 100644 --- a/execution.py +++ b/execution.py @@ -194,8 +194,12 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute return (True, None, None) -def recursive_will_execute(prompt, outputs, current_item): +def recursive_will_execute(prompt, outputs, current_item, memo={}): unique_id = current_item + + if unique_id in memo: + return memo[unique_id] + inputs = prompt[unique_id]['inputs'] will_execute = [] if unique_id in outputs: @@ -207,9 +211,10 @@ def recursive_will_execute(prompt, outputs, current_item): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id) + will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo) - return will_execute + [unique_id] + memo[unique_id] = will_execute + [unique_id] + return memo[unique_id] def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): unique_id = current_item @@ -377,7 +382,8 @@ class PromptExecutor: while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + memo = {} + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute))) output_node_id = to_execute.pop(0)[-1] # This call shouldn't raise anything if there's an error deep in diff --git a/nodes.py b/nodes.py index d9bc4884..a577c212 100644 --- a/nodes.py +++ b/nodes.py @@ -309,18 +309,7 @@ class VAEEncode: CATEGORY = "latent" - @staticmethod - def vae_encode_crop_pixels(pixels): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 - if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] - return pixels - def encode(self, vae, pixels): - pixels = self.vae_encode_crop_pixels(pixels) t = vae.encode(pixels[:,:,:,:3]) return ({"samples":t}, ) @@ -336,7 +325,6 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size): - pixels = VAEEncode.vae_encode_crop_pixels(pixels) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) return ({"samples":t}, ) @@ -350,14 +338,14 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 + x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio + y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 + x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] @@ -854,15 +842,20 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), + "type": (["stable_diffusion", "stable_cascade"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name): + def load_clip(self, clip_name, type="stable_diffusion"): + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + if type == "stable_cascade": + clip_type = comfy.sd.CLIPType.STABLE_CASCADE + clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class DualCLIPLoader: @@ -1434,7 +1427,7 @@ class SaveImage: filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() - for image in images: + for (batch_number, image) in enumerate(images): i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) metadata = None @@ -1446,7 +1439,8 @@ class SaveImage: for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) - file = f"{filename}_{counter:05}_.png" + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, @@ -1966,6 +1960,7 @@ def init_custom_nodes(): "nodes_sdupscale.py", "nodes_photomaker.py", "nodes_cond.py", + "nodes_stable_cascade.py", ] for node_file in extras_files: diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index b12ad968..23f51d81 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -22,6 +22,7 @@ function isConvertableWidget(widget, config) { } function hideWidget(node, widget, suffix = "") { + if (widget.type?.startsWith(CONVERTED_TYPE)) return; widget.origType = widget.type; widget.origComputeSize = widget.computeSize; widget.origSerializeValue = widget.serializeValue; diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 83a4ebc8..16960920 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -24,7 +24,7 @@ export function getPngMetadata(file) { const length = dataView.getUint32(offset); // Get the chunk type const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8)); - if (type === "tEXt" || type == "comf") { + if (type === "tEXt" || type == "comf" || type === "iTXt") { // Get the keyword let keyword_end = offset + 8; while (pngData[keyword_end] !== 0) { @@ -33,7 +33,7 @@ export function getPngMetadata(file) { const keyword = String.fromCharCode(...pngData.slice(offset + 8, keyword_end)); // Get the text const contentArraySegment = pngData.slice(keyword_end + 1, offset + 8 + length); - const contentJson = Array.from(contentArraySegment).map(s=>String.fromCharCode(s)).join('') + const contentJson = new TextDecoder("utf-8").decode(contentArraySegment); txt_chunks[keyword] = contentJson; } diff --git a/web/style.css b/web/style.css index 863840b2..cf7a8b9e 100644 --- a/web/style.css +++ b/web/style.css @@ -197,6 +197,7 @@ button.comfy-close-menu-btn { .comfy-modal button:hover, .comfy-menu-actions button:hover { filter: brightness(1.2); + will-change: transform; cursor: pointer; } @@ -462,11 +463,13 @@ dialog::backdrop { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; filter: brightness(95%); + will-change: transform; } .litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) { background-color: var(--comfy-menu-bg) !important; filter: brightness(155%); + will-change: transform; color: var(--input-text); } @@ -527,12 +530,14 @@ dialog::backdrop { color: var(--input-text); background-color: var(--comfy-input-bg); filter: brightness(80%); + will-change: transform; padding-left: 0.2em; } .litegraph.lite-search-item.generic_type { color: var(--input-text); filter: brightness(50%); + will-change: transform; } @media only screen and (max-width: 450px) { @@ -551,4 +556,4 @@ dialog::backdrop { text-align: center; border-top: none; } -} \ No newline at end of file +}