Browse Source

Support base SDXL and SDXL refiner models.

Large refactor of the model detection and loading code.
pull/785/head
comfyanonymous 1 year ago
parent
commit
f87ec10a97
  1. 42
      comfy/cldm/cldm.py
  2. 23
      comfy/clip_config_bigg.json
  3. 28
      comfy/clip_vision.py
  4. 4
      comfy/ldm/modules/attention.py
  5. 11
      comfy/ldm/modules/diffusionmodules/openaimodel.py
  6. 78
      comfy/model_base.py
  7. 120
      comfy/model_detection.py
  8. 13
      comfy/samplers.py
  9. 368
      comfy/sd.py
  10. 40
      comfy/sd1_clip.py
  11. 2
      comfy/sd2_clip.py
  12. 83
      comfy/sdxl_clip.py
  13. 148
      comfy/supported_models.py
  14. 65
      comfy/supported_models_base.py
  15. 16
      comfy/utils.py
  16. 6
      nodes.py

42
comfy/cldm/cldm.py

@ -34,8 +34,10 @@ class ControlNet(nn.Module):
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
use_bf16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
@ -51,6 +53,8 @@ class ControlNet(nn.Module):
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
):
super().__init__()
if use_spatial_transformer:
@ -75,6 +79,10 @@ class ControlNet(nn.Module):
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
@ -97,8 +105,10 @@ class ControlNet(nn.Module):
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
@ -111,6 +121,24 @@ class ControlNet(nn.Module):
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
@ -179,7 +207,7 @@ class ControlNet(nn.Module):
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
@ -238,7 +266,7 @@ class ControlNet(nn.Module):
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
@ -257,7 +285,7 @@ class ControlNet(nn.Module):
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, **kwargs):
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
@ -265,6 +293,14 @@ class ControlNet(nn.Module):
outs = []
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:

23
comfy/clip_config_bigg.json

@ -0,0 +1,23 @@
{
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 20,
"num_hidden_layers": 32,
"pad_token_id": 1,
"projection_dim": 512,
"torch_dtype": "float32",
"vocab_size": 49408
}

28
comfy/clip_vision.py

@ -29,31 +29,31 @@ class ClipVisionModel():
outputs = self.model(**inputs)
return outputs
def convert_to_transformers(sd):
def convert_to_transformers(sd, prefix):
sd_k = sd.keys()
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k:
if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
keys_to_replace = {
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding",
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight",
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight",
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias",
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight",
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias",
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight",
"{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
"{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
"{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
"{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
"{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
"{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
"{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
}
for x in keys_to_replace:
if x in sd_k:
sd[keys_to_replace[x]] = sd.pop(x)
if "embedder.model.visual.proj" in sd_k:
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1)
if "{}proj".format(prefix) in sd_k:
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32)
sd = transformers_convert(sd, prefix, "vision_model.", 32)
return sd
def load_clipvision_from_sd(sd):
sd = convert_to_transformers(sd)
def load_clipvision_from_sd(sd, prefix):
sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:

4
comfy/ldm/modules/attention.py

@ -600,7 +600,7 @@ class SpatialTransformer(nn.Module):
use_checkpoint=True, dtype=None):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype)
@ -630,7 +630,7 @@ class SpatialTransformer(nn.Module):
def forward(self, x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape
x_in = x
x = self.norm(x)

11
comfy/ldm/modules/diffusionmodules/openaimodel.py

@ -502,6 +502,7 @@ class UNetModel(nn.Module):
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
):
super().__init__()
if use_spatial_transformer:
@ -526,6 +527,10 @@ class UNetModel(nn.Module):
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
@ -631,7 +636,7 @@ class UNetModel(nn.Module):
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
)
@ -690,7 +695,7 @@ class UNetModel(nn.Module):
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
),
@ -746,7 +751,7 @@ class UNetModel(nn.Module):
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
)

78
comfy/model_base.py

