Browse Source

CLIP code refactor and improvements.

More generic clip model class that can be used on more types of text
encoders.

Don't apply weighting algorithm when weight is 1.0

Don't compute an empty token output when it's not needed.
pull/1932/head
comfyanonymous 1 year ago
parent
commit
656c0b5d90
  1. 80
      comfy/sd1_clip.py
  2. 3
      comfy/sd2_clip.py
  3. 8
      comfy/sdxl_clip.py

80
comfy/sd1_clip.py

@ -8,32 +8,54 @@ import zipfile
from . import model_management from . import model_management
import contextlib import contextlib
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
end_token = special_tokens.get("end", None)
pad_token = special_tokens.get("pad")
output = []
if start_token is not None:
output.append(start_token)
if end_token is not None:
output.append(end_token)
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
to_encode = list(self.empty_tokens) to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs: for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x)) tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens) to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode) out, pooled = self.encode(to_encode)
z_empty = out[0:1] if pooled is not None:
if pooled.shape[0] > 1: first_pooled = pooled[0:1].cpu()
first_pooled = pooled[1:2]
else: else:
first_pooled = pooled[0:1] first_pooled = pooled
output = [] output = []
for k in range(1, out.shape[0]): for k in range(0, sections):
z = out[k:k+1] z = out[k:k+1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)): for i in range(len(z)):
for j in range(len(z[i])): for j in range(len(z[i])):
weight = token_weight_pairs[k - 1][j][1] weight = token_weight_pairs[k][j][1]
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
return z_empty.cpu(), first_pooled.cpu() return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32 freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.num_layers = 12 self.num_layers = 12
if textmodel_path is not None: if textmodel_path is not None:
self.transformer = CLIPTextModel.from_pretrained(textmodel_path) self.transformer = model_class.from_pretrained(textmodel_path)
else: else:
if textmodel_json_config is None: if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config) config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops(device, dtype): with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config) self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None: if dtype is not None:
self.transformer.to(dtype) self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32) inner_model = getattr(self.transformer, self.inner_name)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32) if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
self.freeze() self.freeze()
self.layer = layer self.layer = layer
self.layer_idx = None self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76] self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False self.enable_attention_masks = False
self.layer_norm_hidden_state = True self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers assert abs(layer_idx) <= self.num_layers
@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x): while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]] tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
n = token_dict_size n = token_dict_size
@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = lambda a, b: contextlib.nullcontext(a) precision_scope = lambda a, b: contextlib.nullcontext(a)
@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
z = outputs.hidden_states[self.layer_idx] z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state: if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z) z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
pooled_output = outputs.pooler_output if self.text_projection is not None and pooled_output is not None:
if self.text_projection is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float() return z.float(), pooled_output
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)

3
comfy/sd2_clip.py

@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
layer_idx=23 layer_idx=23
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
self.empty_tokens = [[49406] + [49407] + [0] * 75]
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):

8
comfy/sdxl_clip.py

@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
self.empty_tokens = [[49406] + [49407] + [0] * 75] special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
self.layer_norm_hidden_state = False
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
@ -38,8 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):

Loading…
Cancel
Save