|
|
|
@ -260,60 +260,97 @@ class SD1Tokenizer:
|
|
|
|
|
self.inv_vocab = {v: k for k, v in vocab.items()} |
|
|
|
|
self.embedding_directory = embedding_directory |
|
|
|
|
self.max_word_length = 8 |
|
|
|
|
self.embedding_identifier = "embedding:" |
|
|
|
|
|
|
|
|
|
def _try_get_embedding(self, embedding_name:str): |
|
|
|
|
''' |
|
|
|
|
Takes a potential embedding name and tries to retrieve it. |
|
|
|
|
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. |
|
|
|
|
''' |
|
|
|
|
embed = load_embed(embedding_name, self.embedding_directory) |
|
|
|
|
if embed is None: |
|
|
|
|
stripped = embedding_name.strip(',') |
|
|
|
|
if len(stripped) < len(embedding_name): |
|
|
|
|
embed = load_embed(stripped, self.embedding_directory) |
|
|
|
|
return (embed, embedding_name[len(stripped):]) |
|
|
|
|
return (embed, "") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_with_weights(self, text:str, return_word_ids=False): |
|
|
|
|
''' |
|
|
|
|
Takes a prompt and converts it to a list of (token, weight, word id) elements. |
|
|
|
|
Tokens can both be integer tokens and pre computed CLIP tensors. |
|
|
|
|
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. |
|
|
|
|
Returned list has the dimensions NxM where M is the input size of CLIP |
|
|
|
|
''' |
|
|
|
|
if self.pad_with_end: |
|
|
|
|
pad_token = self.end_token |
|
|
|
|
else: |
|
|
|
|
pad_token = 0 |
|
|
|
|
|
|
|
|
|
def tokenize_with_weights(self, text): |
|
|
|
|
text = escape_important(text) |
|
|
|
|
parsed_weights = token_weights(text, 1.0) |
|
|
|
|
|
|
|
|
|
#tokenize words |
|
|
|
|
tokens = [] |
|
|
|
|
for t in parsed_weights: |
|
|
|
|
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ') |
|
|
|
|
while len(to_tokenize) > 0: |
|
|
|
|
word = to_tokenize.pop(0) |
|
|
|
|
temp_tokens = [] |
|
|
|
|
embedding_identifier = "embedding:" |
|
|
|
|
if word.startswith(embedding_identifier) and self.embedding_directory is not None: |
|
|
|
|
embedding_name = word[len(embedding_identifier):].strip('\n') |
|
|
|
|
embed = load_embed(embedding_name, self.embedding_directory) |
|
|
|
|
for weighted_segment, weight in parsed_weights: |
|
|
|
|
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') |
|
|
|
|
to_tokenize = [x for x in to_tokenize if x != ""] |
|
|
|
|
for word in to_tokenize: |
|
|
|
|
#if we find an embedding, deal with the embedding |
|
|
|
|
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: |
|
|
|
|
embedding_name = word[len(self.embedding_identifier):].strip('\n') |
|
|
|
|
embed, leftover = self._try_get_embedding(embedding_name) |
|
|
|
|
if embed is None: |
|
|
|
|
stripped = embedding_name.strip(',') |
|
|
|
|
if len(stripped) < len(embedding_name): |
|
|
|
|
embed = load_embed(stripped, self.embedding_directory) |
|
|
|
|
if embed is not None: |
|
|
|
|
to_tokenize.insert(0, embedding_name[len(stripped):]) |
|
|
|
|
|
|
|
|
|
if embed is not None: |
|
|
|
|
print(f"warning, embedding:{embedding_name} does not exist, ignoring") |
|
|
|
|
else: |
|
|
|
|
if len(embed.shape) == 1: |
|
|
|
|
temp_tokens += [(embed, t[1])] |
|
|
|
|
tokens.append([(embed, weight)]) |
|
|
|
|
else: |
|
|
|
|
for x in range(embed.shape[0]): |
|
|
|
|
temp_tokens += [(embed[x], t[1])] |
|
|
|
|
tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) |
|
|
|
|
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word |
|
|
|
|
if leftover != "": |
|
|
|
|
word = leftover |
|
|
|
|
else: |
|
|
|
|
continue |
|
|
|
|
#parse word |
|
|
|
|
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) |
|
|
|
|
|
|
|
|
|
#reshape token array to CLIP input size |
|
|
|
|
batched_tokens = [] |
|
|
|
|
batch = [(self.start_token, 1.0, 0)] |
|
|
|
|
batched_tokens.append(batch) |
|
|
|
|
for i, t_group in enumerate(tokens): |
|
|
|
|
#determine if we're going to try and keep the tokens in a single batch |
|
|
|
|
is_large = len(t_group) >= self.max_word_length |
|
|
|
|
|
|
|
|
|
while len(t_group) > 0: |
|
|
|
|
if len(t_group) + len(batch) > self.max_length - 1: |
|
|
|
|
remaining_length = self.max_length - len(batch) - 1 |
|
|
|
|
#break word in two and add end token |
|
|
|
|
if is_large: |
|
|
|
|
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) |
|
|
|
|
batch.append((self.end_token, 1.0, 0)) |
|
|
|
|
t_group = t_group[remaining_length:] |
|
|
|
|
#add end token and pad |
|
|
|
|
else: |
|
|
|
|
print("warning, embedding:{} does not exist, ignoring".format(embedding_name)) |
|
|
|
|
elif len(word) > 0: |
|
|
|
|
tt = self.tokenizer(word)["input_ids"][1:-1] |
|
|
|
|
for x in tt: |
|
|
|
|
temp_tokens += [(x, t[1])] |
|
|
|
|
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section) |
|
|
|
|
|
|
|
|
|
#try not to split words in different sections |
|
|
|
|
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length): |
|
|
|
|
for x in range(tokens_left): |
|
|
|
|
tokens += [(self.end_token, 1.0)] |
|
|
|
|
tokens += temp_tokens |
|
|
|
|
batch.append((self.end_token, 1.0, 0)) |
|
|
|
|
batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) |
|
|
|
|
#start new batch |
|
|
|
|
batch = [(self.start_token, 1.0, 0)] |
|
|
|
|
batched_tokens.append(batch) |
|
|
|
|
else: |
|
|
|
|
batch.extend([(t,w,i+1) for t,w in t_group]) |
|
|
|
|
t_group = [] |
|
|
|
|
|
|
|
|
|
#fill last batch |
|
|
|
|
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) |
|
|
|
|
|
|
|
|
|
out_tokens = [] |
|
|
|
|
for x in range(0, len(tokens), self.max_tokens_per_section): |
|
|
|
|
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))] |
|
|
|
|
o_token += [(self.end_token, 1.0)] |
|
|
|
|
if self.pad_with_end: |
|
|
|
|
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token)) |
|
|
|
|
else: |
|
|
|
|
o_token +=[(0, 1.0)] * (self.max_length - len(o_token)) |
|
|
|
|
if not return_word_ids: |
|
|
|
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] |
|
|
|
|
|
|
|
|
|
out_tokens += [o_token] |
|
|
|
|
return batched_tokens |
|
|
|
|
|
|
|
|
|
return out_tokens |
|
|
|
|
|
|
|
|
|
def untokenize(self, token_weight_pair): |
|
|
|
|
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) |
|
|
|
|