Browse Source

Fix lowvram model merging.

pull/1344/head
comfyanonymous 1 year ago
parent
commit
a57b0c797b
  1. 7
      comfy/controlnet.py
  2. 7
      comfy/model_base.py
  3. 8
      comfy/model_management.py

7
comfy/controlnet.py

@ -257,12 +257,7 @@ class ControlLora(ControlNet):
cm = self.control_model.state_dict()
for k in sd:
weight = sd[k]
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
try:
comfy.utils.set_attr(self.control_model, k, weight)
except:

7
comfy/model_base.py

@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import numpy as np
from enum import Enum
from . import utils
@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_state_dict = self.diffusion_model.state_dict()
unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:

8
comfy/model_management.py

@ -1,6 +1,7 @@
import psutil
from enum import Enum
from comfy.cli_args import args
import comfy.utils
import torch
import sys
@ -637,6 +638,13 @@ def soft_empty_cache():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
return weight
#TODO: might be cleaner to put this somewhere else
import threading

Loading…
Cancel
Save