Browse Source

Slightly cleaner code.

pull/3/head
comfyanonymous 2 years ago
parent
commit
73f60740c8
  1. 9
      comfy/sd.py

9
comfy/sd.py

@ -32,11 +32,10 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.") y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
sd[y] = sd.pop(x) sd[y] = sd.pop(x)
try: if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd: ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids']
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'].round() if ids.dtype == torch.float32:
except: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
pass
for x in load_state_dict_to: for x in load_state_dict_to:
x.load_state_dict(sd, strict=False) x.load_state_dict(sd, strict=False)

Loading…
Cancel
Save