diff --git a/comfy/sd.py b/comfy/sd.py index 3568a2aa..20d00952 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -564,9 +564,6 @@ class CLIP: n.layer_idx = self.layer_idx return n - def load_from_state_dict(self, sd): - self.cond_stage_model.load_sd(sd) - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4616ca4e..477d5c30 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -66,7 +66,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = layer self.layer_idx = None self.empty_tokens = [[49406] + [49407] * 76] - self.text_projection = None + self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = True if layer == "hidden": assert layer_idx is not None @@ -163,6 +165,10 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): + if "text_projection" in sd: + self.text_projection[:] = sd.pop("text_projection") + if "text_projection.weight" in sd: + self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): diff --git a/comfy/sd2_clip_config.json b/comfy/sd2_clip_config.json index ace6ef00..85cec832 100644 --- a/comfy/sd2_clip_config.json +++ b/comfy/sd2_clip_config.json @@ -17,7 +17,7 @@ "num_attention_heads": 16, "num_hidden_layers": 24, "pad_token_id": 1, - "projection_dim": 512, + "projection_dim": 1024, "torch_dtype": "float32", "vocab_size": 49408 } diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index d05c0a9b..e3ac2ee0 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -11,15 +11,9 @@ class SDXLClipG(sd1_clip.SD1ClipModel): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] - self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.layer_norm_hidden_state = False def load_sd(self, sd): - if "text_projection" in sd: - self.text_projection[:] = sd.pop("text_projection") - if "text_projection.weight" in sd: - self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return super().load_sd(sd) class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):