|
|
|
@ -99,6 +99,10 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
if self.get_dtype() == torch.float16: |
|
|
|
|
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) |
|
|
|
|
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) |
|
|
|
|
|
|
|
|
|
if self.model_type == ModelType.V_PREDICTION: |
|
|
|
|
unet_state_dict["v_pred"] = torch.tensor([]) |
|
|
|
|
|
|
|
|
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|