|
|
|
@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|
|
|
|
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp |
|
|
|
|
code2idx = {"q": 0, "k": 1, "v": 2} |
|
|
|
|
|
|
|
|
|
# This function exists because at the time of writing torch.cat can't do fp8 with cuda |
|
|
|
|
def cat_tensors(tensors): |
|
|
|
|
x = 0 |
|
|
|
|
for t in tensors: |
|
|
|
|
x += t.shape[0] |
|
|
|
|
|
|
|
|
|
shape = [x] + list(tensors[0].shape)[1:] |
|
|
|
|
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) |
|
|
|
|
|
|
|
|
|
x = 0 |
|
|
|
|
for t in tensors: |
|
|
|
|
out[x:x + t.shape[0]] = t |
|
|
|
|
x += t.shape[0] |
|
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): |
|
|
|
|
new_state_dict = {} |
|
|
|
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|
|
|
|
if None in tensors: |
|
|
|
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") |
|
|
|
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) |
|
|
|
|
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) |
|
|
|
|
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors) |
|
|
|
|
|
|
|
|
|
for k_pre, tensors in capture_qkv_bias.items(): |
|
|
|
|
if None in tensors: |
|
|
|
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") |
|
|
|
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) |
|
|
|
|
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) |
|
|
|
|
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors) |
|
|
|
|
|
|
|
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
|