From af7a49916b9a302c4ae0bf52838d6b20575378db Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Jul 2023 17:34:45 -0400 Subject: [PATCH] Support loading unet files in diffusers format. --- comfy/diffusers_load.py | 3 +- comfy/model_detection.py | 8 ++-- comfy/sd.py | 69 ++++++++++++++++++++++++++++++++- comfy/supported_models.py | 8 ++-- comfy/supported_models_base.py | 10 ++--- comfy/utils.py | 21 ++++++++++ folder_paths.py | 1 + models/unet/put_unet_files_here | 0 nodes.py | 18 ++++++++- 9 files changed, 123 insertions(+), 15 deletions(-) create mode 100644 models/unet/put_unet_files_here diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index ba04b981..11d94c34 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -8,7 +8,8 @@ import os.path as osp import re import torch from safetensors.torch import load_file, save_file -import diffusers_convert +from . import diffusers_convert + def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index edad48b1..cf764e0b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -108,11 +108,13 @@ def detect_unet_config(state_dict, key_prefix, use_fp16): unet_config["context_dim"] = context_dim return unet_config - -def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): - unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) +def model_config_from_unet_config(unet_config): for model_config in supported_models.models: if model_config.matches(unet_config): return model_config(unet_config) return None + +def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): + unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) + return model_config_from_unet_config(unet_config) diff --git a/comfy/sd.py b/comfy/sd.py index 360f2962..d3bcf22a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1049,7 +1049,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd) + model = model_config.get_model(sd, "model.diffusion_model.") model = model.to(offload_device) model.load_model_weights(sd, "model.diffusion_model.") @@ -1073,6 +1073,73 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) + +def load_unet(unet_path): #load unet in diffusers format + sd = utils.load_torch_file(unet_path) + parameters = calculate_parameters(sd, "") + fp16 = model_management.should_use_fp16(model_params=parameters) + + match = {} + match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] + match["model_channels"] = sd["conv_in.weight"].shape[0] + match["in_channels"] = sd["conv_in.weight"].shape[1] + match["adm_in_channels"] = None + if "class_embedding.linear_1.weight" in sd: + match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1] + + SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + + SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} + + SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] + print("match", match) + for unet_config in supported_models: + matches = True + for k in match: + if match[k] != unet_config[k]: + matches = False + break + if matches: + diffusers_keys = utils.unet_to_diffusers(unet_config) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + print(diffusers_keys[k], k) + offload_device = model_management.unet_offload_device() + model_config = model_detection.model_config_from_unet_config(unet_config) + model = model_config.get_model(new_sd, "") + model = model.to(offload_device) + model.load_model_weights(new_sd, "") + return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + def save_checkpoint(output_path, model, clip, vae, metadata=None): try: model.patch_model() diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 38a53ca7..b1beee8c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -53,9 +53,9 @@ class SD20(supported_models_base.BASE): latent_format = latent_formats.SD15 - def v_prediction(self, state_dict): + def v_prediction(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction - k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" + k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) out = state_dict[k] if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. return True @@ -109,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE): latent_format = latent_formats.SDXL - def get_model(self, state_dict): + def get_model(self, state_dict, prefix=""): return model_base.SDXLRefiner(self) def process_clip_state_dict(self, state_dict): @@ -144,7 +144,7 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL - def get_model(self, state_dict): + def get_model(self, state_dict, prefix=""): return model_base.SDXL(self) def process_clip_state_dict(self, state_dict): diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 0b0235ca..86dc6706 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -41,7 +41,7 @@ class BASE: return False return True - def v_prediction(self, state_dict): + def v_prediction(self, state_dict, prefix=""): return False def inpaint_model(self): @@ -53,13 +53,13 @@ class BASE: for x in self.unet_extra_config: self.unet_config[x] = self.unet_extra_config[x] - def get_model(self, state_dict): + def get_model(self, state_dict, prefix=""): if self.inpaint_model(): - return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict)) + return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix)) elif self.noise_aug_config is not None: - return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict)) + return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix)) else: - return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict)) + return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix)) def process_clip_state_dict(self, state_dict): return state_dict diff --git a/comfy/utils.py b/comfy/utils.py index 25ccd944..c4c15def 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -117,6 +117,23 @@ UNET_MAP_RESNET = { "out_layers.0.bias": "norm2.bias", } +UNET_MAP_BASIC = { + "label_emb.0.0.weight": "class_embedding.linear_1.weight", + "label_emb.0.0.bias": "class_embedding.linear_1.bias", + "label_emb.0.2.weight": "class_embedding.linear_2.weight", + "label_emb.0.2.bias": "class_embedding.linear_2.bias", + "input_blocks.0.0.weight": "conv_in.weight", + "input_blocks.0.0.bias": "conv_in.bias", + "out.0.weight": "conv_norm_out.weight", + "out.0.bias": "conv_norm_out.bias", + "out.2.weight": "conv_out.weight", + "out.2.bias": "conv_out.bias", + "time_embed.0.weight": "time_embedding.linear_1.weight", + "time_embed.0.bias": "time_embedding.linear_1.bias", + "time_embed.2.weight": "time_embedding.linear_2.weight", + "time_embed.2.bias": "time_embedding.linear_2.bias" +} + def unet_to_diffusers(unet_config): num_res_blocks = unet_config["num_res_blocks"] attention_resolutions = unet_config["attention_resolutions"] @@ -185,6 +202,10 @@ def unet_to_diffusers(unet_config): for k in ["weight", "bias"]: diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) n += 1 + + for k in UNET_MAP_BASIC: + diffusers_unet_map[UNET_MAP_BASIC[k]] = k + return diffusers_unet_map def convert_sd_to(state_dict, dtype): diff --git a/folder_paths.py b/folder_paths.py index 2ad1b171..eb7d39b8 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -14,6 +14,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) +folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) diff --git a/models/unet/put_unet_files_here b/models/unet/put_unet_files_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index b22cdfc9..490389ca 100644 --- a/nodes.py +++ b/nodes.py @@ -397,7 +397,7 @@ class DiffusersLoader: RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" - CATEGORY = "advanced/loaders" + CATEGORY = "advanced/loaders/deprecated" def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): @@ -552,6 +552,21 @@ class ControlNetApply: c.append(n) return (c, ) +class UNETLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_unet" + + CATEGORY = "advanced/loaders" + + def load_unet(self, unet_name): + unet_path = folder_paths.get_full_path("unet", unet_name) + model = comfy.sd.load_unet(unet_path) + return (model,) + class CLIPLoader: @classmethod def INPUT_TYPES(s): @@ -1371,6 +1386,7 @@ NODE_CLASS_MAPPINGS = { "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "UNETLoader": UNETLoader, "DualCLIPLoader": DualCLIPLoader, "CLIPVisionEncode": CLIPVisionEncode, "StyleModelApply": StyleModelApply,