|
|
|
@ -10,21 +10,29 @@ import contextlib
|
|
|
|
|
|
|
|
|
|
class ClipTokenWeightEncoder: |
|
|
|
|
def encode_token_weights(self, token_weight_pairs): |
|
|
|
|
z_empty, _ = self.encode(self.empty_tokens) |
|
|
|
|
output = [] |
|
|
|
|
first_pooled = None |
|
|
|
|
to_encode = list(self.empty_tokens) |
|
|
|
|
for x in token_weight_pairs: |
|
|
|
|
tokens = [list(map(lambda a: a[0], x))] |
|
|
|
|
z, pooled = self.encode(tokens) |
|
|
|
|
if first_pooled is None: |
|
|
|
|
first_pooled = pooled |
|
|
|
|
tokens = list(map(lambda a: a[0], x)) |
|
|
|
|
to_encode.append(tokens) |
|
|
|
|
|
|
|
|
|
out, pooled = self.encode(to_encode) |
|
|
|
|
z_empty = out[0:1] |
|
|
|
|
if pooled.shape[0] > 1: |
|
|
|
|
first_pooled = pooled[1:2] |
|
|
|
|
else: |
|
|
|
|
first_pooled = pooled[0:1] |
|
|
|
|
|
|
|
|
|
output = [] |
|
|
|
|
for i in range(1, out.shape[0]): |
|
|
|
|
z = out[i:i+1] |
|
|
|
|
for i in range(len(z)): |
|
|
|
|
for j in range(len(z[i])): |
|
|
|
|
weight = x[j][1] |
|
|
|
|
weight = token_weight_pairs[i - 1][j][1] |
|
|
|
|
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] |
|
|
|
|
output += [z] |
|
|
|
|
output.append(z) |
|
|
|
|
|
|
|
|
|
if (len(output) == 0): |
|
|
|
|
return self.encode(self.empty_tokens) |
|
|
|
|
return z_empty, first_pooled |
|
|
|
|
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() |
|
|
|
|
|
|
|
|
|
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): |
|
|
|
|