Browse Source

Make custom VALIDATE_INPUTS skip normal validation

Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`,
that variable will be a dictionary of the socket type of all incoming
connections. If that argument exists, normal socket type validation will
not occur. This removes the last hurdle for enabling variant types
entirely from custom nodes, so I've removed that command-line option.

I've added appropriate unit tests for these changes.
pull/2666/head
Jacob Segal 9 months ago
parent
commit
6d09dd70f8
  1. 1
      comfy/cli_args.py
  2. 61
      execution.py
  3. 63
      tests/inference/test_execution.py
  4. 4
      tests/inference/testing_nodes/testing-pack/flow_control.py
  5. 112
      tests/inference/testing_nodes/testing-pack/specific_tests.py
  6. 44
      tests/inference/testing_nodes/testing-pack/stubs.py
  7. 48
      tests/inference/testing_nodes/testing-pack/tools.py
  8. 11
      tests/inference/testing_nodes/testing-pack/util.py

1
comfy/cli_args.py

@ -117,7 +117,6 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--enable-variants", action="store_true", help="Enables '*' type nodes.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()

61
execution.py

@ -92,6 +92,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro
cached_output = outputs.get(input_unique_id) cached_output = outputs.get(input_unique_id)
if cached_output is None: if cached_output is None:
continue continue
if output_index >= len(cached_output):
continue
obj = cached_output[output_index] obj = cached_output[output_index]
input_data_all[x] = obj input_data_all[x] = obj
elif input_category is not None: elif input_category is not None:
@ -514,6 +516,7 @@ def validate_inputs(prompt, item, validated):
validate_function_inputs = [] validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"): if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
received_types = {}
for x in valid_inputs: for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x) type_input, input_category, extra_info = get_input_info(obj_class, x)
@ -551,9 +554,9 @@ def validate_inputs(prompt, item, validated):
o_id = val[0] o_id = val[0]
o_class_type = prompt[o_id]['class_type'] o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
is_variant = args.enable_variants and (r[val[1]] == "*" or type_input == "*") received_type = r[val[1]]
if r[val[1]] != type_input and not is_variant: received_types[x] = received_type
received_type = r[val[1]] if 'input_types' not in validate_function_inputs and received_type != type_input:
details = f"{x}, {received_type} != {type_input}" details = f"{x}, {received_type} != {type_input}"
error = { error = {
"type": "return_type_mismatch", "type": "return_type_mismatch",
@ -622,34 +625,34 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if "min" in extra_info and val < extra_info["min"]: if x not in validate_function_inputs:
error = { if "min" in extra_info and val < extra_info["min"]:
"type": "value_smaller_than_min", error = {
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]), "type": "value_smaller_than_min",
"details": f"{x}", "message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"extra_info": { "details": f"{x}",
"input_name": x, "extra_info": {
"input_config": info, "input_name": x,
"received_value": val, "input_config": info,
"received_value": val,
}
} }
} errors.append(error)
errors.append(error) continue
continue if "max" in extra_info and val > extra_info["max"]:
if "max" in extra_info and val > extra_info["max"]: error = {
error = { "type": "value_bigger_than_max",
"type": "value_bigger_than_max", "message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]), "details": f"{x}",
"details": f"{x}", "extra_info": {
"extra_info": { "input_name": x,
"input_name": x, "input_config": info,
"input_config": info, "received_value": val,
"received_value": val, }
} }
} errors.append(error)
errors.append(error) continue
continue
if x not in validate_function_inputs:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
input_config = info input_config = info
@ -682,6 +685,8 @@ def validate_inputs(prompt, item, validated):
for x in input_data_all: for x in input_data_all:
if x in validate_function_inputs: if x in validate_function_inputs:
input_filtered[x] = input_data_all[x] input_filtered[x] = input_data_all[x]
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered) #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")

63
tests/inference/test_execution.py

@ -12,6 +12,7 @@ import websocket #NOTE: websocket-client (https://github.com/websocket-client/we
import uuid import uuid
import urllib.request import urllib.request
import urllib.parse import urllib.parse
import urllib.error
from comfy.graph_utils import GraphBuilder, Node from comfy.graph_utils import GraphBuilder, Node
class RunResult: class RunResult:
@ -125,7 +126,6 @@ class TestExecution:
'--listen', args_pytest["listen"], '--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]), '--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--enable-variants',
]) ])
yield yield
p.kill() p.kill()
@ -237,6 +237,67 @@ class TestExecution:
except Exception as e: except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
@pytest.mark.parametrize("test_value, expect_error", [
(5, True),
("foo", True),
(5.0, False),
])
def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
g.node("SaveImage", images=validation1.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
@pytest.mark.parametrize("test_type, test_value", [
("StubInt", 5),
("StubFloat", 5.0)
])
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation1.out(0))
with pytest.raises(urllib.error.HTTPError):
client.run(g)
@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation2.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation3.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
# Creating the nodes in this specific order previously caused a bug # Creating the nodes in this specific order previously caused a bug

4
tests/inference/testing_nodes/testing-pack/flow_control.py

@ -1,7 +1,9 @@
from comfy.graph_utils import GraphBuilder, is_link from comfy.graph_utils import GraphBuilder, is_link
from comfy.graph import ExecutionBlocker from comfy.graph import ExecutionBlocker
from .tools import VariantSupport
NUM_FLOW_SOCKETS = 5 NUM_FLOW_SOCKETS = 5
@VariantSupport()
class TestWhileLoopOpen: class TestWhileLoopOpen:
def __init__(self): def __init__(self):
pass pass
@ -31,6 +33,7 @@ class TestWhileLoopOpen:
values.append(kwargs.get("initial_value%d" % i, None)) values.append(kwargs.get("initial_value%d" % i, None))
return tuple(["stub"] + values) return tuple(["stub"] + values)
@VariantSupport()
class TestWhileLoopClose: class TestWhileLoopClose:
def __init__(self): def __init__(self):
pass pass
@ -131,6 +134,7 @@ class TestWhileLoopClose:
"expand": graph.finalize(), "expand": graph.finalize(),
} }
@VariantSupport()
class TestExecutionBlockerNode: class TestExecutionBlockerNode:
def __init__(self): def __init__(self):
pass pass

112
tests/inference/testing_nodes/testing-pack/specific_tests.py

@ -1,9 +1,7 @@
import torch import torch
from .tools import VariantSupport
class TestLazyMixImages: class TestLazyMixImages:
def __init__(self):
pass
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
@ -50,9 +48,6 @@ class TestLazyMixImages:
return (result[0],) return (result[0],)
class TestVariadicAverage: class TestVariadicAverage:
def __init__(self):
pass
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
@ -74,9 +69,6 @@ class TestVariadicAverage:
class TestCustomIsChanged: class TestCustomIsChanged:
def __init__(self):
pass
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
@ -103,14 +95,116 @@ class TestCustomIsChanged:
else: else:
return False return False
class TestCustomValidation1:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation1"
CATEGORY = "Testing/Nodes"
def custom_validation1(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"
return True
class TestCustomValidation2:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation2"
CATEGORY = "Testing/Nodes"
def custom_validation2(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"
if 'input1' in input_types:
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input1: {input_types['input1']}"
if 'input2' in input_types:
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input2: {input_types['input2']}"
return True
@VariantSupport()
class TestCustomValidation3:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation3"
CATEGORY = "Testing/Nodes"
def custom_validation3(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
TEST_NODE_CLASS_MAPPINGS = { TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages, "TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage, "TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged, "TestCustomIsChanged": TestCustomIsChanged,
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
} }
TEST_NODE_DISPLAY_NAME_MAPPINGS = { TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestLazyMixImages": "Lazy Mix Images", "TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average", "TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged", "TestCustomIsChanged": "Custom IsChanged",
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
} }

44
tests/inference/testing_nodes/testing-pack/stubs.py

@ -51,11 +51,55 @@ class StubMask:
def stub_mask(self, value, height, width, batch_size): def stub_mask(self, value, height, width, batch_size):
return (torch.ones(batch_size, height, width) * value,) return (torch.ones(batch_size, height, width) * value,)
class StubInt:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
},
}
RETURN_TYPES = ("INT",)
FUNCTION = "stub_int"
CATEGORY = "Testing/Stub Nodes"
def stub_int(self, value):
return (value,)
class StubFloat:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "stub_float"
CATEGORY = "Testing/Stub Nodes"
def stub_float(self, value):
return (value,)
TEST_STUB_NODE_CLASS_MAPPINGS = { TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage, "StubImage": StubImage,
"StubMask": StubMask, "StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
} }
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image", "StubImage": "Stub Image",
"StubMask": "Stub Mask", "StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",
} }

48
tests/inference/testing_nodes/testing-pack/tools.py

@ -0,0 +1,48 @@
class SmartType(str):
def __ne__(self, other):
if self == "*" or other == "*":
return False
selfset = set(self.split(','))
otherset = set(other.split(','))
return not selfset.issubset(otherset)
def VariantSupport():
def decorator(cls):
if hasattr(cls, "INPUT_TYPES"):
old_input_types = getattr(cls, "INPUT_TYPES")
def new_input_types(*args, **kwargs):
types = old_input_types(*args, **kwargs)
for category in ["required", "optional"]:
if category not in types:
continue
for key, value in types[category].items():
if isinstance(value, tuple):
types[category][key] = (SmartType(value[0]),) + value[1:]
return types
setattr(cls, "INPUT_TYPES", new_input_types)
if hasattr(cls, "RETURN_TYPES"):
old_return_types = cls.RETURN_TYPES
setattr(cls, "RETURN_TYPES", tuple(SmartType(x) for x in old_return_types))
if hasattr(cls, "VALIDATE_INPUTS"):
# Reflection is used to determine what the function signature is, so we can't just change the function signature
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
else:
def validate_inputs(input_types):
inputs = cls.INPUT_TYPES()
for key, value in input_types.items():
if isinstance(value, SmartType):
continue
if "required" in inputs and key in inputs["required"]:
expected_type = inputs["required"][key][0]
elif "optional" in inputs and key in inputs["optional"]:
expected_type = inputs["optional"][key][0]
else:
expected_type = None
if expected_type is not None and SmartType(value) != expected_type:
return f"Invalid type of {key}: {value} (expected {expected_type})"
return True
setattr(cls, "VALIDATE_INPUTS", validate_inputs)
return cls
return decorator

11
tests/inference/testing_nodes/testing-pack/util.py

@ -1,5 +1,7 @@
from comfy.graph_utils import GraphBuilder from comfy.graph_utils import GraphBuilder
from .tools import VariantSupport
@VariantSupport()
class TestAccumulateNode: class TestAccumulateNode:
def __init__(self): def __init__(self):
pass pass
@ -27,6 +29,7 @@ class TestAccumulateNode:
value = accumulation["accum"] + [to_add] value = accumulation["accum"] + [to_add]
return ({"accum": value},) return ({"accum": value},)
@VariantSupport()
class TestAccumulationHeadNode: class TestAccumulationHeadNode:
def __init__(self): def __init__(self):
pass pass
@ -75,6 +78,7 @@ class TestAccumulationTailNode:
else: else:
return ({"accum": accum[:-1]}, accum[-1]) return ({"accum": accum[:-1]}, accum[-1])
@VariantSupport()
class TestAccumulationToListNode: class TestAccumulationToListNode:
def __init__(self): def __init__(self):
pass pass
@ -97,6 +101,7 @@ class TestAccumulationToListNode:
def accumulation_to_list(self, accumulation): def accumulation_to_list(self, accumulation):
return (accumulation["accum"],) return (accumulation["accum"],)
@VariantSupport()
class TestListToAccumulationNode: class TestListToAccumulationNode:
def __init__(self): def __init__(self):
pass pass
@ -119,6 +124,7 @@ class TestListToAccumulationNode:
def list_to_accumulation(self, list): def list_to_accumulation(self, list):
return ({"accum": list},) return ({"accum": list},)
@VariantSupport()
class TestAccumulationGetLengthNode: class TestAccumulationGetLengthNode:
def __init__(self): def __init__(self):
pass pass
@ -140,6 +146,7 @@ class TestAccumulationGetLengthNode:
def accumlength(self, accumulation): def accumlength(self, accumulation):
return (len(accumulation['accum']),) return (len(accumulation['accum']),)
@VariantSupport()
class TestAccumulationGetItemNode: class TestAccumulationGetItemNode:
def __init__(self): def __init__(self):
pass pass
@ -162,6 +169,7 @@ class TestAccumulationGetItemNode:
def get_item(self, accumulation, index): def get_item(self, accumulation, index):
return (accumulation['accum'][index],) return (accumulation['accum'][index],)
@VariantSupport()
class TestAccumulationSetItemNode: class TestAccumulationSetItemNode:
def __init__(self): def __init__(self):
pass pass
@ -222,6 +230,7 @@ class TestIntMathOperation:
from .flow_control import NUM_FLOW_SOCKETS from .flow_control import NUM_FLOW_SOCKETS
@VariantSupport()
class TestForLoopOpen: class TestForLoopOpen:
def __init__(self): def __init__(self):
pass pass
@ -257,6 +266,7 @@ class TestForLoopOpen:
"expand": graph.finalize(), "expand": graph.finalize(),
} }
@VariantSupport()
class TestForLoopClose: class TestForLoopClose:
def __init__(self): def __init__(self):
pass pass
@ -295,6 +305,7 @@ class TestForLoopClose:
} }
NUM_LIST_SOCKETS = 10 NUM_LIST_SOCKETS = 10
@VariantSupport()
class TestMakeListNode: class TestMakeListNode:
def __init__(self): def __init__(self):
pass pass

Loading…
Cancel
Save