From 508d286b8fbd4cfa0281ee6c0d25aab95a7c014c Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 17 Feb 2024 21:56:46 -0800 Subject: [PATCH] Fix Pyright warnings --- comfy/caching.py | 23 +++++++++++++---------- execution.py | 10 ++++------ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/comfy/caching.py b/comfy/caching.py index 7730a371..936e2e6d 100644 --- a/comfy/caching.py +++ b/comfy/caching.py @@ -1,5 +1,6 @@ import itertools from typing import Sequence, Mapping +from comfy.graph import DynamicPrompt import nodes @@ -10,7 +11,7 @@ class CacheKeySet: self.keys = {} self.subcache_keys = {} - def add_keys(node_ids): + def add_keys(self, node_ids): raise NotImplementedError() def all_node_ids(self): @@ -66,7 +67,7 @@ class CacheKeySetInputSignature(CacheKeySet): self.is_changed_cache = is_changed_cache self.add_keys(node_ids) - def include_node_id_in_input(self): + def include_node_id_in_input(self) -> bool: return False def add_keys(self, node_ids): @@ -131,8 +132,9 @@ class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature): class BasicCache: def __init__(self, key_class): self.key_class = key_class - self.dynprompt = None - self.cache_key_set = None + self.initialized = False + self.dynprompt: DynamicPrompt + self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} @@ -140,16 +142,17 @@ class BasicCache: self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.is_changed_cache = is_changed_cache + self.initialized = True def all_node_ids(self): - assert self.cache_key_set is not None + assert self.initialized node_ids = self.cache_key_set.all_node_ids() for subcache in self.subcaches.values(): node_ids = node_ids.union(subcache.all_node_ids()) return node_ids def clean_unused(self): - assert self.cache_key_set is not None + assert self.initialized preserve_keys = set(self.cache_key_set.get_used_keys()) preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) to_remove = [] @@ -167,12 +170,12 @@ class BasicCache: del self.subcaches[key] def _set_immediate(self, node_id, value): - assert self.cache_key_set is not None + assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) self.cache[cache_key] = value def _get_immediate(self, node_id): - if self.cache_key_set is None: + if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: @@ -181,7 +184,6 @@ class BasicCache: return None def _ensure_subcache(self, node_id, children_ids): - assert self.cache_key_set is not None subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache = self.subcaches.get(subcache_key, None) if subcache is None: @@ -191,7 +193,7 @@ class BasicCache: return subcache def _get_subcache(self, node_id): - assert self.cache_key_set is not None + assert self.initialized subcache_key = self.cache_key_set.get_subcache_key(node_id) if subcache_key in self.subcaches: return self.subcaches[subcache_key] @@ -211,6 +213,7 @@ class HierarchicalCache(BasicCache): super().__init__(key_class) def _get_cache_for(self, node_id): + assert self.dynprompt is not None parent_id = self.dynprompt.get_parent_node_id(node_id) if parent_id is None: return self diff --git a/execution.py b/execution.py index 57c9cbf7..4d9a4b98 100644 --- a/execution.py +++ b/execution.py @@ -84,7 +84,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro for x in inputs: input_data = inputs[x] input_type, input_category, input_info = get_input_info(class_def, x) - if is_link(input_data) and not input_info.get("rawLink", False): + if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] if outputs is None: @@ -94,7 +94,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro continue obj = cached_output[output_index] input_data_all[x] = obj - elif input_category is not None: + else: input_data_all[x] = [input_data] if "hidden" in valid_inputs: @@ -336,8 +336,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp # Check for conflicts for node_id in new_graph.keys(): if dynprompt.get_node(node_id) is not None: - raise Exception("Attempt to add duplicate node %s" % node_id) - break + raise Exception("Attempt to add duplicate node %s. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder." % node_id) for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) @@ -518,6 +517,7 @@ def validate_inputs(prompt, item, validated): for x in valid_inputs: type_input, input_category, extra_info = get_input_info(obj_class, x) + assert extra_info is not None if x not in inputs: if input_category == "required": error = { @@ -698,8 +698,6 @@ def validate_inputs(prompt, item, validated): "details": details, "extra_info": { "input_name": x, - "input_config": info, - "received_value": val, } } errors.append(error)