|
|
|
@ -8,6 +8,7 @@ import heapq
|
|
|
|
|
import traceback |
|
|
|
|
import gc |
|
|
|
|
import inspect |
|
|
|
|
from typing import List, Literal, NamedTuple, Optional |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import nodes |
|
|
|
@ -275,8 +276,15 @@ class PromptExecutor:
|
|
|
|
|
self.outputs = {} |
|
|
|
|
self.object_storage = {} |
|
|
|
|
self.outputs_ui = {} |
|
|
|
|
self.status_notes = [] |
|
|
|
|
self.success = True |
|
|
|
|
self.old_prompt = {} |
|
|
|
|
|
|
|
|
|
def add_note(self, event, data, broadcast: bool): |
|
|
|
|
self.status_notes.append((event, data)) |
|
|
|
|
if self.server.client_id is not None or broadcast: |
|
|
|
|
self.server.send_sync(event, data, self.server.client_id) |
|
|
|
|
|
|
|
|
|
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): |
|
|
|
|
node_id = error["node_id"] |
|
|
|
|
class_type = prompt[node_id]["class_type"] |
|
|
|
@ -290,23 +298,22 @@ class PromptExecutor:
|
|
|
|
|
"node_type": class_type, |
|
|
|
|
"executed": list(executed), |
|
|
|
|
} |
|
|
|
|
self.server.send_sync("execution_interrupted", mes, self.server.client_id) |
|
|
|
|
self.add_note("execution_interrupted", mes, broadcast=True) |
|
|
|
|
else: |
|
|
|
|
if self.server.client_id is not None: |
|
|
|
|
mes = { |
|
|
|
|
"prompt_id": prompt_id, |
|
|
|
|
"node_id": node_id, |
|
|
|
|
"node_type": class_type, |
|
|
|
|
"executed": list(executed), |
|
|
|
|
|
|
|
|
|
"exception_message": error["exception_message"], |
|
|
|
|
"exception_type": error["exception_type"], |
|
|
|
|
"traceback": error["traceback"], |
|
|
|
|
"current_inputs": error["current_inputs"], |
|
|
|
|
"current_outputs": error["current_outputs"], |
|
|
|
|
} |
|
|
|
|
self.server.send_sync("execution_error", mes, self.server.client_id) |
|
|
|
|
mes = { |
|
|
|
|
"prompt_id": prompt_id, |
|
|
|
|
"node_id": node_id, |
|
|
|
|
"node_type": class_type, |
|
|
|
|
"executed": list(executed), |
|
|
|
|
|
|
|
|
|
"exception_message": error["exception_message"], |
|
|
|
|
"exception_type": error["exception_type"], |
|
|
|
|
"traceback": error["traceback"], |
|
|
|
|
"current_inputs": error["current_inputs"], |
|
|
|
|
"current_outputs": error["current_outputs"], |
|
|
|
|
} |
|
|
|
|
self.add_note("execution_error", mes, broadcast=False) |
|
|
|
|
|
|
|
|
|
# Next, remove the subsequent outputs since they will not be executed |
|
|
|
|
to_delete = [] |
|
|
|
|
for o in self.outputs: |
|
|
|
@ -327,8 +334,7 @@ class PromptExecutor:
|
|
|
|
|
else: |
|
|
|
|
self.server.client_id = None |
|
|
|
|
|
|
|
|
|
if self.server.client_id is not None: |
|
|
|
|
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) |
|
|
|
|
self.add_note("execution_start", { "prompt_id": prompt_id}, broadcast=False) |
|
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
#delete cached outputs if nodes don't exist for them |
|
|
|
@ -361,8 +367,9 @@ class PromptExecutor:
|
|
|
|
|
del d |
|
|
|
|
|
|
|
|
|
comfy.model_management.cleanup_models() |
|
|
|
|
if self.server.client_id is not None: |
|
|
|
|
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) |
|
|
|
|
self.add_note("execution_cached", |
|
|
|
|
{ "nodes": list(current_outputs) , "prompt_id": prompt_id}, |
|
|
|
|
broadcast=False) |
|
|
|
|
executed = set() |
|
|
|
|
output_node_id = None |
|
|
|
|
to_execute = [] |
|
|
|
@ -378,8 +385,8 @@ class PromptExecutor:
|
|
|
|
|
# This call shouldn't raise anything if there's an error deep in |
|
|
|
|
# the actual SD code, instead it will report the node where the |
|
|
|
|
# error was raised |
|
|
|
|
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) |
|
|
|
|
if success is not True: |
|
|
|
|
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) |
|
|
|
|
if self.success is not True: |
|
|
|
|
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) |
|
|
|
|
break |
|
|
|
|
|
|
|
|
@ -731,14 +738,27 @@ class PromptQueue:
|
|
|
|
|
self.server.queue_updated() |
|
|
|
|
return (item, i) |
|
|
|
|
|
|
|
|
|
def task_done(self, item_id, outputs): |
|
|
|
|
class ExecutionStatus(NamedTuple): |
|
|
|
|
status_str: Literal['success', 'error'] |
|
|
|
|
completed: bool |
|
|
|
|
notes: List[str] |
|
|
|
|
|
|
|
|
|
def task_done(self, item_id, outputs, |
|
|
|
|
status: Optional['PromptQueue.ExecutionStatus']): |
|
|
|
|
with self.mutex: |
|
|
|
|
prompt = self.currently_running.pop(item_id) |
|
|
|
|
if len(self.history) > MAXIMUM_HISTORY_SIZE: |
|
|
|
|
self.history.pop(next(iter(self.history))) |
|
|
|
|
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } |
|
|
|
|
for o in outputs: |
|
|
|
|
self.history[prompt[1]]["outputs"][o] = outputs[o] |
|
|
|
|
|
|
|
|
|
status_dict: dict|None = None |
|
|
|
|
if status is not None: |
|
|
|
|
status_dict = copy.deepcopy(status._asdict()) |
|
|
|
|
|
|
|
|
|
self.history[prompt[1]] = { |
|
|
|
|
"prompt": prompt, |
|
|
|
|
"outputs": copy.deepcopy(outputs), |
|
|
|
|
'status': status_dict, |
|
|
|
|
} |
|
|
|
|
self.server.queue_updated() |
|
|
|
|
|
|
|
|
|
def get_current_queue(self): |
|
|
|
|