|
|
|
@ -102,13 +102,19 @@ def get_output_data(obj, input_data_all):
|
|
|
|
|
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} |
|
|
|
|
return output, ui |
|
|
|
|
|
|
|
|
|
def format_value(x): |
|
|
|
|
if isinstance(x, (int, float, bool, str)): |
|
|
|
|
return x |
|
|
|
|
else: |
|
|
|
|
return str(x) |
|
|
|
|
|
|
|
|
|
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): |
|
|
|
|
unique_id = current_item |
|
|
|
|
inputs = prompt[unique_id]['inputs'] |
|
|
|
|
class_type = prompt[unique_id]['class_type'] |
|
|
|
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] |
|
|
|
|
if unique_id in outputs: |
|
|
|
|
return |
|
|
|
|
return (True, None, None) |
|
|
|
|
|
|
|
|
|
for x in inputs: |
|
|
|
|
input_data = inputs[x] |
|
|
|
@ -117,22 +123,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|
|
|
|
input_unique_id = input_data[0] |
|
|
|
|
output_index = input_data[1] |
|
|
|
|
if input_unique_id not in outputs: |
|
|
|
|
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) |
|
|
|
|
|
|
|
|
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) |
|
|
|
|
if server.client_id is not None: |
|
|
|
|
server.last_node_id = unique_id |
|
|
|
|
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) |
|
|
|
|
obj = class_def() |
|
|
|
|
|
|
|
|
|
output_data, output_ui = get_output_data(obj, input_data_all) |
|
|
|
|
outputs[unique_id] = output_data |
|
|
|
|
if len(output_ui) > 0: |
|
|
|
|
outputs_ui[unique_id] = output_ui |
|
|
|
|
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) |
|
|
|
|
if result[0] is not True: |
|
|
|
|
# Another node failed further upstream |
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
input_data_all = None |
|
|
|
|
try: |
|
|
|
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) |
|
|
|
|
if server.client_id is not None: |
|
|
|
|
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) |
|
|
|
|
server.last_node_id = unique_id |
|
|
|
|
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) |
|
|
|
|
obj = class_def() |
|
|
|
|
|
|
|
|
|
output_data, output_ui = get_output_data(obj, input_data_all) |
|
|
|
|
outputs[unique_id] = output_data |
|
|
|
|
if len(output_ui) > 0: |
|
|
|
|
outputs_ui[unique_id] = output_ui |
|
|
|
|
if server.client_id is not None: |
|
|
|
|
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) |
|
|
|
|
except comfy.model_management.InterruptProcessingException as iex: |
|
|
|
|
print("Processing interrupted") |
|
|
|
|
|
|
|
|
|
# skip formatting inputs/outputs |
|
|
|
|
error_details = { |
|
|
|
|
"node_id": unique_id, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return (False, error_details, iex) |
|
|
|
|
except Exception as ex: |
|
|
|
|
typ, _, tb = sys.exc_info() |
|
|
|
|
exception_type = full_type_name(typ) |
|
|
|
|
input_data_formatted = {} |
|
|
|
|
if input_data_all is not None: |
|
|
|
|
input_data_formatted = {} |
|
|
|
|
for name, inputs in input_data_all.items(): |
|
|
|
|
input_data_formatted[name] = [format_value(x) for x in inputs] |
|
|
|
|
|
|
|
|
|
output_data_formatted = {} |
|
|
|
|
for node_id, node_outputs in outputs.items(): |
|
|
|
|
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] |
|
|
|
|
|
|
|
|
|
print("!!! Exception during processing !!!") |
|
|
|
|
print(traceback.format_exc()) |
|
|
|
|
|
|
|
|
|
error_details = { |
|
|
|
|
"node_id": unique_id, |
|
|
|
|
"message": str(ex), |
|
|
|
|
"exception_type": exception_type, |
|
|
|
|
"traceback": traceback.format_tb(tb), |
|
|
|
|
"current_inputs": input_data_formatted, |
|
|
|
|
"current_outputs": output_data_formatted |
|
|
|
|
} |
|
|
|
|
return (False, error_details, ex) |
|
|
|
|
|
|
|
|
|
executed.add(unique_id) |
|
|
|
|
|
|
|
|
|
return (True, None, None) |
|
|
|
|
|
|
|
|
|
def recursive_will_execute(prompt, outputs, current_item): |
|
|
|
|
unique_id = current_item |
|
|
|
|
inputs = prompt[unique_id]['inputs'] |
|
|
|
@ -210,6 +258,44 @@ class PromptExecutor:
|
|
|
|
|
self.old_prompt = {} |
|
|
|
|
self.server = server |
|
|
|
|
|
|
|
|
|
def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): |
|
|
|
|
# First, send back the status to the frontend depending |
|
|
|
|
# on the exception type |
|
|
|
|
if isinstance(ex, comfy.model_management.InterruptProcessingException): |
|
|
|
|
mes = { |
|
|
|
|
"prompt_id": prompt_id, |
|
|
|
|
"executed": list(executed), |
|
|
|
|
|
|
|
|
|
"node_id": error["node_id"], |
|
|
|
|
} |
|
|
|
|
self.server.send_sync("execution_interrupted", mes, self.server.client_id) |
|
|
|
|
else: |
|
|
|
|
if self.server.client_id is not None: |
|
|
|
|
mes = { |
|
|
|
|
"prompt_id": prompt_id, |
|
|
|
|
"executed": list(executed), |
|
|
|
|
|
|
|
|
|
"message": error["message"], |
|
|
|
|
"exception_type": error["exception_type"], |
|
|
|
|
"traceback": error["traceback"], |
|
|
|
|
"node_id": error["node_id"], |
|
|
|
|
"current_inputs": error["current_inputs"], |
|
|
|
|
"current_outputs": error["current_outputs"], |
|
|
|
|
} |
|
|
|
|
self.server.send_sync("execution_error", mes, self.server.client_id) |
|
|
|
|
|
|
|
|
|
# Next, remove the subsequent outputs since they will not be executed |
|
|
|
|
to_delete = [] |
|
|
|
|
for o in self.outputs: |
|
|
|
|
if (o not in current_outputs) and (o not in executed): |
|
|
|
|
to_delete += [o] |
|
|
|
|
if o in self.old_prompt: |
|
|
|
|
d = self.old_prompt.pop(o) |
|
|
|
|
del d |
|
|
|
|
for o in to_delete: |
|
|
|
|
d = self.outputs.pop(o) |
|
|
|
|
del d |
|
|
|
|
|
|
|
|
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): |
|
|
|
|
nodes.interrupt_processing(False) |
|
|
|
|
|
|
|
|
@ -244,42 +330,29 @@ class PromptExecutor:
|
|
|
|
|
if self.server.client_id is not None: |
|
|
|
|
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) |
|
|
|
|
executed = set() |
|
|
|
|
try: |
|
|
|
|
to_execute = [] |
|
|
|
|
for x in list(execute_outputs): |
|
|
|
|
to_execute += [(0, x)] |
|
|
|
|
|
|
|
|
|
while len(to_execute) > 0: |
|
|
|
|
#always execute the output that depends on the least amount of unexecuted nodes first |
|
|
|
|
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) |
|
|
|
|
x = to_execute.pop(0)[-1] |
|
|
|
|
|
|
|
|
|
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) |
|
|
|
|
except Exception as e: |
|
|
|
|
if isinstance(e, comfy.model_management.InterruptProcessingException): |
|
|
|
|
print("Processing interrupted") |
|
|
|
|
else: |
|
|
|
|
message = str(traceback.format_exc()) |
|
|
|
|
print(message) |
|
|
|
|
if self.server.client_id is not None: |
|
|
|
|
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) |
|
|
|
|
|
|
|
|
|
to_delete = [] |
|
|
|
|
for o in self.outputs: |
|
|
|
|
if (o not in current_outputs) and (o not in executed): |
|
|
|
|
to_delete += [o] |
|
|
|
|
if o in self.old_prompt: |
|
|
|
|
d = self.old_prompt.pop(o) |
|
|
|
|
del d |
|
|
|
|
for o in to_delete: |
|
|
|
|
d = self.outputs.pop(o) |
|
|
|
|
del d |
|
|
|
|
finally: |
|
|
|
|
for x in executed: |
|
|
|
|
self.old_prompt[x] = copy.deepcopy(prompt[x]) |
|
|
|
|
self.server.last_node_id = None |
|
|
|
|
if self.server.client_id is not None: |
|
|
|
|
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) |
|
|
|
|
output_node_id = None |
|
|
|
|
to_execute = [] |
|
|
|
|
|
|
|
|
|
for node_id in list(execute_outputs): |
|
|
|
|
to_execute += [(0, node_id)] |
|
|
|
|
|
|
|
|
|
while len(to_execute) > 0: |
|
|
|
|
#always execute the output that depends on the least amount of unexecuted nodes first |
|
|
|
|
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) |
|
|
|
|
output_node_id = to_execute.pop(0)[-1] |
|
|
|
|
|
|
|
|
|
# This call shouldn't raise anything if there's an error deep in |
|
|
|
|
# the actual SD code, instead it will report the node where the |
|
|
|
|
# error was raised |
|
|
|
|
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) |
|
|
|
|
if success is not True: |
|
|
|
|
self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) |
|
|
|
|
|
|
|
|
|
for x in executed: |
|
|
|
|
self.old_prompt[x] = copy.deepcopy(prompt[x]) |
|
|
|
|
self.server.last_node_id = None |
|
|
|
|
if self.server.client_id is not None: |
|
|
|
|
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) |
|
|
|
|
|
|
|
|
|
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) |
|
|
|
|
gc.collect() |
|
|
|
@ -359,7 +432,7 @@ def validate_inputs(prompt, item, validated):
|
|
|
|
|
except Exception as ex: |
|
|
|
|
typ, _, tb = sys.exc_info() |
|
|
|
|
valid = False |
|
|
|
|
error_type = full_type_name(typ) |
|
|
|
|
exception_type = full_type_name(typ) |
|
|
|
|
reasons = [{ |
|
|
|
|
"type": "exception_during_validation", |
|
|
|
|
"message": "Exception when validating node", |
|
|
|
@ -367,7 +440,7 @@ def validate_inputs(prompt, item, validated):
|
|
|
|
|
"extra_info": { |
|
|
|
|
"input_name": x, |
|
|
|
|
"input_config": info, |
|
|
|
|
"error_type": error_type, |
|
|
|
|
"exception_type": exception_type, |
|
|
|
|
"traceback": traceback.format_tb(tb) |
|
|
|
|
} |
|
|
|
|
}] |
|
|
|
@ -507,13 +580,13 @@ def validate_prompt(prompt):
|
|
|
|
|
except Exception as ex: |
|
|
|
|
typ, _, tb = sys.exc_info() |
|
|
|
|
valid = False |
|
|
|
|
error_type = full_type_name(typ) |
|
|
|
|
exception_type = full_type_name(typ) |
|
|
|
|
reasons = [{ |
|
|
|
|
"type": "exception_during_validation", |
|
|
|
|
"message": "Exception when validating node", |
|
|
|
|
"details": str(ex), |
|
|
|
|
"extra_info": { |
|
|
|
|
"error_type": error_type, |
|
|
|
|
"exception_type": exception_type, |
|
|
|
|
"traceback": traceback.format_tb(tb) |
|
|
|
|
} |
|
|
|
|
}] |
|
|
|
|