@ -2,6 +2,7 @@ import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
class BaseModel(torch.nn.Module):
@ -15,9 +16,9 @@ class BaseModel(torch.nn.Module):
self.parameterization = "v"
else:
self.parameterization = "eps"
if "adm_in_channels" in unet_config:
self.adm_channels = unet_config["adm_in_channels"]
else:
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
print("v_prediction", v_prediction)
print("adm", self.adm_channels)
@ -55,6 +56,25 @@ class BaseModel(torch.nn.Module):
def is_adm(self):
return self.adm_channels > 0
def encode_adm(self, **kwargs):
return None
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
for k in keys:
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)
if len(u) > 0:
print("unet unexpected:", u)
del to_load
return self
class SD21UNCLIP(BaseModel):
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
super().__init__(unet_config, v_prediction)
@ -95,3 +115,55 @@ class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.concat_keys = ("mask", "masked_image")
class SDXLRefiner(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
if kwargs.get("prompt_type", "") == "negative":
aesthetic_score = kwargs.get("aesthetic_score", 2.5)
else:
aesthetic_score = kwargs.get("aesthetic_score", 6)
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([aesthetic_score])))
flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([target_width])))
out.append(self.embedder(torch.Tensor([target_height])))
flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

120
comfy/model_detection.py

@ -0,0 +1,120 @@
from . import supported_models
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
c = False
for k in state_dict_keys:
if k.startswith(prefix_string.format(count)):
c = True
break
if c == False:
break
count += 1
return count
def detect_unet_config(state_dict, key_prefix, use_fp16):
state_dict_keys = list(state_dict.keys())
num_res_blocks = 2
unet_config = {
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"num_res_blocks": num_res_blocks,
"use_spatial_transformer": True,
"legacy": False
}
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
if y_input in state_dict_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
else:
unet_config["adm_in_channels"] = None
unet_config["use_fp16"] = use_fp16
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
num_res_blocks = []
channel_mult = []
attention_resolutions = []
transformer_depth = []
context_dim = None
use_linear_in_transformer = False
current_res = 1
count = 0
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
while True:
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
if len(block_keys) == 0:
break
if "{}0.op.weight".format(prefix) in block_keys: #new layer
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
current_res *= 2
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
else:
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
if context_dim is None:
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
count += 1
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
if len(set(num_res_blocks)) == 1:
num_res_blocks = num_res_blocks[0]
if len(set(transformer_depth)) == 1:
transformer_depth = transformer_depth[0]
unet_config["in_channels"] = in_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["attention_resolutions"] = attention_resolutions
unet_config["transformer_depth"] = transformer_depth
unet_config["channel_mult"] = channel_mult
unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
unet_config["context_dim"] = context_dim
return unet_config
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
for model_config in supported_models.models:
if model_config.matches(unet_config):
return model_config(unet_config)
return None

13
comfy/samplers.py

@ -229,7 +229,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
@ -460,8 +460,7 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n]
def encode_adm(model, conds, batch_size, device):
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
for t in range(len(conds)):
x = conds[t]
adm_out = None
@ -469,7 +468,11 @@ def encode_adm(model, conds, batch_size, device):
adm_out = x[1]["adm"]
else:
params = x[1].copy()
params["width"] = params.get("width", width * 8)
params["height"] = params.get("height", height * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
adm_out = model.encode_adm(device=device, **params)
if adm_out is not None:
x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device)
@ -580,8 +583,8 @@ class KSampler:
precision_scope = contextlib.nullcontext
if self.model.is_adm():
positive = encode_adm(self.model, positive, noise.shape[0], self.device)
negative = encode_adm(self.model, negative, noise.shape[0], self.device)
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}

368
comfy/sd.py

