|
|
|
@ -7,6 +7,7 @@ import threading
|
|
|
|
|
import heapq |
|
|
|
|
import traceback |
|
|
|
|
import gc |
|
|
|
|
import inspect |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import nodes |
|
|
|
@ -402,6 +403,10 @@ def validate_inputs(prompt, item, validated):
|
|
|
|
|
errors = [] |
|
|
|
|
valid = True |
|
|
|
|
|
|
|
|
|
validate_function_inputs = [] |
|
|
|
|
if hasattr(obj_class, "VALIDATE_INPUTS"): |
|
|
|
|
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args |
|
|
|
|
|
|
|
|
|
for x in required_inputs: |
|
|
|
|
if x not in inputs: |
|
|
|
|
error = { |
|
|
|
@ -531,29 +536,7 @@ def validate_inputs(prompt, item, validated):
|
|
|
|
|
errors.append(error) |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
if hasattr(obj_class, "VALIDATE_INPUTS"): |
|
|
|
|
input_data_all = get_input_data(inputs, obj_class, unique_id) |
|
|
|
|
#ret = obj_class.VALIDATE_INPUTS(**input_data_all) |
|
|
|
|
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") |
|
|
|
|
for i, r in enumerate(ret): |
|
|
|
|
if r is not True: |
|
|
|
|
details = f"{x}" |
|
|
|
|
if r is not False: |
|
|
|
|
details += f" - {str(r)}" |
|
|
|
|
|
|
|
|
|
error = { |
|
|
|
|
"type": "custom_validation_failed", |
|
|
|
|
"message": "Custom validation failed for node", |
|
|
|
|
"details": details, |
|
|
|
|
"extra_info": { |
|
|
|
|
"input_name": x, |
|
|
|
|
"input_config": info, |
|
|
|
|
"received_value": val, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
errors.append(error) |
|
|
|
|
continue |
|
|
|
|
else: |
|
|
|
|
if x not in validate_function_inputs: |
|
|
|
|
if isinstance(type_input, list): |
|
|
|
|
if val not in type_input: |
|
|
|
|
input_config = info |
|
|
|
@ -580,6 +563,35 @@ def validate_inputs(prompt, item, validated):
|
|
|
|
|
errors.append(error) |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
if len(validate_function_inputs) > 0: |
|
|
|
|
input_data_all = get_input_data(inputs, obj_class, unique_id) |
|
|
|
|
input_filtered = {} |
|
|
|
|
for x in input_data_all: |
|
|
|
|
if x in validate_function_inputs: |
|
|
|
|
input_filtered[x] = input_data_all[x] |
|
|
|
|
|
|
|
|
|
#ret = obj_class.VALIDATE_INPUTS(**input_filtered) |
|
|
|
|
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") |
|
|
|
|
for x in input_filtered: |
|
|
|
|
for i, r in enumerate(ret): |
|
|
|
|
if r is not True: |
|
|
|
|
details = f"{x}" |
|
|
|
|
if r is not False: |
|
|
|
|
details += f" - {str(r)}" |
|
|
|
|
|
|
|
|
|
error = { |
|
|
|
|
"type": "custom_validation_failed", |
|
|
|
|
"message": "Custom validation failed for node", |
|
|
|
|
"details": details, |
|
|
|
|
"extra_info": { |
|
|
|
|
"input_name": x, |
|
|
|
|
"input_config": info, |
|
|
|
|
"received_value": val, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
errors.append(error) |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
if len(errors) > 0 or valid is not True: |
|
|
|
|
ret = (False, errors, unique_id) |
|
|
|
|
else: |
|
|
|
|