Browse Source

Move text projection into the CLIP model code.

Fix issue with not loading the SSD1B clip correctly.
pull/2908/head
comfyanonymous 9 months ago
parent
commit
1cb3f6a83b
  1. 8
      comfy/clip_model.py
  2. 4
      comfy/sd.py
  3. 8
      comfy/sd1_clip.py
  4. 14
      comfy/supported_models.py
  5. 14
      comfy/utils.py

8
comfy/clip_model.py

@ -119,6 +119,9 @@ class CLIPTextModel(torch.nn.Module):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
@ -128,7 +131,10 @@ class CLIPTextModel(torch.nn.Module):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
return self.text_model(*args, **kwargs)
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out)
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):

4
comfy/sd.py

@ -52,7 +52,7 @@ def load_clip_weights(model, sd):
if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.")
return load_model_weights(model, sd)
@ -361,7 +361,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32)
clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "")
clip_target = EmptyClass()
clip_target.params = {}

8
comfy/sd1_clip.py

@ -86,7 +86,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = enable_attention_masks
@ -182,18 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
def encode(self, tokens):
return self(tokens)
def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
if "text_projection.weight" in sd:
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string):

14
comfy/supported_models.py

@ -75,7 +75,7 @@ class SD20(supported_models_base.BASE):
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
replace_prefix["cond_stage_model.model."] = "clip_h."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@ -134,7 +134,7 @@ class SDXLRefiner(supported_models_base.BASE):
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
@ -182,10 +182,8 @@ class SDXL(supported_models_base.BASE):
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection"
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@ -338,6 +336,12 @@ class Stable_Cascade_C(supported_models_base.BASE):
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return state_dict
def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
if "clip_g.text_projection" in state_dict:
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
return state_dict
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_C(self, device=device)
return out

14
comfy/utils.py

@ -98,8 +98,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
tp = "{}text_projection.weight".format(prefix_from)
if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
tp = "{}text_projection".format(prefix_from)
if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1)
return sd
UNET_MAP_ATTENTIONS = {
"proj_in.weight",
"proj_in.bias",

Loading…
Cancel
Save