From f73e57d881bfff3d85cd631c31dfd245e3dfa2f1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 29 Jan 2023 18:46:44 -0500 Subject: [PATCH] Add support for textual inversion embedding for SD1.x CLIP. --- .gitignore | 1 + README.md | 4 + comfy/sd.py | 22 +++-- comfy/sd1_clip.py | 93 +++++++++++++++++-- ...eddings_or_textual_inversion_concepts_here | 0 nodes.py | 3 +- 6 files changed, 108 insertions(+), 15 deletions(-) create mode 100644 models/embeddings/put_embeddings_or_textual_inversion_concepts_here diff --git a/.gitignore b/.gitignore index a72d994a..7961356d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ __pycache__/ output/ models/checkpoints models/vae +models/embeddings diff --git a/README.md b/README.md index df834c1b..4dfebd54 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,10 @@ Dragging a generated png on the webpage or loading one will give you the full wo You can use () to change emphasis of a word or phrase like: (good code:1.2) or (bad code:0.8). The default emphasis for () is 1.1. To use () characters in your actual prompt escape them like \\( or \\). +To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension): + +```embedding:embedding_filename.pt``` + ### Colab Notebook To run it on colab you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb) diff --git a/comfy/sd.py b/comfy/sd.py index 98bb4bdb..13776f1b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -53,19 +53,25 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): class CLIP: - def __init__(self, config): + def __init__(self, config, embedding_directory=None): self.target_clip = config["target"] + if "params" in config: + params = config["params"] + else: + params = {} + + tokenizer_params = {} + if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder": clip = sd2_clip.SD2ClipModel tokenizer = sd2_clip.SD2Tokenizer elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder": clip = sd1_clip.SD1ClipModel tokenizer = sd1_clip.SD1Tokenizer - if "params" in config: - self.cond_stage_model = clip(**(config["params"])) - else: - self.cond_stage_model = clip() - self.tokenizer = tokenizer() + tokenizer_params['embedding_directory'] = embedding_directory + + self.cond_stage_model = clip(**(params)) + self.tokenizer = tokenizer(**(tokenizer_params)) def encode(self, text): tokens = self.tokenizer.tokenize_with_weights(text) @@ -103,7 +109,7 @@ class VAE: return samples -def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True): +def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): config = OmegaConf.load(config_path) model_config_params = config['model']['params'] clip_config = model_config_params['cond_stage_config'] @@ -124,7 +130,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True): load_state_dict_to = [w] if output_clip: - clip = CLIP(config=clip_config) + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 2a881832..4eccdc64 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -63,9 +63,38 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = "hidden" self.layer_idx = layer_idx + def set_up_textual_embeddings(self, tokens, current_embeds): + out_tokens = [] + next_new_token = token_dict_size = current_embeds.weight.shape[0] + embedding_weights = [] + + for x in tokens: + tokens_temp = [] + for y in x: + if isinstance(y, int): + tokens_temp += [y] + else: + embedding_weights += [y] + tokens_temp += [next_new_token] + next_new_token += 1 + out_tokens += [tokens_temp] + + if len(embedding_weights) > 0: + new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1]) + new_embedding.weight[:token_dict_size] = current_embeds.weight[:] + n = token_dict_size + for x in embedding_weights: + new_embedding.weight[n] = x + n += 1 + self.transformer.set_input_embeddings(new_embedding) + return out_tokens + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(self.device) outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": z = outputs.last_hidden_state @@ -138,18 +167,49 @@ def unescape_important(text): text = text.replace("\0\2", "(") return text +def load_embed(embedding_name, embedding_directory): + embed_path = os.path.join(embedding_directory, embedding_name) + if not os.path.isfile(embed_path): + extensions = ['.safetensors', '.pt', '.bin'] + valid_file = None + for x in extensions: + t = embed_path + x + if os.path.isfile(t): + valid_file = t + break + if valid_file is None: + print("warning, embedding {} does not exist, ignoring".format(embed_path)) + return None + else: + embed_path = valid_file + + if embed_path.lower().endswith(".safetensors"): + import safetensors.torch + embed = safetensors.torch.load_file(embed_path, device="cpu") + else: + embed = torch.load(embed_path, weights_only=True, map_location="cpu") + if 'string_to_param' in embed: + values = embed['string_to_param'].values() + else: + values = embed.values() + return next(iter(values)) + class SD1Tokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) self.max_length = max_length + self.max_tokens_per_section = self.max_length - 2 + empty = self.tokenizer('')["input_ids"] self.start_token = empty[0] self.end_token = empty[1] self.pad_with_end = pad_with_end vocab = self.tokenizer.get_vocab() self.inv_vocab = {v: k for k, v in vocab.items()} + self.embedding_directory = embedding_directory + self.max_word_length = 8 def tokenize_with_weights(self, text): text = escape_important(text) @@ -157,13 +217,34 @@ class SD1Tokenizer: tokens = [] for t in parsed_weights: - tt = self.tokenizer(unescape_important(t[0]))["input_ids"][1:-1] - for x in tt: - tokens += [(x, t[1])] + to_tokenize = unescape_important(t[0]).split(' ') + for word in to_tokenize: + temp_tokens = [] + embedding_identifier = "embedding:" + if word.startswith(embedding_identifier) and self.embedding_directory is not None: + embedding_name = word[len(embedding_identifier):].strip('\n') + embed = load_embed(embedding_name, self.embedding_directory) + if embed is not None: + if len(embed.shape) == 1: + temp_tokens += [(embed, t[1])] + else: + for x in range(embed.shape[0]): + temp_tokens += [(embed[x], t[1])] + elif len(word) > 0: + tt = self.tokenizer(word)["input_ids"][1:-1] + for x in tt: + temp_tokens += [(x, t[1])] + tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section) + + #try not to split words in different sections + if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length): + for x in range(tokens_left): + tokens += [(self.end_token, 1.0)] + tokens += temp_tokens out_tokens = [] - for x in range(0, len(tokens), self.max_length - 2): - o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_length - 2 + x, len(tokens))] + for x in range(0, len(tokens), self.max_tokens_per_section): + o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))] o_token += [(self.end_token, 1.0)] if self.pad_with_end: o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token)) diff --git a/models/embeddings/put_embeddings_or_textual_inversion_concepts_here b/models/embeddings/put_embeddings_or_textual_inversion_concepts_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index 7e737d0c..585fa80b 100644 --- a/nodes.py +++ b/nodes.py @@ -127,7 +127,8 @@ class CheckpointLoader: def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): config_path = os.path.join(self.config_dir, config_name) ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) - return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True) + embedding_directory = os.path.join(self.models_dir, "embeddings") + return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory) class VAELoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")