@ -3,8 +3,6 @@ import contextlib
import copy
import inspect
from . import sd1_clip
from . import sd2_clip
from comfy import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL
@ -17,19 +15,28 @@ from . import clip_vision
from . import gligen
from . import diffusers_convert
from . import model_base
from . import model_detection
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
replace_prefix = {"model.diffusion_model.": "diffusion_model."}
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys())))
for x in replace:
sd[x[1]] = sd.pop(x[0])
from . import sd1_clip
from . import sd2_clip
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
m = set(m)
unexpected_keys = set(u)
k = list(sd.keys())
for x in k:
# print(x)
if x not in unexpected_keys:
w = sd.pop(x)
del w
if len(m) > 0:
print("missing", m)
return model
def load_clip_weights(model, sd):
k = list(sd.keys())
for x in k:
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
sd[y] = sd.pop(x)
@ -39,20 +46,8 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
for x in load_state_dict_to:
x.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.eval()
return model
sd = utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return load_model_weights(model, sd)
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@ -66,18 +61,26 @@ LORA_CLIP_MAP = {
LORA_UNET_MAP_ATTENTIONS = {
"proj_in": "proj_in",
"proj_out": "proj_out",
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
"transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k",
"transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v",
"transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0",
"transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q",
"transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k",
"transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v",
"transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0",
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj",
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
}
transformer_lora_blocks = {
"transformer_blocks.{}.attn1.to_q": "transformer_blocks_{}_attn1_to_q",
"transformer_blocks.{}.attn1.to_k": "transformer_blocks_{}_attn1_to_k",
"transformer_blocks.{}.attn1.to_v": "transformer_blocks_{}_attn1_to_v",
"transformer_blocks.{}.attn1.to_out.0": "transformer_blocks_{}_attn1_to_out_0",
"transformer_blocks.{}.attn2.to_q": "transformer_blocks_{}_attn2_to_q",
"transformer_blocks.{}.attn2.to_k": "transformer_blocks_{}_attn2_to_k",
"transformer_blocks.{}.attn2.to_v": "transformer_blocks_{}_attn2_to_v",
"transformer_blocks.{}.attn2.to_out.0": "transformer_blocks_{}_attn2_to_out_0",
"transformer_blocks.{}.ff.net.0.proj": "transformer_blocks_{}_ff_net_0_proj",
"transformer_blocks.{}.ff.net.2": "transformer_blocks_{}_ff_net_2",
}
for i in range(10):
for k in transformer_lora_blocks:
LORA_UNET_MAP_ATTENTIONS[k.format(i)] = transformer_lora_blocks[k].format(i)
LORA_UNET_MAP_RESNET = {
"in_layers.2": "resnets_{}_conv1",
"emb_layers.1": "resnets_{}_time_emb_proj",
@ -470,21 +473,12 @@ def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
class CLIP:
def __init__(self, config={}, embedding_directory=None, no_init=False):
def __init__(self, target=None, embedding_directory=None, no_init=False):
if no_init:
return
self.target_clip = config["target"]
if "params" in config:
params = config["params"]
else:
params = {}
if self.target_clip.endswith("FrozenOpenCLIPEmbedder"):
clip = sd2_clip.SD2ClipModel
tokenizer = sd2_clip.SD2Tokenizer
elif self.target_clip.endswith("FrozenCLIPEmbedder"):
clip = sd1_clip.SD1ClipModel
tokenizer = sd1_clip.SD1Tokenizer
params = target.params
clip = target.clip
tokenizer = target.tokenizer
self.device = model_management.text_encoder_device()
params["device"] = self.device
@ -497,11 +491,11 @@ class CLIP:
def clone(self):
n = CLIP(no_init=True)
n.target_clip = self.target_clip
n.patcher = self.patcher.clone()
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx
n.device = self.device
return n
def load_from_state_dict(self, sd):
@ -521,21 +515,22 @@ class CLIP:
self.cond_stage_model.clip_layer(self.layer_idx)
try:
self.patcher.patch_model()
cond = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.patcher.unpatch_model()
except Exception as e:
self.patcher.unpatch_model()
raise e
cond_out = cond
if return_pooled:
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__)
pooled = cond[:, eos_token_index]
return cond, pooled
return cond
return cond_out, pooled
return cond_out
def encode(self, text):
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
class VAE:
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
if config is None:
@ -668,10 +663,10 @@ class ControlNet:
self.previous_controlnet = None
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond_txt, batched_number):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
@ -689,7 +684,9 @@ class ControlNet:
with precision_scope(model_management.get_autocast_device(self.device)):
self.control_model = model_management.load_if_low_vram(self.control_model)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
context = torch.cat(cond['c_crossattn'], 1)
y = cond.get('c_adm', None)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled()
@ -749,60 +746,28 @@ class ControlNet:
def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
sd2 = False
key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
key = 'zero_convs.0.0.weight'
if pth_key in controlnet_data:
pth = True
key = pth_key
prefix = "control_model."
elif key in controlnet_data:
pass
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net
context_dim = controlnet_data[key].shape[1]
use_fp16 = False
if model_management.should_use_fp16() and controlnet_data[key].dtype == torch.float16:
use_fp16 = True
if context_dim == 768:
#SD1.x
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
model_channels=320,
attention_resolutions=[ 4, 2, 1 ],
num_res_blocks=2,
channel_mult=[ 1, 2, 4, 4 ],
num_heads=8,
use_spatial_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=False,
legacy=False,
use_fp16=use_fp16)
else:
#SD2.x
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
model_channels=320,
attention_resolutions=[ 4, 2, 1 ],
num_res_blocks=2,
channel_mult=[ 1, 2, 4, 4 ],
num_head_channels=64,
use_spatial_transformer=True,
use_linear_in_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=False,
legacy=False,
use_fp16=use_fp16)
use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = 3
control_model = cldm.ControlNet(**controlnet_config)
if pth:
if 'difference' in controlnet_data:
if model is not None:
@ -823,9 +788,10 @@ def load_controlnet(ckpt_path, model=None):
pass
w = WeightsLoader()
w.control_model = control_model
w.load_state_dict(controlnet_data, strict=False)
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else:
control_model.load_state_dict(controlnet_data, strict=False)
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
if use_fp16:
control_model = control_model.half()
@ -850,10 +816,10 @@ class T2IAdapter:
self.cond_hint_original = None
self.cond_hint = None
def get_control(self, x_noisy, t, cond_txt, batched_number):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
@ -929,12 +895,21 @@ class T2IAdapter:
def load_t2i_adapter(t2i_data):
keys = t2i_data.keys()
if 'adapter' in keys:
t2i_data = t2i_data['adapter']
keys = t2i_data.keys()
if "body.0.in_conv.weight" in keys:
cin = t2i_data['body.0.in_conv.weight'].shape[1]
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
elif 'conv_in.weight' in keys:
cin = t2i_data['conv_in.weight'].shape[1]
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
channel = t2i_data['conv_in.weight'].shape[0]
ksize = t2i_data['body.0.block2.weight'].shape[2]
use_conv = False
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
if len(down_opts) > 0:
use_conv = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv)
else:
return None
model_ad.load_state_dict(t2i_data)
@ -1010,17 +985,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
if state_dict is None:
state_dict = utils.load_torch_file(ckpt_path)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
@ -1029,13 +995,33 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
if state_dict is None:
state_dict = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae:
w = WeightsLoader()
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, state_dict)
if output_clip:
w = WeightsLoader()
class EmptyClass:
pass
clip_target = EmptyClass()
clip_target.params = clip_config["params"]
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_clip_weights(w, state_dict)
return (ModelPatcher(model), clip, vae)
@ -1045,139 +1031,41 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip = None
clipvision = None
vae = None
model = None
clip_target = None
fp16 = model_management.should_use_fp16()
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE()
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip_config = {}
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
else:
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight"
noise_aug_config = None
if clipvision_key in sd_keys:
size = sd[clipvision_key].shape[1]
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
if model_config.clip_vision_prefix is not None:
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd)
noise_aug_key = "noise_augmentor.betas"
if noise_aug_key in sd_keys:
noise_aug_config = {}
params = {}
noise_schedule_config = {}
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h
params["timestep_dim"] = 1024
elif size == 1024: #l
params["timestep_dim"] = 768
noise_aug_config['params'] = params
sd_config = {
"linear_start": 0.00085,
"linear_end": 0.012,
"num_timesteps_cond": 1,
"log_every_t": 200,
"timesteps": 1000,
"first_stage_key": "jpg",
"cond_stage_key": "txt",
"image_size": 64,
"channels": 4,
"cond_stage_trainable": False,
"monitor": "val/loss_simple_ema",
"scale_factor": 0.18215,
"use_ema": False,
}
unet_config = {
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"attention_resolutions": [
4,
2,
1
],
"num_res_blocks": 2,
"channel_mult": [
1,
2,
4,
4
],
"use_spatial_transformer": True,
"transformer_depth": 1,
"legacy": False
}
if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2:
unet_config['use_linear_in_transformer'] = True
unet_config["use_fp16"] = fp16
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
unclip_model = False
inpaint_model = False
if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
unclip_model = True
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
inpaint_model = True
else:
sd_config["conditioning_key"] = "crossattn"
if unet_config["context_dim"] == 768:
unet_config["num_heads"] = 8 #SD1.x
else:
unet_config["num_head_channels"] = 64 #SD2.x
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix)
unclip = 'model.diffusion_model.label_emb.0.0.weight'
if unclip in sd_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1]
model = model_config.get_model(sd)
model.load_model_weights(sd, "model.diffusion_model.")
v_prediction = False
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = sd[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
v_prediction = True
sd_config["parameterization"] = 'v'
if output_vae:
vae = VAE(scale_factor=model_config.vae_scale_factor)
w = WeightsLoader()
w.first_stage_model = vae.first_stage_model
load_model_weights(w, sd)
if inpaint_model:
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
elif unclip_model:
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
if output_clip:
w = WeightsLoader()
clip_target = model_config.clip_target()
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
sd = model_config.process_clip_state_dict(sd)
load_model_weights(w, sd)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys:", left_over)
return (ModelPatcher(model), clip, vae, clipvision)

40
comfy/sd1_clip.py

@ -8,11 +8,14 @@ import zipfile
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
z_empty = self.encode(self.empty_tokens)
z_empty, _ = self.encode(self.empty_tokens)
output = []
first_pooled = None
for x in token_weight_pairs:
tokens = [list(map(lambda a: a[0], x))]
z = self.encode(tokens)
z, pooled = self.encode(tokens)
if first_pooled is None:
first_pooled = pooled
for i in range(len(z)):
for j in range(len(z[i])):
weight = x[j][1]
@ -20,7 +23,7 @@ class ClipTokenWeightEncoder:
output += [z]
if (len(output) == 0):
return self.encode(self.empty_tokens)
return torch.cat(output, dim=-2).cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
@ -50,6 +53,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = layer
self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76]
self.text_projection = None
self.layer_norm_hidden_state = True
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) <= 12
@ -112,9 +117,13 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
z = self.transformer.text_model.final_layer_norm(z)
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
return z
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z, pooled_output
def encode(self, tokens):
return self(tokens)
@ -204,7 +213,7 @@ def expand_directory_list(directories):
dirs.add(root)
return list(dirs)
def load_embed(embedding_name, embedding_directory):
def load_embed(embedding_name, embedding_directory, embedding_size):
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
@ -253,13 +262,23 @@ def load_embed(embedding_name, embedding_directory):
if embed_out is None:
if 'string_to_param' in embed:
values = embed['string_to_param'].values()
embed_out = next(iter(values))
elif isinstance(embed, list):
out_list = []
for x in range(len(embed)):
for k in embed[x]:
t = embed[x][k]
if t.shape[-1] != embedding_size:
continue
out_list.append(t.reshape(-1, t.shape[-1]))
embed_out = torch.cat(out_list, dim=0)
else:
values = embed.values()
embed_out = next(iter(values))
embed_out = next(iter(values))
return embed_out
class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
@ -275,17 +294,18 @@ class SD1Tokenizer:
self.embedding_directory = embedding_directory
self.max_word_length = 8
self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory)
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
embed = load_embed(stripped, self.embedding_directory, self.embedding_size)
return (embed, embedding_name[len(stripped):])
return (embed, "")

