diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 998babe8..93036b1a 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -168,19 +168,28 @@ def unescape_important(text): return text def load_embed(embedding_name, embedding_directory): - embed_path = os.path.join(embedding_directory, embedding_name) - if not os.path.isfile(embed_path): - extensions = ['.safetensors', '.pt', '.bin'] - valid_file = None - for x in extensions: - t = embed_path + x - if os.path.isfile(t): - valid_file = t - break - if valid_file is None: - return None + if isinstance(embedding_directory, str): + embedding_directory = [embedding_directory] + + valid_file = None + for embed_dir in embedding_directory: + embed_path = os.path.join(embed_dir, embedding_name) + if not os.path.isfile(embed_path): + extensions = ['.safetensors', '.pt', '.bin'] + for x in extensions: + t = embed_path + x + if os.path.isfile(t): + valid_file = t + break else: - embed_path = valid_file + valid_file = embed_path + if valid_file is not None: + break + + if valid_file is None: + return None + + embed_path = valid_file if embed_path.lower().endswith(".safetensors"): import safetensors.torch diff --git a/folder_paths.py b/folder_paths.py index ba506744..af56a6da 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -22,7 +22,7 @@ folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) -# folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) +folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) @@ -33,6 +33,8 @@ def add_model_folder_path(folder_name, full_folder_path): if folder_name in folder_names_and_paths: folder_names_and_paths[folder_name][0].append(full_folder_path) +def get_folder_paths(folder_name): + return folder_names_and_paths[folder_name][0][:] def recursive_search(directory): result = [] diff --git a/nodes.py b/nodes.py index 0beacee1..7589a0ab 100644 --- a/nodes.py +++ b/nodes.py @@ -188,9 +188,6 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) class CheckpointLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - embedding_directory = os.path.join(models_dir, "embeddings") - @classmethod def INPUT_TYPES(s): return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), @@ -203,7 +200,7 @@ class CheckpointLoader: def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): config_path = folder_paths.get_full_path("configs", config_name) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory) + return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) class CheckpointLoaderSimple: @classmethod @@ -217,7 +214,7 @@ class CheckpointLoaderSimple: def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out class CLIPSetLastLayer: