|
|
|
@ -147,7 +147,7 @@ class PromptExecutor:
|
|
|
|
|
self.old_prompt = {} |
|
|
|
|
self.server = server |
|
|
|
|
|
|
|
|
|
def execute(self, prompt, extra_data={}): |
|
|
|
|
def execute(self, prompt, extra_data={}, execute_outputs=[]): |
|
|
|
|
nodes.interrupt_processing(False) |
|
|
|
|
|
|
|
|
|
if "client_id" in extra_data: |
|
|
|
@ -172,27 +172,15 @@ class PromptExecutor:
|
|
|
|
|
executed = set() |
|
|
|
|
try: |
|
|
|
|
to_execute = [] |
|
|
|
|
for x in prompt: |
|
|
|
|
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] |
|
|
|
|
if hasattr(class_, 'OUTPUT_NODE'): |
|
|
|
|
to_execute += [(0, x)] |
|
|
|
|
for x in list(execute_outputs): |
|
|
|
|
to_execute += [(0, x)] |
|
|
|
|
|
|
|
|
|
while len(to_execute) > 0: |
|
|
|
|
#always execute the output that depends on the least amount of unexecuted nodes first |
|
|
|
|
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) |
|
|
|
|
x = to_execute.pop(0)[-1] |
|
|
|
|
|
|
|
|
|
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] |
|
|
|
|
if hasattr(class_, 'OUTPUT_NODE'): |
|
|
|
|
if class_.OUTPUT_NODE == True: |
|
|
|
|
valid = False |
|
|
|
|
try: |
|
|
|
|
m = validate_inputs(prompt, x) |
|
|
|
|
valid = m[0] |
|
|
|
|
except: |
|
|
|
|
valid = False |
|
|
|
|
if valid: |
|
|
|
|
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) |
|
|
|
|
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) |
|
|
|
|
except Exception as e: |
|
|
|
|
if isinstance(e, comfy.model_management.InterruptProcessingException): |
|
|
|
|
print("Processing interrupted") |
|
|
|
@ -219,8 +207,11 @@ class PromptExecutor:
|
|
|
|
|
comfy.model_management.soft_empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_inputs(prompt, item): |
|
|
|
|
def validate_inputs(prompt, item, validated): |
|
|
|
|
unique_id = item |
|
|
|
|
if unique_id in validated: |
|
|
|
|
return validated[unique_id] |
|
|
|
|
|
|
|
|
|
inputs = prompt[unique_id]['inputs'] |
|
|
|
|
class_type = prompt[unique_id]['class_type'] |
|
|
|
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] |
|
|
|
@ -241,8 +232,9 @@ def validate_inputs(prompt, item):
|
|
|
|
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES |
|
|
|
|
if r[val[1]] != type_input: |
|
|
|
|
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) |
|
|
|
|
r = validate_inputs(prompt, o_id) |
|
|
|
|
r = validate_inputs(prompt, o_id, validated) |
|
|
|
|
if r[0] == False: |
|
|
|
|
validated[o_id] = r |
|
|
|
|
return r |
|
|
|
|
else: |
|
|
|
|
if type_input == "INT": |
|
|
|
@ -270,7 +262,10 @@ def validate_inputs(prompt, item):
|
|
|
|
|
if isinstance(type_input, list): |
|
|
|
|
if val not in type_input: |
|
|
|
|
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) |
|
|
|
|
return (True, "") |
|
|
|
|
|
|
|
|
|
ret = (True, "") |
|
|
|
|
validated[unique_id] = ret |
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
def validate_prompt(prompt): |
|
|
|
|
outputs = set() |
|
|
|
@ -284,11 +279,12 @@ def validate_prompt(prompt):
|
|
|
|
|
|
|
|
|
|
good_outputs = set() |
|
|
|
|
errors = [] |
|
|
|
|
validated = {} |
|
|
|
|
for o in outputs: |
|
|
|
|
valid = False |
|
|
|
|
reason = "" |
|
|
|
|
try: |
|
|
|
|
m = validate_inputs(prompt, o) |
|
|
|
|
m = validate_inputs(prompt, o, validated) |
|
|
|
|
valid = m[0] |
|
|
|
|
reason = m[1] |
|
|
|
|
except Exception as e: |
|
|
|
@ -297,7 +293,7 @@ def validate_prompt(prompt):
|
|
|
|
|
reason = "Parsing error" |
|
|
|
|
|
|
|
|
|
if valid == True: |
|
|
|
|
good_outputs.add(x) |
|
|
|
|
good_outputs.add(o) |
|
|
|
|
else: |
|
|
|
|
print("Failed to validate prompt for output {} {}".format(o, reason)) |
|
|
|
|
print("output will be ignored") |
|
|
|
@ -307,7 +303,7 @@ def validate_prompt(prompt):
|
|
|
|
|
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) |
|
|
|
|
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) |
|
|
|
|
|
|
|
|
|
return (True, "") |
|
|
|
|
return (True, "", list(good_outputs)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PromptQueue: |
|
|
|
|