diff --git a/execution.py b/execution.py index 53ba2e0f..260a0897 100644 --- a/execution.py +++ b/execution.py @@ -268,11 +268,14 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): + self.server = server + self.reset() + + def reset(self): self.outputs = {} self.object_storage = {} self.outputs_ui = {} self.old_prompt = {} - self.server = server def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] @@ -706,6 +709,7 @@ class PromptQueue: self.queue = [] self.currently_running = {} self.history = {} + self.flags = {} server.prompt_queue = self def put(self, item): @@ -792,3 +796,17 @@ class PromptQueue: def delete_history_item(self, id_to_delete): with self.mutex: self.history.pop(id_to_delete, None) + + def set_flag(self, name, data): + with self.mutex: + self.flags[name] = data + self.not_empty.notify() + + def get_flags(self, reset=True): + with self.mutex: + if reset: + ret = self.flags + self.flags = {} + return ret + else: + return self.flags.copy() diff --git a/main.py b/main.py index bcf73872..45d5d41b 100644 --- a/main.py +++ b/main.py @@ -97,7 +97,7 @@ def prompt_worker(q, server): gc_collect_interval = 10.0 while True: - timeout = None + timeout = 1000.0 if need_gc: timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) @@ -118,6 +118,19 @@ def prompt_worker(q, server): execution_time = current_time - execution_start_time print("Prompt executed in {:.2f} seconds".format(execution_time)) + flags = q.get_flags() + free_memory = flags.get("free_memory", False) + + if flags.get("unload_models", free_memory): + comfy.model_management.unload_all_models() + need_gc = True + last_gc_collect = 0 + + if free_memory: + e.reset() + need_gc = True + last_gc_collect = 0 + if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: diff --git a/server.py b/server.py index bd6f026b..acb8875d 100644 --- a/server.py +++ b/server.py @@ -507,6 +507,17 @@ class PromptServer(): nodes.interrupt_processing() return web.Response(status=200) + @routes.post("/free") + async def post_interrupt(request): + json_data = await request.json() + unload_models = json_data.get("unload_models", False) + free_memory = json_data.get("free_memory", False) + if unload_models: + self.prompt_queue.set_flag("unload_models", unload_models) + if free_memory: + self.prompt_queue.set_flag("free_memory", free_memory) + return web.Response(status=200) + @routes.post("/history") async def post_history(request): json_data = await request.json()