2
comfy/sd2_clip.py

@ -31,4 +31,4 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory)
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)

83
comfy/sdxl_clip.py

@ -0,0 +1,83 @@
from comfy import sd1_clip
import torch
import os
class SDXLClipG(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
self.layer_norm_hidden_state = False
if layer == "last":
pass
elif layer == "penultimate":
layer_idx = -1
self.clip_layer(layer_idx)
elif self.layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < 32
self.clip_layer(layer_idx)
else:
raise NotImplementedError()
def clip_layer(self, layer_idx):
if layer_idx < 0:
layer_idx -= 1 #The real last layer of SD2.x clip is the penultimate one. The last one might contain garbage.
if abs(layer_idx) >= 32:
self.layer = "hidden"
self.layer_idx = -2
else:
self.layer = "hidden"
self.layer_idx = layer_idx
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
super().__init__()
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device)
self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device)
def clip_layer(self, layer_idx):
self.clip_l.clip_layer(layer_idx)
self.clip_g.clip_layer(layer_idx)
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
super().__init__()
self.clip_g = SDXLClipG(device=device)
def clip_layer(self, layer_idx):
self.clip_g.clip_layer(layer_idx)
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled

148
comfy/supported_models.py

@ -0,0 +1,148 @@
import torch
from . import model_base
from . import utils
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from . import supported_models_base
class SD15(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
}
unet_extra_config = {
"num_heads": 8,
"num_head_channels": -1,
}
vae_scale_factor = 0.18215
def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
for x in k:
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
state_dict[y] = state_dict.pop(x)
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
if ids.dtype == torch.float32:
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
class SD20(supported_models_base.BASE):
unet_config = {
"context_dim": 1024,
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": None,
}
vae_scale_factor = 0.18215
def v_prediction(self, state_dict):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return True
return False
def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
class SD21UnclipL(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 1536,
}
clip_vision_prefix = "embedder.model.visual."
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}
class SD21UnclipH(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 2048,
}
clip_vision_prefix = "embedder.model.visual."
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}
class SDXLRefiner(supported_models_base.BASE):
unet_config = {
"model_channels": 384,
"use_linear_in_transformer": True,
"context_dim": 1280,
"adm_in_channels": 2560,
"transformer_depth": [0, 4, 4, 0],
}
vae_scale_factor = 0.13025
def get_model(self, state_dict):
return model_base.SDXLRefiner(self.unet_config)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
replace_prefix = {}
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
class SDXL(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 10],
"context_dim": 2048,
"adm_in_channels": 2816
}
vae_scale_factor = 0.13025
def get_model(self, state_dict):
return model_base.SDXL(self.unet_config)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
replace_prefix = {}
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL]

