diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py new file mode 100644 index 00000000..cb29df43 --- /dev/null +++ b/comfy/clip_vision.py @@ -0,0 +1,62 @@ +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor +from .utils import load_torch_file, transformers_convert +import os + +class ClipVisionModel(): + def __init__(self, json_config): + config = CLIPVisionConfig.from_json_file(json_config) + self.model = CLIPVisionModelWithProjection(config) + self.processor = CLIPImageProcessor(crop_size=224, + do_center_crop=True, + do_convert_rgb=True, + do_normalize=True, + do_resize=True, + image_mean=[ 0.48145466,0.4578275,0.40821073], + image_std=[0.26862954,0.26130258,0.27577711], + resample=3, #bicubic + size=224) + + def load_sd(self, sd): + self.model.load_state_dict(sd, strict=False) + + def encode_image(self, image): + inputs = self.processor(images=[image[0]], return_tensors="pt") + outputs = self.model(**inputs) + return outputs + +def convert_to_transformers(sd): + sd_k = sd.keys() + if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k: + keys_to_replace = { + "embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding", + "embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight", + "embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight", + "embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias", + "embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight", + "embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias", + "embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight", + } + + for x in keys_to_replace: + if x in sd_k: + sd[keys_to_replace[x]] = sd.pop(x) + + if "embedder.model.visual.proj" in sd_k: + sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1) + + sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32) + return sd + +def load_clipvision_from_sd(sd): + sd = convert_to_transformers(sd) + if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") + else: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") + clip = ClipVisionModel(json_config) + clip.load_sd(sd) + return clip + +def load(ckpt_path): + sd = load_torch_file(ckpt_path) + return load_clipvision_from_sd(sd) diff --git a/comfy/clip_vision_config_h.json b/comfy/clip_vision_config_h.json new file mode 100644 index 00000000..bb71be41 --- /dev/null +++ b/comfy/clip_vision_config_h.json @@ -0,0 +1,18 @@ +{ + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "gelu", + "hidden_size": 1280, + "image_size": 224, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 5120, + "layer_norm_eps": 1e-05, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 32, + "patch_size": 14, + "projection_dim": 1024, + "torch_dtype": "float32" +} diff --git a/comfy_extras/clip_vision_config.json b/comfy/clip_vision_config_vitl.json similarity index 70% rename from comfy_extras/clip_vision_config.json rename to comfy/clip_vision_config_vitl.json index 0e4db13d..c59b8ed5 100644 --- a/comfy_extras/clip_vision_config.json +++ b/comfy/clip_vision_config_vitl.json @@ -1,8 +1,4 @@ { - "_name_or_path": "openai/clip-vit-large-patch14", - "architectures": [ - "CLIPVisionModel" - ], "attention_dropout": 0.0, "dropout": 0.0, "hidden_act": "quick_gelu", @@ -18,6 +14,5 @@ "num_hidden_layers": 24, "patch_size": 14, "projection_dim": 768, - "torch_dtype": "float32", - "transformers_version": "4.24.0" + "torch_dtype": "float32" } diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 6af96124..d3f0eb2b 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): log = super().log_images(*args, **kwargs) log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') return log + + +class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion): + def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5, + freeze_embedder=True, noise_aug_config=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.embed_key = embedding_key + self.embedding_dropout = embedding_dropout + # self._init_embedder(embedder_config, freeze_embedder) + self._init_noise_aug(noise_aug_config) + + def _init_embedder(self, config, freeze=True): + embedder = instantiate_from_config(config) + if freeze: + self.embedder = embedder.eval() + self.embedder.train = disabled_train + for param in self.embedder.parameters(): + param.requires_grad = False + + def _init_noise_aug(self, config): + if config is not None: + # use the KARLO schedule for noise augmentation on CLIP image embeddings + noise_augmentor = instantiate_from_config(config) + assert isinstance(noise_augmentor, nn.Module) + noise_augmentor = noise_augmentor.eval() + noise_augmentor.train = disabled_train + self.noise_augmentor = noise_augmentor + else: + self.noise_augmentor = None + + def get_input(self, batch, k, cond_key=None, bs=None, **kwargs): + outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs) + z, c = outputs[0], outputs[1] + img = batch[self.embed_key][:bs] + img = rearrange(img, 'b h w c -> b c h w') + c_adm = self.embedder(img) + if self.noise_augmentor is not None: + c_adm, noise_level_emb = self.noise_augmentor(c_adm) + # assume this gives embeddings of noise levels + c_adm = torch.cat((c_adm, noise_level_emb), 1) + if self.training: + c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0], + device=c_adm.device)[:, None]) * c_adm + all_conds = {"c_crossattn": [c], "c_adm": c_adm} + noutputs = [z, all_conds] + noutputs.extend(outputs[2:]) + return noutputs + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, **kwargs): + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True, + return_original_cond=True) + log["inputs"] = x + log["reconstruction"] = xrec + assert self.model.conditioning_key is not None + assert self.cond_stage_key in ["caption", "txt"] + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', '')) + unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.) + + uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext + with ema_scope(f"Sampling"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True, + ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.), + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_, ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log diff --git a/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py b/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba3..da8d41f9 100644 --- a/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -307,7 +307,16 @@ def model_wrapper( else: x_in = torch.cat([x] * 2) t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) + if isinstance(condition, dict): + assert isinstance(unconditional_condition, dict) + c_in = dict() + for k in condition: + if isinstance(condition[k], list): + c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))] + else: + c_in[k] = torch.cat([unconditional_condition[k], condition[k]]) + else: + c_in = torch.cat([unconditional_condition, condition]) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) return noise_uncond + guidance_scale * (noise - noise_uncond) diff --git a/comfy/ldm/models/diffusion/dpm_solver/sampler.py b/comfy/ldm/models/diffusion/dpm_solver/sampler.py index 4270c618..e4d0d0a3 100644 --- a/comfy/ldm/models/diffusion/dpm_solver/sampler.py +++ b/comfy/ldm/models/diffusion/dpm_solver/sampler.py @@ -3,7 +3,6 @@ import torch from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver - MODEL_TYPES = { "eps": "noise", "v": "v" @@ -51,12 +50,20 @@ class DPMSolverSampler(object): ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + if isinstance(ctmp, torch.Tensor): + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}") else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + if isinstance(conditioning, torch.Tensor): + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") # sampling C, H, W = shape @@ -83,6 +90,7 @@ class DPMSolverSampler(object): ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, + lower_order_final=True) - return x.to(device), None \ No newline at end of file + return x.to(device), None diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 7b2f5b53..8a4e8b3e 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -409,6 +409,15 @@ class QKVAttention(nn.Module): return count_flops_attn(model, _x, y) +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. @@ -470,6 +479,7 @@ class UNetModel(nn.Module): num_attention_blocks=None, disable_middle_self_attn=False, use_linear_in_transformer=False, + adm_in_channels=None, ): super().__init__() if use_spatial_transformer: @@ -538,6 +548,15 @@ class UNetModel(nn.Module): elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) else: raise ValueError() diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 637363df..daf35da7 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) + elif schedule == "squaredcos_cap_v2": # used for karlo prior + # return early + return betas_for_alpha_bar( + n_timestep, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": @@ -218,6 +225,7 @@ class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) + def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -267,4 +275,4 @@ class HybridConditioner(nn.Module): def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise() diff --git a/comfy/ldm/modules/encoders/kornia_functions.py b/comfy/ldm/modules/encoders/kornia_functions.py new file mode 100644 index 00000000..912314cd --- /dev/null +++ b/comfy/ldm/modules/encoders/kornia_functions.py @@ -0,0 +1,59 @@ + + +from typing import List, Tuple, Union + +import torch +import torch.nn as nn + +#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py + +def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + r"""Normalize an image/video tensor with mean and standard deviation. + .. math:: + \text{input[channel] = (input[channel] - mean[channel]) / std[channel]} + Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, + Args: + data: Image tensor of size :math:`(B, C, *)`. + mean: Mean for each channel. + std: Standard deviations for each channel. + Return: + Normalised tensor with same size as input :math:`(B, C, *)`. + Examples: + >>> x = torch.rand(1, 4, 3, 3) + >>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.])) + >>> out.shape + torch.Size([1, 4, 3, 3]) + >>> x = torch.rand(1, 4, 3, 3) + >>> mean = torch.zeros(4) + >>> std = 255. * torch.ones(4) + >>> out = normalize(x, mean, std) + >>> out.shape + torch.Size([1, 4, 3, 3]) + """ + shape = data.shape + if len(mean.shape) == 0 or mean.shape[0] == 1: + mean = mean.expand(shape[1]) + if len(std.shape) == 0 or std.shape[0] == 1: + std = std.expand(shape[1]) + + # Allow broadcast on channel dimension + if mean.shape and mean.shape[0] != 1: + if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]: + raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") + + # Allow broadcast on channel dimension + if std.shape and std.shape[0] != 1: + if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]: + raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") + + mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) + std = torch.as_tensor(std, device=data.device, dtype=data.dtype) + + if mean.shape: + mean = mean[..., :, None] + if std.shape: + std = std[..., :, None] + + out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std + + return out.view(shape) diff --git a/comfy/ldm/modules/encoders/modules.py b/comfy/ldm/modules/encoders/modules.py index 4edd5496..bc9fde63 100644 --- a/comfy/ldm/modules/encoders/modules.py +++ b/comfy/ldm/modules/encoders/modules.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from . import kornia_functions from torch.utils.checkpoint import checkpoint from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel @@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module): c = batch[key][:, None] if self.ucg_rate > 0. and not disable_dropout: mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) c = c.long() c = self.embedding(c) return c @@ -57,18 +58,20 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, + freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False @@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): "pooled", "hidden" ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 super().__init__() @@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False @@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": @@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder): return self(text) +class ClipImageEmbedder(nn.Module): + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=True, + ucg_rate=0. + ): + super().__init__() + from clip import load as load_clip + self.model, _ = load_clip(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + # x = kornia_functions.geometry_resize(x, (224, 224), + # interpolation='bicubic', align_corners=True, + # antialias=self.antialias) + x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True) + x = (x + 1.) / 2. + # re-normalize according to clip + x = kornia_functions.enhance_normalize(x, self.mean, self.std) + return x + + def forward(self, x, no_dropout=False): + # x is assumed to be in range [-1,1] + out = self.model.encode_image(self.preprocess(x)) + out = out.to(x.dtype) + if self.ucg_rate > 0. and not no_dropout: + out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out + return out + + class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ LAYERS = [ - #"pooled", + # "pooled", "last", "penultimate" ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last"): super().__init__() @@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): x = self.model.ln_final(x) return x - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break @@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): return self(text) +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="pooled", antialias=True, ucg_rate=0.): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), + pretrained=version, ) + del model.transformer + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + # x = kornia.geometry.resize(x, (224, 224), + # interpolation='bicubic', align_corners=True, + # antialias=self.antialias) + x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia_functions.enhance_normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + if self.ucg_rate > 0. and not no_dropout: + z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z + return z + + def encode_with_vision_transformer(self, img): + img = self.preprocess(img) + x = self.model.visual(img) + return x + + def encode(self, text): + return self(text) + + class FrozenCLIPT5Encoder(AbstractEncoder): def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", clip_max_length=77, t5_max_length=77): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") def encode(self, text): return self(text) @@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] - - diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py new file mode 100644 index 00000000..f99e7920 --- /dev/null +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -0,0 +1,35 @@ +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from ldm.modules.diffusionmodules.openaimodel import Timestep +import torch + +class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): + def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): + super().__init__(*args, **kwargs) + if clip_stats_path is None: + clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim) + else: + clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu") + self.register_buffer("data_mean", clip_mean[None, :], persistent=False) + self.register_buffer("data_std", clip_std[None, :], persistent=False) + self.time_embed = Timestep(timestep_dim) + + def scale(self, x): + # re-normalize to centered mean and unit variance + x = (x - self.data_mean) * 1. / self.data_std + return x + + def unscale(self, x): + # back to original data stats + x = (x * self.data_std) + self.data_mean + return x + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + x = self.scale(x) + z = self.q_sample(x, noise_level) + z = self.unscale(z) + noise_level = self.time_embed(noise_level) + return z, noise_level diff --git a/comfy/samplers.py b/comfy/samplers.py index 15e78bbd..ddec9900 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if 'strength' in cond[1]: strength = cond[1]['strength'] + adm_cond = None + if 'adm' in cond[1]: + adm_cond = cond[1]['adm'] + input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] mult = torch.ones_like(input_x) * strength @@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con cropped.append(cr) conditionning['c_concat'] = torch.cat(cropped, dim=1) + if adm_cond is not None: + conditionning['c_adm'] = adm_cond + control = None if 'control' in cond[1]: control = cond[1]['control'] @@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if 'c_concat' in c1: if c1['c_concat'].shape != c2['c_concat'].shape: return False + if 'c_adm' in c1: + if c1['c_adm'].shape != c2['c_adm'].shape: + return False return True def can_concat_cond(c1, c2): @@ -92,16 +102,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def cond_cat(c_list): c_crossattn = [] c_concat = [] + c_adm = [] for x in c_list: if 'c_crossattn' in x: c_crossattn.append(x['c_crossattn']) if 'c_concat' in x: c_concat.append(x['c_concat']) + if 'c_adm' in x: + c_adm.append(x['c_adm']) out = {} if len(c_crossattn) > 0: out['c_crossattn'] = [torch.cat(c_crossattn)] if len(c_concat) > 0: out['c_concat'] = [torch.cat(c_concat)] + if len(c_adm) > 0: + out['c_adm'] = torch.cat(c_adm) return out def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): @@ -327,6 +342,30 @@ def apply_control_net_to_equal_area(conds, uncond): n['control'] = cond_cnets[x] uncond[temp[1]] = [o[0], n] +def encode_adm(noise_augmentor, conds, batch_size, device): + for t in range(len(conds)): + x = conds[t] + if 'adm' in x[1]: + adm_inputs = [] + weights = [] + adm_in = x[1]["adm"] + for adm_c in adm_in: + adm_cond = adm_c[0].image_embeds + weight = adm_c[1] + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + adm_inputs.append(adm_out) + + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: Apply Noise to Embedding Mix + else: + adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) + x[1] = x[1].copy() + x[1]["adm"] = torch.cat([adm_out] * batch_size) + + return conds + class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", @@ -422,10 +461,14 @@ class KSampler: else: precision_scope = contextlib.nullcontext + if hasattr(self.model, 'noise_augmentor'): #unclip + positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) + negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} cond_concat = None - if hasattr(self.model, 'concat_keys'): + if hasattr(self.model, 'concat_keys'): #inpaint cond_concat = [] for ck in self.model.concat_keys: if denoise_mask is not None: diff --git a/comfy/sd.py b/comfy/sd.py index 2a38ceb1..2d7ff5ab 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -12,20 +12,7 @@ from .cldm import cldm from .t2i_adapter import adapter from . import utils - -def load_torch_file(ckpt): - if ckpt.lower().endswith(".safetensors"): - import safetensors.torch - sd = safetensors.torch.load_file(ckpt, device="cpu") - else: - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd - return sd +from . import clip_vision def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -53,30 +40,7 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): if x in sd: sd[keys_to_replace[x]] = sd.pop(x) - resblock_to_replace = { - "ln_1": "layer_norm1", - "ln_2": "layer_norm2", - "mlp.c_fc": "mlp.fc1", - "mlp.c_proj": "mlp.fc2", - "attn.out_proj": "self_attn.out_proj", - } - - for resblock in range(24): - for x in resblock_to_replace: - for y in ["weight", "bias"]: - k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y) - k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y) - if k in sd: - sd[k_to] = sd.pop(k) - - for y in ["weight", "bias"]: - k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y) - if k_from in sd: - weights = sd.pop(k_from) - for x in range(3): - p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] - k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y) - sd[k_to] = weights[1024*x:1024*(x + 1)] + sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24) for x in load_state_dict_to: x.load_state_dict(sd, strict=False) @@ -123,7 +87,7 @@ LORA_UNET_MAP_RESNET = { } def load_lora(path, to_load): - lora = load_torch_file(path) + lora = utils.load_torch_file(path) patch_dict = {} loaded_keys = set() for x in to_load: @@ -599,7 +563,7 @@ class ControlNet: return out def load_controlnet(ckpt_path, model=None): - controlnet_data = load_torch_file(ckpt_path) + controlnet_data = utils.load_torch_file(ckpt_path) pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth = False sd2 = False @@ -793,7 +757,7 @@ class StyleModel: def load_style_model(ckpt_path): - model_data = load_torch_file(ckpt_path) + model_data = utils.load_torch_file(ckpt_path) keys = model_data.keys() if "style_embedding" in keys: model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) @@ -804,7 +768,7 @@ def load_style_model(ckpt_path): def load_clip(ckpt_path, embedding_directory=None): - clip_data = load_torch_file(ckpt_path) + clip_data = utils.load_torch_file(ckpt_path) config = {} if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' @@ -847,7 +811,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e load_state_dict_to = [w] model = instantiate_from_config(config["model"]) - sd = load_torch_file(ckpt_path) + sd = utils.load_torch_file(ckpt_path) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: @@ -856,10 +820,11 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e return (ModelPatcher(model), clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): - sd = load_torch_file(ckpt_path) +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): + sd = utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None + clipvision = None vae = None fp16 = model_management.should_use_fp16() @@ -884,6 +849,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] + clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" + noise_aug_config = None + if clipvision_key in sd_keys: + size = sd[clipvision_key].shape[1] + + if output_clipvision: + clipvision = clip_vision.load_clipvision_from_sd(sd) + + noise_aug_key = "noise_augmentor.betas" + if noise_aug_key in sd_keys: + noise_aug_config = {} + params = {} + noise_schedule_config = {} + noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0] + noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2" + params["noise_schedule_config"] = noise_schedule_config + noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" + if size == 1280: #h + params["timestep_dim"] = 1024 + elif size == 1024: #l + params["timestep_dim"] = 768 + noise_aug_config['params'] = params + sd_config = { "linear_start": 0.00085, "linear_end": 0.012, @@ -932,7 +920,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} - if unet_config["in_channels"] > 4: #inpainting model + if noise_aug_config is not None: #SD2.x unclip model + sd_config["noise_aug_config"] = noise_aug_config + sd_config["image_size"] = 96 + sd_config["embedding_dropout"] = 0.25 + sd_config["conditioning_key"] = 'crossattn-adm' + model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" + elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" @@ -944,6 +938,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e else: unet_config["num_heads"] = 8 #SD1.x + unclip = 'model.diffusion_model.label_emb.0.0.weight' + if unclip in sd_keys: + unet_config["num_classes"] = "sequential" + unet_config["adm_in_channels"] = sd[unclip].shape[1] + if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" out = sd[k] @@ -956,4 +955,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e if fp16: model = model.half() - return (ModelPatcher(model), clip, vae) + return (ModelPatcher(model), clip, vae, clipvision) diff --git a/comfy/utils.py b/comfy/utils.py index 798ac1c4..0380b91d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,47 @@ import torch +def load_torch_file(ckpt): + if ckpt.lower().endswith(".safetensors"): + import safetensors.torch + sd = safetensors.torch.load_file(ckpt, device="cpu") + else: + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd + return sd + +def transformers_convert(sd, prefix_from, prefix_to, number): + resblock_to_replace = { + "ln_1": "layer_norm1", + "ln_2": "layer_norm2", + "mlp.c_fc": "mlp.fc1", + "mlp.c_proj": "mlp.fc2", + "attn.out_proj": "self_attn.out_proj", + } + + for resblock in range(number): + for x in resblock_to_replace: + for y in ["weight", "bias"]: + k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) + k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) + if k in sd: + sd[k_to] = sd.pop(k) + + for y in ["weight", "bias"]: + k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) + if k_from in sd: + weights = sd.pop(k_from) + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] + k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) + sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return sd + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] diff --git a/comfy_extras/clip_vision.py b/comfy_extras/clip_vision.py deleted file mode 100644 index 58d79a83..00000000 --- a/comfy_extras/clip_vision.py +++ /dev/null @@ -1,32 +0,0 @@ -from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor -from comfy.sd import load_torch_file -import os - -class ClipVisionModel(): - def __init__(self): - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json") - config = CLIPVisionConfig.from_json_file(json_config) - self.model = CLIPVisionModel(config) - self.processor = CLIPImageProcessor(crop_size=224, - do_center_crop=True, - do_convert_rgb=True, - do_normalize=True, - do_resize=True, - image_mean=[ 0.48145466,0.4578275,0.40821073], - image_std=[0.26862954,0.26130258,0.27577711], - resample=3, #bicubic - size=224) - - def load_sd(self, sd): - self.model.load_state_dict(sd, strict=False) - - def encode_image(self, image): - inputs = self.processor(images=[image[0]], return_tensors="pt") - outputs = self.model(**inputs) - return outputs - -def load(ckpt_path): - clip_data = load_torch_file(ckpt_path) - clip = ClipVisionModel() - clip.load_sd(clip_data) - return clip diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index b79b7851..6a7d0e51 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -1,6 +1,5 @@ import os from comfy_extras.chainner_models import model_loading -from comfy.sd import load_torch_file import model_management import torch import comfy.utils @@ -18,7 +17,7 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) - sd = load_torch_file(model_path) + sd = comfy.utils.load_torch_file(model_path) out = model_loading.load_state_dict(sd).eval() return (out, ) diff --git a/nodes.py b/nodes.py index e69832c5..1555c19c 100644 --- a/nodes.py +++ b/nodes.py @@ -18,7 +18,7 @@ import comfy.samplers import comfy.sd import comfy.utils -import comfy_extras.clip_vision +import comfy.clip_vision import model_management import importlib @@ -219,6 +219,21 @@ class CheckpointLoaderSimple: out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out +class unCLIPCheckpointLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") + FUNCTION = "load_checkpoint" + + CATEGORY = "_for_testing/unclip" + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return out + class CLIPSetLastLayer: @classmethod def INPUT_TYPES(s): @@ -370,7 +385,7 @@ class CLIPVisionLoader: def load_clip(self, clip_name): clip_path = folder_paths.get_full_path("clip_vision", clip_name) - clip_vision = comfy_extras.clip_vision.load(clip_path) + clip_vision = comfy.clip_vision.load(clip_path) return (clip_vision,) class CLIPVisionEncode: @@ -382,7 +397,7 @@ class CLIPVisionEncode: RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" - CATEGORY = "conditioning/style_model" + CATEGORY = "conditioning" def encode(self, clip_vision, image): output = clip_vision.encode_image(image) @@ -424,6 +439,32 @@ class StyleModelApply: c.append(n) return (c, ) +class unCLIPConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_adm" + + CATEGORY = "_for_testing/unclip" + + def apply_adm(self, conditioning, clip_vision_output, strength): + c = [] + for t in conditioning: + o = t[1].copy() + x = (clip_vision_output, strength) + if "adm" in o: + o["adm"] = o["adm"][:] + [x] + else: + o["adm"] = [x] + n = [t[0], o] + c.append(n) + return (c, ) + + class EmptyLatentImage: def __init__(self, device="cpu"): self.device = device @@ -1025,6 +1066,7 @@ NODE_CLASS_MAPPINGS = { "CLIPLoader": CLIPLoader, "CLIPVisionEncode": CLIPVisionEncode, "StyleModelApply": StyleModelApply, + "unCLIPConditioning": unCLIPConditioning, "ControlNetApply": ControlNetApply, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, @@ -1033,6 +1075,7 @@ NODE_CLASS_MAPPINGS = { "VAEDecodeTiled": VAEDecodeTiled, "VAEEncodeTiled": VAEEncodeTiled, "TomePatchModel": TomePatchModel, + "unCLIPCheckpointLoader": unCLIPCheckpointLoader, } def load_custom_node(module_path):