diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 9b82a246..de52a118 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -119,6 +119,9 @@ class CLIPTextModel(torch.nn.Module): super().__init__() self.num_layers = config_dict["num_hidden_layers"] self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) + embed_dim = config_dict["hidden_size"] + self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) self.dtype = dtype def get_input_embeddings(self): @@ -128,7 +131,10 @@ class CLIPTextModel(torch.nn.Module): self.text_model.embeddings.token_embedding = embeddings def forward(self, *args, **kwargs): - return self.text_model(*args, **kwargs) + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out) + class CLIPVisionEmbeddings(torch.nn.Module): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): diff --git a/comfy/sd.py b/comfy/sd.py index 7a77bb17..1be34ea3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -52,7 +52,7 @@ def load_clip_weights(model, sd): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.") return load_model_weights(model, sd) @@ -361,7 +361,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI for i in range(len(clip_data)): if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: - clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "") clip_target = EmptyClass() clip_target.params = {} diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 8287ad2e..6f0ed300 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -86,7 +86,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = layer self.layer_idx = None self.special_tokens = special_tokens - self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.enable_attention_masks = enable_attention_masks @@ -182,18 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: pooled_output = None - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() return z.float(), pooled_output def encode(self, tokens): return self(tokens) def load_sd(self, sd): - if "text_projection" in sd: - self.text_projection[:] = sd.pop("text_projection") - if "text_projection.weight" in sd: - self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5bb98d88..dbc3cf26 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -75,7 +75,7 @@ class SD20(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format replace_prefix["cond_stage_model.model."] = "clip_h." state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -134,7 +134,7 @@ class SDXLRefiner(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.model."] = "clip_g." state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -182,10 +182,8 @@ class SDXL(supported_models_base.BASE): replace_prefix["conditioner.embedders.1.model."] = "clip_g." state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) - keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection" - state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -338,6 +336,12 @@ class Stable_Cascade_C(supported_models_base.BASE): state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] return state_dict + def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) + if "clip_g.text_projection" in state_dict: + state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1) + return state_dict + def get_model(self, state_dict, prefix="", device=None): out = model_base.StableCascade_C(self, device=device) return out diff --git a/comfy/utils.py b/comfy/utils.py index 04cf76ed..c471024d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -98,8 +98,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number): p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + + return sd + +def clip_text_transformers_convert(sd, prefix_from, prefix_to): + sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32) + + tp = "{}text_projection.weight".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp) + + tp = "{}text_projection".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1) return sd + UNET_MAP_ATTENTIONS = { "proj_in.weight", "proj_in.bias",