Browse Source

Fix Pyright warnings

pull/2666/head
Jacob Segal 9 months ago
parent
commit
508d286b8f
  1. 23
      comfy/caching.py
  2. 10
      execution.py

23
comfy/caching.py

@ -1,5 +1,6 @@
import itertools import itertools
from typing import Sequence, Mapping from typing import Sequence, Mapping
from comfy.graph import DynamicPrompt
import nodes import nodes
@ -10,7 +11,7 @@ class CacheKeySet:
self.keys = {} self.keys = {}
self.subcache_keys = {} self.subcache_keys = {}
def add_keys(node_ids): def add_keys(self, node_ids):
raise NotImplementedError() raise NotImplementedError()
def all_node_ids(self): def all_node_ids(self):
@ -66,7 +67,7 @@ class CacheKeySetInputSignature(CacheKeySet):
self.is_changed_cache = is_changed_cache self.is_changed_cache = is_changed_cache
self.add_keys(node_ids) self.add_keys(node_ids)
def include_node_id_in_input(self): def include_node_id_in_input(self) -> bool:
return False return False
def add_keys(self, node_ids): def add_keys(self, node_ids):
@ -131,8 +132,9 @@ class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
class BasicCache: class BasicCache:
def __init__(self, key_class): def __init__(self, key_class):
self.key_class = key_class self.key_class = key_class
self.dynprompt = None self.initialized = False
self.cache_key_set = None self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
@ -140,16 +142,17 @@ class BasicCache:
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
self.is_changed_cache = is_changed_cache self.is_changed_cache = is_changed_cache
self.initialized = True
def all_node_ids(self): def all_node_ids(self):
assert self.cache_key_set is not None assert self.initialized
node_ids = self.cache_key_set.all_node_ids() node_ids = self.cache_key_set.all_node_ids()
for subcache in self.subcaches.values(): for subcache in self.subcaches.values():
node_ids = node_ids.union(subcache.all_node_ids()) node_ids = node_ids.union(subcache.all_node_ids())
return node_ids return node_ids
def clean_unused(self): def clean_unused(self):
assert self.cache_key_set is not None assert self.initialized
preserve_keys = set(self.cache_key_set.get_used_keys()) preserve_keys = set(self.cache_key_set.get_used_keys())
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
to_remove = [] to_remove = []
@ -167,12 +170,12 @@ class BasicCache:
del self.subcaches[key] del self.subcaches[key]
def _set_immediate(self, node_id, value): def _set_immediate(self, node_id, value):
assert self.cache_key_set is not None assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value self.cache[cache_key] = value
def _get_immediate(self, node_id): def _get_immediate(self, node_id):
if self.cache_key_set is None: if not self.initialized:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache: if cache_key in self.cache:
@ -181,7 +184,6 @@ class BasicCache:
return None return None
def _ensure_subcache(self, node_id, children_ids): def _ensure_subcache(self, node_id, children_ids):
assert self.cache_key_set is not None
subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None) subcache = self.subcaches.get(subcache_key, None)
if subcache is None: if subcache is None:
@ -191,7 +193,7 @@ class BasicCache:
return subcache return subcache
def _get_subcache(self, node_id): def _get_subcache(self, node_id):
assert self.cache_key_set is not None assert self.initialized
subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches: if subcache_key in self.subcaches:
return self.subcaches[subcache_key] return self.subcaches[subcache_key]
@ -211,6 +213,7 @@ class HierarchicalCache(BasicCache):
super().__init__(key_class) super().__init__(key_class)
def _get_cache_for(self, node_id): def _get_cache_for(self, node_id):
assert self.dynprompt is not None
parent_id = self.dynprompt.get_parent_node_id(node_id) parent_id = self.dynprompt.get_parent_node_id(node_id)
if parent_id is None: if parent_id is None:
return self return self

10
execution.py

@ -84,7 +84,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x) input_type, input_category, input_info = get_input_info(class_def, x)
if is_link(input_data) and not input_info.get("rawLink", False): if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if outputs is None: if outputs is None:
@ -94,7 +94,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro
continue continue
obj = cached_output[output_index] obj = cached_output[output_index]
input_data_all[x] = obj input_data_all[x] = obj
elif input_category is not None: else:
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
@ -336,8 +336,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
# Check for conflicts # Check for conflicts
for node_id in new_graph.keys(): for node_id in new_graph.keys():
if dynprompt.get_node(node_id) is not None: if dynprompt.get_node(node_id) is not None:
raise Exception("Attempt to add duplicate node %s" % node_id) raise Exception("Attempt to add duplicate node %s. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder." % node_id)
break
for node_id, node_info in new_graph.items(): for node_id, node_info in new_graph.items():
new_node_ids.append(node_id) new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id) display_id = node_info.get("override_display_id", unique_id)
@ -518,6 +517,7 @@ def validate_inputs(prompt, item, validated):
for x in valid_inputs: for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x) type_input, input_category, extra_info = get_input_info(obj_class, x)
assert extra_info is not None
if x not in inputs: if x not in inputs:
if input_category == "required": if input_category == "required":
error = { error = {
@ -698,8 +698,6 @@ def validate_inputs(prompt, item, validated):
"details": details, "details": details,
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
"input_config": info,
"received_value": val,
} }
} }
errors.append(error) errors.append(error)

Loading…
Cancel
Save