diff --git a/comfy/sd.py b/comfy/sd.py index fdb885bd..98bb4bdb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -32,11 +32,10 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.") sd[y] = sd.pop(x) - try: - if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd: - sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'].round() - except: - pass + if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd: + ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] + if ids.dtype == torch.float32: + sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() for x in load_state_dict_to: x.load_state_dict(sd, strict=False)