diff --git a/comfy/clip_model.py b/comfy/clip_model.py index de52a118..14f43c56 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -133,7 +133,7 @@ class CLIPTextModel(torch.nn.Module): def forward(self, *args, **kwargs): x = self.text_model(*args, **kwargs) out = self.text_projection(x[2]) - return (x[0], x[1], out) + return (x[0], x[1], out, x[2]) class CLIPVisionEmbeddings(torch.nn.Module): diff --git a/comfy/lora.py b/comfy/lora.py index 715e0842..21b9897a 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -201,9 +201,9 @@ def model_lora_keys_clip(model, key_map={}): key_map[lora_key] = k - k = "clip_g.text_projection" + k = "clip_g.transformer.text_projection.weight" if k in sdk: - key_map["lora_prior_te_text_projection"] = k #cascade lora + key_map["lora_prior_te_text_projection"] = k #cascade lora? # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too # key_map["lora_te_text_projection"] = k diff --git a/comfy/sd.py b/comfy/sd.py index 1be34ea3..d208abe6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -123,10 +123,13 @@ class CLIP: return self.tokenizer.tokenize_with_weights(text, return_word_ids) def encode_from_tokens(self, tokens, return_pooled=False): + self.cond_stage_model.reset_clip_options() + if self.layer_idx is not None: - self.cond_stage_model.clip_layer(self.layer_idx) - else: - self.cond_stage_model.reset_clip_layer() + self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) + + if return_pooled == "unprojected": + self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6f0ed300..04f3dfd3 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = True + if layer == "hidden": assert layer_idx is not None assert abs(layer_idx) < self.num_layers - self.clip_layer(layer_idx) - self.layer_default = (self.layer, self.layer_idx) + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) def freeze(self): self.transformer = self.transformer.eval() @@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): for param in self.parameters(): param.requires_grad = False - def clip_layer(self, layer_idx): - if abs(layer_idx) > self.num_layers: + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" self.layer_idx = layer_idx - def reset_clip_layer(self): - self.layer = self.layer_default[0] - self.layer_idx = self.layer_default[1] + def reset_clip_options(self): + self.layer = self.options_default[0] + self.layer_idx = self.options_default[1] + self.return_projected_pooled = self.options_default[2] def set_up_textual_embeddings(self, tokens, current_embeds): out_tokens = [] @@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: z = outputs[1] - if outputs[2] is not None: - pooled_output = outputs[2].float() - else: - pooled_output = None + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() return z.float(), pooled_output @@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) - def clip_layer(self, layer_idx): - getattr(self, self.clip).clip_layer(layer_idx) + def set_clip_options(self, options): + getattr(self, self.clip).set_clip_options(options) - def reset_clip_layer(self): - getattr(self, self.clip).reset_clip_layer() + def reset_clip_options(self): + getattr(self, self.clip).reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs[self.clip_name] diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 3ce5c7e0..e62d1ed8 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -40,13 +40,13 @@ class SDXLClipModel(torch.nn.Module): self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) - def clip_layer(self, layer_idx): - self.clip_l.clip_layer(layer_idx) - self.clip_g.clip_layer(layer_idx) + def set_clip_options(self, options): + self.clip_l.set_clip_options(options) + self.clip_g.set_clip_options(options) - def reset_clip_layer(self): - self.clip_g.reset_clip_layer() - self.clip_l.reset_clip_layer() + def reset_clip_options(self): + self.clip_g.reset_clip_options() + self.clip_l.reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs_g = token_weight_pairs["g"] diff --git a/nodes.py b/nodes.py index a577c212..1846797b 100644 --- a/nodes.py +++ b/nodes.py @@ -1003,7 +1003,7 @@ class GLIGENTextBoxApply: def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): c = [] - cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected") for t in conditioning_to: n = [t[0], t[1].copy()] position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]