Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,30 +99,44 @@ def make_input_strong_link(self, to_node_id, to_input):
self.add_strong_link(from_node_id, from_socket, to_node_id)

def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True

def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
if unique_id in self.pendingNodes:
return
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}

inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy:
self.add_strong_link(from_node_id, from_socket, unique_id)
if not self.is_cached(from_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True

def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
node_ids = [node_unique_id]
links = []

while len(node_ids) > 0:
unique_id = node_ids.pop()
if unique_id in self.pendingNodes:
continue

self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}

inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))

for link in links:
self.add_strong_link(*link)

def is_cached(self, node_id):
return False

def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
Expand All @@ -146,11 +160,8 @@ def __init__(self, dynprompt, output_cache):
self.output_cache = output_cache
self.staged_node_id = None

def add_strong_link(self, from_node_id, from_socket, to_node_id):
if self.output_cache.get(from_node_id) is not None:
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None

def stage_node_execution(self):
assert self.staged_node_id is None
Expand Down
Loading