diff --git a/folder_paths.py b/folder_paths.py index af56a6da..f13e4895 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -27,6 +27,40 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")] folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") +temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") +input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + +if not os.path.exists(input_directory): + os.makedirs(input_directory) + +def set_output_directory(output_dir): + global output_directory + output_directory = output_dir + +def get_output_directory(): + global output_directory + return output_directory + +def get_temp_directory(): + global temp_directory + return temp_directory + +def get_input_directory(): + global input_directory + return input_directory + + +#NOTE: used in http server so don't put folders that should not be accessed remotely +def get_directory_by_type(type_name): + if type_name == "output": + return get_output_directory() + if type_name == "temp": + return get_temp_directory() + if type_name == "input": + return get_input_directory() + return None + def add_model_folder_path(folder_name, full_folder_path): global folder_names_and_paths diff --git a/main.py b/main.py index fbfaf6be..a3549b86 100644 --- a/main.py +++ b/main.py @@ -17,6 +17,7 @@ if __name__ == "__main__": print("\t--port 8188\t\t\tSet the listen port.") print() print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") + print("\t--output-directory path/to/output\tSet the ComfyUI output directory.") print() print() print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") @@ -134,6 +135,14 @@ if __name__ == "__main__": for i in indices: load_extra_path_config(sys.argv[i]) + try: + output_dir = sys.argv[sys.argv.index('--output-directory') + 1] + output_dir = os.path.abspath(output_dir) + print("setting output directory to:", output_dir) + folder_paths.set_output_directory(output_dir) + except: + pass + port = 8188 try: p_index = sys.argv.index('--port') diff --git a/nodes.py b/nodes.py index 935e28b8..187d54a1 100644 --- a/nodes.py +++ b/nodes.py @@ -777,7 +777,7 @@ class KSamplerAdvanced: class SaveImage: def __init__(self): - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") + self.output_dir = folder_paths.get_output_directory() self.type = "output" @classmethod @@ -829,9 +829,6 @@ class SaveImage: os.makedirs(full_output_folder, exist_ok=True) counter = 1 - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) - results = list() for image in images: i = 255. * image.cpu().numpy() @@ -856,7 +853,7 @@ class SaveImage: class PreviewImage(SaveImage): def __init__(self): - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") + self.output_dir = folder_paths.get_temp_directory() self.type = "temp" @classmethod @@ -867,13 +864,11 @@ class PreviewImage(SaveImage): } class LoadImage: - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") @classmethod def INPUT_TYPES(s): - if not os.path.exists(s.input_dir): - os.makedirs(s.input_dir) + input_dir = folder_paths.get_input_directory() return {"required": - {"image": (sorted(os.listdir(s.input_dir)), )}, + {"image": (sorted(os.listdir(input_dir)), )}, } CATEGORY = "image" @@ -881,7 +876,8 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): - image_path = os.path.join(self.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -895,18 +891,19 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - image_path = os.path.join(s.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() class LoadImageMask: - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") @classmethod def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() return {"required": - {"image": (sorted(os.listdir(s.input_dir)), ), + {"image": (sorted(os.listdir(input_dir)), ), "channel": (["alpha", "red", "green", "blue"], ),} } @@ -915,7 +912,8 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - image_path = os.path.join(self.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) i = Image.open(image_path) mask = None c = channel[0].upper() @@ -930,7 +928,8 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - image_path = os.path.join(s.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) diff --git a/server.py b/server.py index 963daeff..840d9a4e 100644 --- a/server.py +++ b/server.py @@ -89,7 +89,7 @@ class PromptServer(): @routes.post("/upload/image") async def upload_image(request): - upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + upload_dir = folder_paths.get_input_directory() if not os.path.exists(upload_dir): os.makedirs(upload_dir) @@ -122,10 +122,10 @@ class PromptServer(): async def view_image(request): if "filename" in request.rel_url.query: type = request.rel_url.query.get("type", "output") - if type not in ["output", "input", "temp"]: + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) - output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type) if "subfolder" in request.rel_url.query: full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: