From 1b3d65bd84c8026dea234643861491279886218c Mon Sep 17 00:00:00 2001 From: realazthat Date: Thu, 11 Jan 2024 08:38:18 -0500 Subject: [PATCH] Add error, status to /history endpoint --- execution.py | 70 +++++++++++++++++++++++++++++++++------------------- main.py | 7 +++++- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/execution.py b/execution.py index 260a0897..554e8dc1 100644 --- a/execution.py +++ b/execution.py @@ -8,6 +8,7 @@ import heapq import traceback import gc import inspect +from typing import List, Literal, NamedTuple, Optional import torch import nodes @@ -275,8 +276,15 @@ class PromptExecutor: self.outputs = {} self.object_storage = {} self.outputs_ui = {} + self.status_notes = [] + self.success = True self.old_prompt = {} + def add_note(self, event, data, broadcast: bool): + self.status_notes.append((event, data)) + if self.server.client_id is not None or broadcast: + self.server.send_sync(event, data, self.server.client_id) + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] class_type = prompt[node_id]["class_type"] @@ -290,23 +298,22 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), } - self.server.send_sync("execution_interrupted", mes, self.server.client_id) + self.add_note("execution_interrupted", mes, broadcast=True) else: - if self.server.client_id is not None: - mes = { - "prompt_id": prompt_id, - "node_id": node_id, - "node_type": class_type, - "executed": list(executed), - - "exception_message": error["exception_message"], - "exception_type": error["exception_type"], - "traceback": error["traceback"], - "current_inputs": error["current_inputs"], - "current_outputs": error["current_outputs"], - } - self.server.send_sync("execution_error", mes, self.server.client_id) + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + "exception_message": error["exception_message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.add_note("execution_error", mes, broadcast=False) + # Next, remove the subsequent outputs since they will not be executed to_delete = [] for o in self.outputs: @@ -327,8 +334,7 @@ class PromptExecutor: else: self.server.client_id = None - if self.server.client_id is not None: - self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + self.add_note("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): #delete cached outputs if nodes don't exist for them @@ -361,8 +367,9 @@ class PromptExecutor: del d comfy.model_management.cleanup_models() - if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + self.add_note("execution_cached", + { "nodes": list(current_outputs) , "prompt_id": prompt_id}, + broadcast=False) executed = set() output_node_id = None to_execute = [] @@ -378,8 +385,8 @@ class PromptExecutor: # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if success is not True: + self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) + if self.success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) break @@ -731,14 +738,27 @@ class PromptQueue: self.server.queue_updated() return (item, i) - def task_done(self, item_id, outputs): + class ExecutionStatus(NamedTuple): + status_str: Literal['success', 'error'] + completed: bool + notes: List[str] + + def task_done(self, item_id, outputs, + status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) - self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } - for o in outputs: - self.history[prompt[1]]["outputs"][o] = outputs[o] + + status_dict: dict|None = None + if status is not None: + status_dict = copy.deepcopy(status._asdict()) + + self.history[prompt[1]] = { + "prompt": prompt, + "outputs": copy.deepcopy(outputs), + 'status': status_dict, + } self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 45d5d41b..a40ad2a4 100644 --- a/main.py +++ b/main.py @@ -110,7 +110,12 @@ def prompt_worker(q, server): e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True - q.task_done(item_id, e.outputs_ui) + q.task_done(item_id, + e.outputs_ui, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + notes=e.status_notes)) if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)