65
comfy/supported_models_base.py

@ -0,0 +1,65 @@
import torch
from . import model_base
from . import utils
def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace:
if x in state_dict:
state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix):
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
for x in replace:
state_dict[x[1]] = state_dict.pop(x[0])
return state_dict
class ClipTarget:
def __init__(self, tokenizer, clip):
self.clip = clip
self.tokenizer = tokenizer
self.params = {}
class BASE:
unet_config = {}
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
}
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
if s.unet_config[k] != unet_config[k]:
return False
return True
def v_prediction(self, state_dict):
return False
def inpaint_model(self):
return self.unet_config["in_channels"] > 4
def __init__(self, unet_config):
self.unet_config = unet_config
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict):
if self.inpaint_model():
return model_base.SDInpaint(self.unet_config, v_prediction=self.v_prediction(state_dict))
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self.unet_config, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
else:
return model_base.BaseModel(self.unet_config, v_prediction=self.v_prediction(state_dict))
def process_clip_state_dict(self, state_dict):
return state_dict

16
comfy/utils.py

@ -26,10 +26,10 @@ def load_torch_file(ckpt, safe_load=False):
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
"{}.ln_final.weight": "{}.final_layer_norm.weight",
"{}.ln_final.bias": "{}.final_layer_norm.bias",
"{}positional_embedding": "{}embeddings.position_embedding.weight",
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
"{}ln_final.weight": "{}final_layer_norm.weight",
"{}ln_final.bias": "{}final_layer_norm.bias",
}
for k in keys_to_replace:
@ -48,19 +48,19 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
for resblock in range(number):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd

6
nodes.py

@ -48,7 +48,9 @@ class CLIPTextEncode:
CATEGORY = "conditioning"
def encode(self, clip, text):
return ([[clip.encode(text), {}]], )
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
class ConditioningCombine:
@classmethod
@ -1344,7 +1346,7 @@ NODE_CLASS_MAPPINGS = {
"DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent,
"SaveLatent": SaveLatent
"SaveLatent": SaveLatent,
}
NODE_DISPLAY_NAME_MAPPINGS = {

Loading…
Cancel
Save