Browse Source

Support people putting commas after the embedding name in the prompt.

pull/16/head
comfyanonymous 2 years ago
parent
commit
137ae2606c
  1. 13
      comfy/sd1_clip.py

13
comfy/sd1_clip.py

@ -178,7 +178,6 @@ def load_embed(embedding_name, embedding_directory):
valid_file = t valid_file = t
break break
if valid_file is None: if valid_file is None:
print("warning, embedding {} does not exist, ignoring".format(embed_path))
return None return None
else: else:
embed_path = valid_file embed_path = valid_file
@ -218,18 +217,28 @@ class SD1Tokenizer:
tokens = [] tokens = []
for t in parsed_weights: for t in parsed_weights:
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ') to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ')
for word in to_tokenize: while len(to_tokenize) > 0:
word = to_tokenize.pop(0)
temp_tokens = [] temp_tokens = []
embedding_identifier = "embedding:" embedding_identifier = "embedding:"
if word.startswith(embedding_identifier) and self.embedding_directory is not None: if word.startswith(embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(embedding_identifier):].strip('\n') embedding_name = word[len(embedding_identifier):].strip('\n')
embed = load_embed(embedding_name, self.embedding_directory) embed = load_embed(embedding_name, self.embedding_directory)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
if embed is not None:
to_tokenize.insert(0, embedding_name[len(stripped):])
if embed is not None: if embed is not None:
if len(embed.shape) == 1: if len(embed.shape) == 1:
temp_tokens += [(embed, t[1])] temp_tokens += [(embed, t[1])]
else: else:
for x in range(embed.shape[0]): for x in range(embed.shape[0]):
temp_tokens += [(embed[x], t[1])] temp_tokens += [(embed[x], t[1])]
else:
print("warning, embedding:{} does not exist, ignoring".format(embedding_name))
elif len(word) > 0: elif len(word) > 0:
tt = self.tokenizer(word)["input_ids"][1:-1] tt = self.tokenizer(word)["input_ids"][1:-1]
for x in tt: for x in tt:

Loading…
Cancel
Save