|
|
|
@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
|
|
|
|
|
def set_up_textual_embeddings(self, tokens, current_embeds): |
|
|
|
|
out_tokens = [] |
|
|
|
|
next_new_token = token_dict_size = current_embeds.weight.shape[0] |
|
|
|
|
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1 |
|
|
|
|
embedding_weights = [] |
|
|
|
|
|
|
|
|
|
for x in tokens: |
|
|
|
|
tokens_temp = [] |
|
|
|
|
for y in x: |
|
|
|
|
if isinstance(y, int): |
|
|
|
|
if y == token_dict_size: #EOS token |
|
|
|
|
y = -1 |
|
|
|
|
tokens_temp += [y] |
|
|
|
|
else: |
|
|
|
|
if y.shape[0] == current_embeds.weight.shape[1]: |
|
|
|
@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
tokens_temp += [self.empty_tokens[0][-1]] |
|
|
|
|
out_tokens += [tokens_temp] |
|
|
|
|
|
|
|
|
|
if len(embedding_weights) > 0: |
|
|
|
|
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) |
|
|
|
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:] |
|
|
|
|
n = token_dict_size |
|
|
|
|
if len(embedding_weights) > 0: |
|
|
|
|
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) |
|
|
|
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1] |
|
|
|
|
for x in embedding_weights: |
|
|
|
|
new_embedding.weight[n] = x |
|
|
|
|
n += 1 |
|
|
|
|
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding |
|
|
|
|
self.transformer.set_input_embeddings(new_embedding) |
|
|
|
|
return out_tokens |
|
|
|
|
|
|
|
|
|
processed_tokens = [] |
|
|
|
|
for x in out_tokens: |
|
|
|
|
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one |
|
|
|
|
|
|
|
|
|
return processed_tokens |
|
|
|
|
|
|
|
|
|
def forward(self, tokens): |
|
|
|
|
backup_embeds = self.transformer.get_input_embeddings() |
|
|
|
|