Browse Source

Support loading unet files in diffusers format.

pull/840/head
comfyanonymous 1 year ago
parent
commit
af7a49916b
  1. 3
      comfy/diffusers_load.py
  2. 8
      comfy/model_detection.py
  3. 69
      comfy/sd.py
  4. 8
      comfy/supported_models.py
  5. 10
      comfy/supported_models_base.py
  6. 21
      comfy/utils.py
  7. 1
      folder_paths.py
  8. 0
      models/unet/put_unet_files_here
  9. 18
      nodes.py

3
comfy/diffusers_load.py

@ -8,7 +8,8 @@ import os.path as osp
import re import re
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
import diffusers_convert from . import diffusers_convert
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))

8
comfy/model_detection.py

@ -108,11 +108,13 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
unet_config["context_dim"] = context_dim unet_config["context_dim"] = context_dim
return unet_config return unet_config
def model_config_from_unet_config(unet_config):
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
for model_config in supported_models.models: for model_config in supported_models.models:
if model_config.matches(unet_config): if model_config.matches(unet_config):
return model_config(unet_config) return model_config(unet_config)
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
return model_config_from_unet_config(unet_config)

69
comfy/sd.py

@ -1049,7 +1049,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd) model = model_config.get_model(sd, "model.diffusion_model.")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
@ -1073,6 +1073,73 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format
sd = utils.load_torch_file(unet_path)
parameters = calculate_parameters(sd, "")
fp16 = model_management.should_use_fp16(model_params=parameters)
match = {}
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
match["model_channels"] = sd["conv_in.weight"].shape[0]
match["in_channels"] = sd["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in sd:
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
print("match", match)
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
matches = False
break
if matches:
diffusers_keys = utils.unet_to_diffusers(unet_config)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config = model_detection.model_config_from_unet_config(unet_config)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
try: try:
model.patch_model() model.patch_model()

8
comfy/supported_models.py

@ -53,9 +53,9 @@ class SD20(supported_models_base.BASE):
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
def v_prediction(self, state_dict): def v_prediction(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k] out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return True return True
@ -109,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
def get_model(self, state_dict): def get_model(self, state_dict, prefix=""):
return model_base.SDXLRefiner(self) return model_base.SDXLRefiner(self)
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
@ -144,7 +144,7 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
def get_model(self, state_dict): def get_model(self, state_dict, prefix=""):
return model_base.SDXL(self) return model_base.SDXL(self)
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):

10
comfy/supported_models_base.py

@ -41,7 +41,7 @@ class BASE:
return False return False
return True return True
def v_prediction(self, state_dict): def v_prediction(self, state_dict, prefix=""):
return False return False
def inpaint_model(self): def inpaint_model(self):
@ -53,13 +53,13 @@ class BASE:
for x in self.unet_extra_config: for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x] self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict): def get_model(self, state_dict, prefix=""):
if self.inpaint_model(): if self.inpaint_model():
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict)) return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
elif self.noise_aug_config is not None: elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict)) return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
else: else:
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict)) return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
return state_dict return state_dict

21
comfy/utils.py

@ -117,6 +117,23 @@ UNET_MAP_RESNET = {
"out_layers.0.bias": "norm2.bias", "out_layers.0.bias": "norm2.bias",
} }
UNET_MAP_BASIC = {
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
"input_blocks.0.0.weight": "conv_in.weight",
"input_blocks.0.0.bias": "conv_in.bias",
"out.0.weight": "conv_norm_out.weight",
"out.0.bias": "conv_norm_out.bias",
"out.2.weight": "conv_out.weight",
"out.2.bias": "conv_out.bias",
"time_embed.0.weight": "time_embedding.linear_1.weight",
"time_embed.0.bias": "time_embedding.linear_1.bias",
"time_embed.2.weight": "time_embedding.linear_2.weight",
"time_embed.2.bias": "time_embedding.linear_2.bias"
}
def unet_to_diffusers(unet_config): def unet_to_diffusers(unet_config):
num_res_blocks = unet_config["num_res_blocks"] num_res_blocks = unet_config["num_res_blocks"]
attention_resolutions = unet_config["attention_resolutions"] attention_resolutions = unet_config["attention_resolutions"]
@ -185,6 +202,10 @@ def unet_to_diffusers(unet_config):
for k in ["weight", "bias"]: for k in ["weight", "bias"]:
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
n += 1 n += 1
for k in UNET_MAP_BASIC:
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
return diffusers_unet_map return diffusers_unet_map
def convert_sd_to(state_dict, dtype): def convert_sd_to(state_dict, dtype):

1
folder_paths.py

@ -14,6 +14,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)

0
models/unet/put_unet_files_here

18
nodes.py

@ -397,7 +397,7 @@ class DiffusersLoader:
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders/deprecated"
def load_checkpoint(self, model_path, output_vae=True, output_clip=True): def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"): for search_path in folder_paths.get_folder_paths("diffusers"):
@ -552,6 +552,21 @@ class ControlNetApply:
c.append(n) c.append(n)
return (c, ) return (c, )
class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name):
unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path)
return (model,)
class CLIPLoader: class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1371,6 +1386,7 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop": LatentCrop, "LatentCrop": LatentCrop,
"LoraLoader": LoraLoader, "LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader, "CLIPLoader": CLIPLoader,
"UNETLoader": UNETLoader,
"DualCLIPLoader": DualCLIPLoader, "DualCLIPLoader": DualCLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode, "CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply, "StyleModelApply": StyleModelApply,

Loading…
Cancel
Save