Browse Source
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.pull/364/head
comfyanonymous
2 years ago
17 changed files with 593 additions and 113 deletions
@ -0,0 +1,62 @@ |
|||||||
|
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor |
||||||
|
from .utils import load_torch_file, transformers_convert |
||||||
|
import os |
||||||
|
|
||||||
|
class ClipVisionModel(): |
||||||
|
def __init__(self, json_config): |
||||||
|
config = CLIPVisionConfig.from_json_file(json_config) |
||||||
|
self.model = CLIPVisionModelWithProjection(config) |
||||||
|
self.processor = CLIPImageProcessor(crop_size=224, |
||||||
|
do_center_crop=True, |
||||||
|
do_convert_rgb=True, |
||||||
|
do_normalize=True, |
||||||
|
do_resize=True, |
||||||
|
image_mean=[ 0.48145466,0.4578275,0.40821073], |
||||||
|
image_std=[0.26862954,0.26130258,0.27577711], |
||||||
|
resample=3, #bicubic |
||||||
|
size=224) |
||||||
|
|
||||||
|
def load_sd(self, sd): |
||||||
|
self.model.load_state_dict(sd, strict=False) |
||||||
|
|
||||||
|
def encode_image(self, image): |
||||||
|
inputs = self.processor(images=[image[0]], return_tensors="pt") |
||||||
|
outputs = self.model(**inputs) |
||||||
|
return outputs |
||||||
|
|
||||||
|
def convert_to_transformers(sd): |
||||||
|
sd_k = sd.keys() |
||||||
|
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k: |
||||||
|
keys_to_replace = { |
||||||
|
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding", |
||||||
|
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight", |
||||||
|
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight", |
||||||
|
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias", |
||||||
|
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight", |
||||||
|
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias", |
||||||
|
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight", |
||||||
|
} |
||||||
|
|
||||||
|
for x in keys_to_replace: |
||||||
|
if x in sd_k: |
||||||
|
sd[keys_to_replace[x]] = sd.pop(x) |
||||||
|
|
||||||
|
if "embedder.model.visual.proj" in sd_k: |
||||||
|
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1) |
||||||
|
|
||||||
|
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32) |
||||||
|
return sd |
||||||
|
|
||||||
|
def load_clipvision_from_sd(sd): |
||||||
|
sd = convert_to_transformers(sd) |
||||||
|
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: |
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") |
||||||
|
else: |
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") |
||||||
|
clip = ClipVisionModel(json_config) |
||||||
|
clip.load_sd(sd) |
||||||
|
return clip |
||||||
|
|
||||||
|
def load(ckpt_path): |
||||||
|
sd = load_torch_file(ckpt_path) |
||||||
|
return load_clipvision_from_sd(sd) |
@ -0,0 +1,18 @@ |
|||||||
|
{ |
||||||
|
"attention_dropout": 0.0, |
||||||
|
"dropout": 0.0, |
||||||
|
"hidden_act": "gelu", |
||||||
|
"hidden_size": 1280, |
||||||
|
"image_size": 224, |
||||||
|
"initializer_factor": 1.0, |
||||||
|
"initializer_range": 0.02, |
||||||
|
"intermediate_size": 5120, |
||||||
|
"layer_norm_eps": 1e-05, |
||||||
|
"model_type": "clip_vision_model", |
||||||
|
"num_attention_heads": 16, |
||||||
|
"num_channels": 3, |
||||||
|
"num_hidden_layers": 32, |
||||||
|
"patch_size": 14, |
||||||
|
"projection_dim": 1024, |
||||||
|
"torch_dtype": "float32" |
||||||
|
} |
@ -0,0 +1,59 @@ |
|||||||
|
|
||||||
|
|
||||||
|
from typing import List, Tuple, Union |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
|
||||||
|
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py |
||||||
|
|
||||||
|
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: |
||||||
|
r"""Normalize an image/video tensor with mean and standard deviation. |
||||||
|
.. math:: |
||||||
|
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]} |
||||||
|
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, |
||||||
|
Args: |
||||||
|
data: Image tensor of size :math:`(B, C, *)`. |
||||||
|
mean: Mean for each channel. |
||||||
|
std: Standard deviations for each channel. |
||||||
|
Return: |
||||||
|
Normalised tensor with same size as input :math:`(B, C, *)`. |
||||||
|
Examples: |
||||||
|
>>> x = torch.rand(1, 4, 3, 3) |
||||||
|
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.])) |
||||||
|
>>> out.shape |
||||||
|
torch.Size([1, 4, 3, 3]) |
||||||
|
>>> x = torch.rand(1, 4, 3, 3) |
||||||
|
>>> mean = torch.zeros(4) |
||||||
|
>>> std = 255. * torch.ones(4) |
||||||
|
>>> out = normalize(x, mean, std) |
||||||
|
>>> out.shape |
||||||
|
torch.Size([1, 4, 3, 3]) |
||||||
|
""" |
||||||
|
shape = data.shape |
||||||
|
if len(mean.shape) == 0 or mean.shape[0] == 1: |
||||||
|
mean = mean.expand(shape[1]) |
||||||
|
if len(std.shape) == 0 or std.shape[0] == 1: |
||||||
|
std = std.expand(shape[1]) |
||||||
|
|
||||||
|
# Allow broadcast on channel dimension |
||||||
|
if mean.shape and mean.shape[0] != 1: |
||||||
|
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]: |
||||||
|
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") |
||||||
|
|
||||||
|
# Allow broadcast on channel dimension |
||||||
|
if std.shape and std.shape[0] != 1: |
||||||
|
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]: |
||||||
|
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") |
||||||
|
|
||||||
|
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) |
||||||
|
std = torch.as_tensor(std, device=data.device, dtype=data.dtype) |
||||||
|
|
||||||
|
if mean.shape: |
||||||
|
mean = mean[..., :, None] |
||||||
|
if std.shape: |
||||||
|
std = std[..., :, None] |
||||||
|
|
||||||
|
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std |
||||||
|
|
||||||
|
return out.view(shape) |
@ -0,0 +1,35 @@ |
|||||||
|
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation |
||||||
|
from ldm.modules.diffusionmodules.openaimodel import Timestep |
||||||
|
import torch |
||||||
|
|
||||||
|
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): |
||||||
|
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): |
||||||
|
super().__init__(*args, **kwargs) |
||||||
|
if clip_stats_path is None: |
||||||
|
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim) |
||||||
|
else: |
||||||
|
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu") |
||||||
|
self.register_buffer("data_mean", clip_mean[None, :], persistent=False) |
||||||
|
self.register_buffer("data_std", clip_std[None, :], persistent=False) |
||||||
|
self.time_embed = Timestep(timestep_dim) |
||||||
|
|
||||||
|
def scale(self, x): |
||||||
|
# re-normalize to centered mean and unit variance |
||||||
|
x = (x - self.data_mean) * 1. / self.data_std |
||||||
|
return x |
||||||
|
|
||||||
|
def unscale(self, x): |
||||||
|
# back to original data stats |
||||||
|
x = (x * self.data_std) + self.data_mean |
||||||
|
return x |
||||||
|
|
||||||
|
def forward(self, x, noise_level=None): |
||||||
|
if noise_level is None: |
||||||
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() |
||||||
|
else: |
||||||
|
assert isinstance(noise_level, torch.Tensor) |
||||||
|
x = self.scale(x) |
||||||
|
z = self.q_sample(x, noise_level) |
||||||
|
z = self.unscale(z) |
||||||
|
noise_level = self.time_embed(noise_level) |
||||||
|
return z, noise_level |
@ -1,32 +0,0 @@ |
|||||||
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor |
|
||||||
from comfy.sd import load_torch_file |
|
||||||
import os |
|
||||||
|
|
||||||
class ClipVisionModel(): |
|
||||||
def __init__(self): |
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json") |
|
||||||
config = CLIPVisionConfig.from_json_file(json_config) |
|
||||||
self.model = CLIPVisionModel(config) |
|
||||||
self.processor = CLIPImageProcessor(crop_size=224, |
|
||||||
do_center_crop=True, |
|
||||||
do_convert_rgb=True, |
|
||||||
do_normalize=True, |
|
||||||
do_resize=True, |
|
||||||
image_mean=[ 0.48145466,0.4578275,0.40821073], |
|
||||||
image_std=[0.26862954,0.26130258,0.27577711], |
|
||||||
resample=3, #bicubic |
|
||||||
size=224) |
|
||||||
|
|
||||||
def load_sd(self, sd): |
|
||||||
self.model.load_state_dict(sd, strict=False) |
|
||||||
|
|
||||||
def encode_image(self, image): |
|
||||||
inputs = self.processor(images=[image[0]], return_tensors="pt") |
|
||||||
outputs = self.model(**inputs) |
|
||||||
return outputs |
|
||||||
|
|
||||||
def load(ckpt_path): |
|
||||||
clip_data = load_torch_file(ckpt_path) |
|
||||||
clip = ClipVisionModel() |
|
||||||
clip.load_sd(clip_data) |
|
||||||
return clip |
|
Loading…
Reference in new issue