diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index fdaa1e6c..4761230a 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -343,17 +343,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") - self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) + self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length - self.max_tokens_per_section = self.max_length - 2 empty = self.tokenizer('')["input_ids"] - self.start_token = empty[0] - self.end_token = empty[1] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() self.inv_vocab = {v: k for k, v in vocab.items()} self.embedding_directory = embedding_directory @@ -414,11 +421,13 @@ class SDTokenizer: else: continue #parse word - tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) + tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) #reshape token array to CLIP input size batched_tokens = [] - batch = [(self.start_token, 1.0, 0)] + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0, 0)) batched_tokens.append(batch) for i, t_group in enumerate(tokens): #determine if we're going to try and keep the tokens in a single batch @@ -435,16 +444,21 @@ class SDTokenizer: #add end token and pad else: batch.append((self.end_token, 1.0, 0)) - batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) #start new batch - batch = [(self.start_token, 1.0, 0)] + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0, 0)) batched_tokens.append(batch) else: batch.extend([(t,w,i+1) for t,w in t_group]) t_group = [] #fill last batch - batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) + batch.append((self.end_token, 1.0, 0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]