|
|
|
@ -32,11 +32,10 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
|
|
|
|
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.") |
|
|
|
|
sd[y] = sd.pop(x) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd: |
|
|
|
|
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'].round() |
|
|
|
|
except: |
|
|
|
|
pass |
|
|
|
|
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd: |
|
|
|
|
ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] |
|
|
|
|
if ids.dtype == torch.float32: |
|
|
|
|
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() |
|
|
|
|
|
|
|
|
|
for x in load_state_dict_to: |
|
|
|
|
x.load_state_dict(sd, strict=False) |
|
|
|
|