Browse Source

Make custom VALIDATE_INPUTS skip normal validation

Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`,
that variable will be a dictionary of the socket type of all incoming
connections. If that argument exists, normal socket type validation will
not occur. This removes the last hurdle for enabling variant types
entirely from custom nodes, so I've removed that command-line option.

I've added appropriate unit tests for these changes.
pull/2666/head
Jacob Segal 9 months ago
parent
commit
6d09dd70f8
  1. 1
      comfy/cli_args.py
  2. 61
      execution.py
  3. 63
      tests/inference/test_execution.py
  4. 4
      tests/inference/testing_nodes/testing-pack/flow_control.py
  5. 112
      tests/inference/testing_nodes/testing-pack/specific_tests.py
  6. 44
      tests/inference/testing_nodes/testing-pack/stubs.py
  7. 48
      tests/inference/testing_nodes/testing-pack/tools.py
  8. 11
      tests/inference/testing_nodes/testing-pack/util.py

1
comfy/cli_args.py

@ -117,7 +117,6 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--enable-variants", action="store_true", help="Enables '*' type nodes.")
if comfy.options.args_parsing:
args = parser.parse_args()

61
execution.py

@ -92,6 +92,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro
cached_output = outputs.get(input_unique_id)
if cached_output is None:
continue
if output_index >= len(cached_output):
continue
obj = cached_output[output_index]
input_data_all[x] = obj
elif input_category is not None:
@ -514,6 +516,7 @@ def validate_inputs(prompt, item, validated):
validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
received_types = {}
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
@ -551,9 +554,9 @@ def validate_inputs(prompt, item, validated):
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
is_variant = args.enable_variants and (r[val[1]] == "*" or type_input == "*")
if r[val[1]] != type_input and not is_variant:
received_type = r[val[1]]
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",
@ -622,34 +625,34 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
if x not in validate_function_inputs:
if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
}
errors.append(error)
continue
if "max" in extra_info and val > extra_info["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
errors.append(error)
continue
if "max" in extra_info and val > extra_info["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
}
errors.append(error)
continue
errors.append(error)
continue
if x not in validate_function_inputs:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
@ -682,6 +685,8 @@ def validate_inputs(prompt, item, validated):
for x in input_data_all:
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")

63
tests/inference/test_execution.py

@ -12,6 +12,7 @@ import websocket #NOTE: websocket-client (https://github.com/websocket-client/we
import uuid
import urllib.request
import urllib.parse
import urllib.error
from comfy.graph_utils import GraphBuilder, Node
class RunResult:
@ -125,7 +126,6 @@ class TestExecution:
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--enable-variants',
])
yield
p.kill()
@ -237,6 +237,67 @@ class TestExecution:
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
@pytest.mark.parametrize("test_value, expect_error", [
(5, True),
("foo", True),
(5.0, False),
])
def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
g.node("SaveImage", images=validation1.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
@pytest.mark.parametrize("test_type, test_value", [
("StubInt", 5),
("StubFloat", 5.0)
])
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation1.out(0))
with pytest.raises(urllib.error.HTTPError):
client.run(g)
@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation2.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation3.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Creating the nodes in this specific order previously caused a bug

4
tests/inference/testing_nodes/testing-pack/flow_control.py

@ -1,7 +1,9 @@
from comfy.graph_utils import GraphBuilder, is_link
from comfy.graph import ExecutionBlocker
from .tools import VariantSupport
NUM_FLOW_SOCKETS = 5
@VariantSupport()
class TestWhileLoopOpen:
def __init__(self):
pass
@ -31,6 +33,7 @@ class TestWhileLoopOpen:
values.append(kwargs.get("initial_value%d" % i, None))
return tuple(["stub"] + values)
@VariantSupport()
class TestWhileLoopClose:
def __init__(self):
pass
@ -131,6 +134,7 @@ class TestWhileLoopClose:
"expand": graph.finalize(),
}
@VariantSupport()
class TestExecutionBlockerNode:
def __init__(self):
pass

112
tests/inference/testing_nodes/testing-pack/specific_tests.py

@ -1,9 +1,7 @@
import torch
from .tools import VariantSupport
class TestLazyMixImages:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
@ -50,9 +48,6 @@ class TestLazyMixImages:
return (result[0],)
class TestVariadicAverage:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
@ -74,9 +69,6 @@ class TestVariadicAverage:
class TestCustomIsChanged:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
@ -103,14 +95,116 @@ class TestCustomIsChanged:
else:
return False
class TestCustomValidation1:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation1"
CATEGORY = "Testing/Nodes"
def custom_validation1(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"
return True
class TestCustomValidation2:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation2"
CATEGORY = "Testing/Nodes"
def custom_validation2(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"
if 'input1' in input_types:
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input1: {input_types['input1']}"
if 'input2' in input_types:
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input2: {input_types['input2']}"
return True
@VariantSupport()
class TestCustomValidation3:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation3"
CATEGORY = "Testing/Nodes"
def custom_validation3(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged,
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged",
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
}

44
tests/inference/testing_nodes/testing-pack/stubs.py

@ -51,11 +51,55 @@ class StubMask:
def stub_mask(self, value, height, width, batch_size):
return (torch.ones(batch_size, height, width) * value,)
class StubInt:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
},
}
RETURN_TYPES = ("INT",)
FUNCTION = "stub_int"
CATEGORY = "Testing/Stub Nodes"
def stub_int(self, value):
return (value,)
class StubFloat:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "stub_float"
CATEGORY = "Testing/Stub Nodes"
def stub_float(self, value):
return (value,)
TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
}
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image",
"StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",
}

48
tests/inference/testing_nodes/testing-pack/tools.py

@ -0,0 +1,48 @@
class SmartType(str):
def __ne__(self, other):
if self == "*" or other == "*":
return False
selfset = set(self.split(','))
otherset = set(other.split(','))
return not selfset.issubset(otherset)
def VariantSupport():
def decorator(cls):
if hasattr(cls, "INPUT_TYPES"):
old_input_types = getattr(cls, "INPUT_TYPES")
def new_input_types(*args, **kwargs):
types = old_input_types(*args, **kwargs)
for category in ["required", "optional"]:
if category not in types:
continue
for key, value in types[category].items():
if isinstance(value, tuple):
types[category][key] = (SmartType(value[0]),) + value[1:]
return types
setattr(cls, "INPUT_TYPES", new_input_types)
if hasattr(cls, "RETURN_TYPES"):
old_return_types = cls.RETURN_TYPES
setattr(cls, "RETURN_TYPES", tuple(SmartType(x) for x in old_return_types))
if hasattr(cls, "VALIDATE_INPUTS"):
# Reflection is used to determine what the function signature is, so we can't just change the function signature
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
else:
def validate_inputs(input_types):
inputs = cls.INPUT_TYPES()
for key, value in input_types.items():
if isinstance(value, SmartType):
continue
if "required" in inputs and key in inputs["required"]:
expected_type = inputs["required"][key][0]
elif "optional" in inputs and key in inputs["optional"]:
expected_type = inputs["optional"][key][0]
else:
expected_type = None
if expected_type is not None and SmartType(value) != expected_type:
return f"Invalid type of {key}: {value} (expected {expected_type})"
return True
setattr(cls, "VALIDATE_INPUTS", validate_inputs)
return cls
return decorator

11
tests/inference/testing_nodes/testing-pack/util.py

@ -1,5 +1,7 @@
from comfy.graph_utils import GraphBuilder
from .tools import VariantSupport
@VariantSupport()
class TestAccumulateNode:
def __init__(self):
pass
@ -27,6 +29,7 @@ class TestAccumulateNode:
value = accumulation["accum"] + [to_add]
return ({"accum": value},)
@VariantSupport()
class TestAccumulationHeadNode:
def __init__(self):
pass
@ -75,6 +78,7 @@ class TestAccumulationTailNode:
else:
return ({"accum": accum[:-1]}, accum[-1])
@VariantSupport()
class TestAccumulationToListNode:
def __init__(self):
pass
@ -97,6 +101,7 @@ class TestAccumulationToListNode:
def accumulation_to_list(self, accumulation):
return (accumulation["accum"],)
@VariantSupport()
class TestListToAccumulationNode:
def __init__(self):
pass
@ -119,6 +124,7 @@ class TestListToAccumulationNode:
def list_to_accumulation(self, list):
return ({"accum": list},)
@VariantSupport()
class TestAccumulationGetLengthNode:
def __init__(self):
pass
@ -140,6 +146,7 @@ class TestAccumulationGetLengthNode:
def accumlength(self, accumulation):
return (len(accumulation['accum']),)
@VariantSupport()
class TestAccumulationGetItemNode:
def __init__(self):
pass
@ -162,6 +169,7 @@ class TestAccumulationGetItemNode:
def get_item(self, accumulation, index):
return (accumulation['accum'][index],)
@VariantSupport()
class TestAccumulationSetItemNode:
def __init__(self):
pass
@ -222,6 +230,7 @@ class TestIntMathOperation:
from .flow_control import NUM_FLOW_SOCKETS
@VariantSupport()
class TestForLoopOpen:
def __init__(self):
pass
@ -257,6 +266,7 @@ class TestForLoopOpen:
"expand": graph.finalize(),
}
@VariantSupport()
class TestForLoopClose:
def __init__(self):
pass
@ -295,6 +305,7 @@ class TestForLoopClose:
}
NUM_LIST_SOCKETS = 10
@VariantSupport()
class TestMakeListNode:
def __init__(self):
pass

Loading…
Cancel
Save