|
|
|
@ -40,6 +40,42 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
|
|
|
|
if ids.dtype == torch.float32: |
|
|
|
|
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() |
|
|
|
|
|
|
|
|
|
keys_to_replace = { |
|
|
|
|
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", |
|
|
|
|
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", |
|
|
|
|
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight", |
|
|
|
|
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for x in keys_to_replace: |
|
|
|
|
if x in sd: |
|
|
|
|
sd[keys_to_replace[x]] = sd.pop(x) |
|
|
|
|
|
|
|
|
|
resblock_to_replace = { |
|
|
|
|
"ln_1": "layer_norm1", |
|
|
|
|
"ln_2": "layer_norm2", |
|
|
|
|
"mlp.c_fc": "mlp.fc1", |
|
|
|
|
"mlp.c_proj": "mlp.fc2", |
|
|
|
|
"attn.out_proj": "self_attn.out_proj", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for resblock in range(24): |
|
|
|
|
for x in resblock_to_replace: |
|
|
|
|
for y in ["weight", "bias"]: |
|
|
|
|
k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y) |
|
|
|
|
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y) |
|
|
|
|
if k in sd: |
|
|
|
|
sd[k_to] = sd.pop(k) |
|
|
|
|
|
|
|
|
|
for y in ["weight", "bias"]: |
|
|
|
|
k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y) |
|
|
|
|
if k_from in sd: |
|
|
|
|
weights = sd.pop(k_from) |
|
|
|
|
for x in range(3): |
|
|
|
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] |
|
|
|
|
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y) |
|
|
|
|
sd[k_to] = weights[1024*x:1024*(x + 1)] |
|
|
|
|
|
|
|
|
|
for x in load_state_dict_to: |
|
|
|
|
x.load_state_dict(sd, strict=False) |
|
|
|
|
|
|
|
|
@ -62,12 +98,6 @@ LORA_CLIP_MAP = {
|
|
|
|
|
"self_attn.out_proj": "self_attn_out_proj", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
LORA_CLIP2_MAP = { |
|
|
|
|
"mlp.c_fc": "mlp_fc1", |
|
|
|
|
"mlp.c_proj": "mlp_fc2", |
|
|
|
|
"attn.out_proj": "self_attn_out_proj", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
LORA_UNET_MAP = { |
|
|
|
|
"proj_in": "proj_in", |
|
|
|
|
"proj_out": "proj_out", |
|
|
|
@ -116,7 +146,7 @@ def model_lora_keys(model, key_map={}):
|
|
|
|
|
k = "{}.{}.weight".format(tk, c) |
|
|
|
|
if k in sdk: |
|
|
|
|
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c]) |
|
|
|
|
key_map[lora_key] = (k, 0) |
|
|
|
|
key_map[lora_key] = k |
|
|
|
|
up_counter += 1 |
|
|
|
|
if up_counter >= 4: |
|
|
|
|
counter += 1 |
|
|
|
@ -124,7 +154,7 @@ def model_lora_keys(model, key_map={}):
|
|
|
|
|
k = "model.diffusion_model.middle_block.1.{}.weight".format(c) |
|
|
|
|
if k in sdk: |
|
|
|
|
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c]) |
|
|
|
|
key_map[lora_key] = (k, 0) |
|
|
|
|
key_map[lora_key] = k |
|
|
|
|
counter = 3 |
|
|
|
|
for b in range(12): |
|
|
|
|
tk = "model.diffusion_model.output_blocks.{}.1".format(b) |
|
|
|
@ -133,29 +163,18 @@ def model_lora_keys(model, key_map={}):
|
|
|
|
|
k = "{}.{}.weight".format(tk, c) |
|
|
|
|
if k in sdk: |
|
|
|
|
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c]) |
|
|
|
|
key_map[lora_key] = (k, 0) |
|
|
|
|
key_map[lora_key] = k |
|
|
|
|
up_counter += 1 |
|
|
|
|
if up_counter >= 4: |
|
|
|
|
counter += 1 |
|
|
|
|
counter = 0 |
|
|
|
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" |
|
|
|
|
for b in range(12): |
|
|
|
|
for b in range(24): |
|
|
|
|
for c in LORA_CLIP_MAP: |
|
|
|
|
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) |
|
|
|
|
if k in sdk: |
|
|
|
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) |
|
|
|
|
key_map[lora_key] = (k, 0) |
|
|
|
|
for b in range(24): |
|
|
|
|
for c in LORA_CLIP2_MAP: |
|
|
|
|
k = "model.transformer.resblocks.{}.{}.weight".format(b, c) |
|
|
|
|
if k in sdk: |
|
|
|
|
lora_key = text_model_lora_key.format(b, LORA_CLIP2_MAP[c]) |
|
|
|
|
key_map[lora_key] = (k, 0) |
|
|
|
|
k = "model.transformer.resblocks.{}.attn.in_proj_weight".format(b) |
|
|
|
|
if k in sdk: |
|
|
|
|
key_map[text_model_lora_key.format(b, "self_attn_q_proj")] = (k, 0) |
|
|
|
|
key_map[text_model_lora_key.format(b, "self_attn_k_proj")] = (k, 1) |
|
|
|
|
key_map[text_model_lora_key.format(b, "self_attn_v_proj")] = (k, 2) |
|
|
|
|
key_map[lora_key] = k |
|
|
|
|
|
|
|
|
|
return key_map |
|
|
|
|
|
|
|
|
@ -174,7 +193,7 @@ class ModelPatcher:
|
|
|
|
|
p = {} |
|
|
|
|
model_sd = self.model.state_dict() |
|
|
|
|
for k in patches: |
|
|
|
|
if k[0] in model_sd: |
|
|
|
|
if k in model_sd: |
|
|
|
|
p[k] = patches[k] |
|
|
|
|
self.patches += [(strength, p)] |
|
|
|
|
return p.keys() |
|
|
|
@ -184,8 +203,7 @@ class ModelPatcher:
|
|
|
|
|
for p in self.patches: |
|
|
|
|
for k in p[1]: |
|
|
|
|
v = p[1][k] |
|
|
|
|
key = k[0] |
|
|
|
|
index = k[1] |
|
|
|
|
key = k |
|
|
|
|
if key not in model_sd: |
|
|
|
|
print("could not patch. key doesn't exist in model:", k) |
|
|
|
|
continue |
|
|
|
@ -199,10 +217,7 @@ class ModelPatcher:
|
|
|
|
|
mat2 = v[1] |
|
|
|
|
if v[2] is not None: |
|
|
|
|
alpha *= v[2] / mat2.shape[0] |
|
|
|
|
calc = (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())) |
|
|
|
|
if len(weight.shape) > 2: |
|
|
|
|
calc = calc.reshape(weight.shape) |
|
|
|
|
weight[index * mat1.shape[0]:(index + 1) * mat1.shape[0]] += calc.type(weight.dtype).to(weight.device) |
|
|
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) |
|
|
|
|
return self.model |
|
|
|
|
def unpatch_model(self): |
|
|
|
|
model_sd = self.model.state_dict() |
|
|
|
|