|
|
|
@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
|
|
|
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation |
|
|
|
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule |
|
|
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep |
|
|
|
|
import comfy.model_management |
|
|
|
|
import numpy as np |
|
|
|
|
from enum import Enum |
|
|
|
|
from . import utils |
|
|
|
@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): |
|
|
|
|
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) |
|
|
|
|
unet_state_dict = self.diffusion_model.state_dict() |
|
|
|
|
unet_sd = self.diffusion_model.state_dict() |
|
|
|
|
unet_state_dict = {} |
|
|
|
|
for k in unet_sd: |
|
|
|
|
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) |
|
|
|
|
|
|
|
|
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) |
|
|
|
|
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) |
|
|
|
|
if self.get_dtype() == torch.float16: |
|
|
|
|