Refactor graph utils module and add
raw_topological_sort function
This commit is contained in:
parent
98bacf5f74
commit
5bd379e714
1 changed files with 43 additions and 3 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from collections import deque
|
||||
import copy
|
||||
from collections import deque
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def find_last_node(nodes, edges):
|
||||
|
|
@ -46,6 +47,38 @@ def ungroup_node(group_node_data, base_flow):
|
|||
return nodes
|
||||
|
||||
|
||||
def raw_topological_sort(nodes, edges) -> List[Dict]:
|
||||
# Redefine the above function but using the nodes and self._edges
|
||||
# which are dicts instead of Vertex and Edge objects
|
||||
# nodes have an id, edges have a source and target keys
|
||||
# return a list of node ids in topological order
|
||||
|
||||
# States: 0 = unvisited, 1 = visiting, 2 = visited
|
||||
state = {node["id"]: 0 for node in nodes}
|
||||
nodes_dict = {node["id"]: node for node in nodes}
|
||||
sorted_vertices = []
|
||||
|
||||
def dfs(node):
|
||||
if state[node] == 1:
|
||||
# We have a cycle
|
||||
raise ValueError("Graph contains a cycle, cannot perform topological sort")
|
||||
if state[node] == 0:
|
||||
state[node] = 1
|
||||
for edge in edges:
|
||||
if edge["source"] == node:
|
||||
dfs(edge["target"])
|
||||
state[node] = 2
|
||||
sorted_vertices.append(node)
|
||||
|
||||
# Visit each node
|
||||
for node in nodes:
|
||||
if state[node["id"]] == 0:
|
||||
dfs(node["id"])
|
||||
|
||||
reverse_sorted = list(reversed(sorted_vertices))
|
||||
return [nodes_dict[node_id] for node_id in reverse_sorted]
|
||||
|
||||
|
||||
def process_flow(flow_object):
|
||||
cloned_flow = copy.deepcopy(flow_object)
|
||||
processed_nodes = set() # To keep track of processed nodes
|
||||
|
|
@ -66,7 +99,8 @@ def process_flow(flow_object):
|
|||
# Mark node as processed
|
||||
processed_nodes.add(node_id)
|
||||
|
||||
nodes_to_process = deque(cloned_flow["nodes"])
|
||||
sorted_nodes_list = raw_topological_sort(cloned_flow["nodes"], cloned_flow["edges"])
|
||||
nodes_to_process = deque(sorted_nodes_list)
|
||||
|
||||
while nodes_to_process:
|
||||
node = nodes_to_process.popleft()
|
||||
|
|
@ -107,7 +141,11 @@ def update_template(template, g_nodes):
|
|||
g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] = display_name
|
||||
|
||||
|
||||
def update_target_handle(new_edge, g_nodes, group_node_id):
|
||||
def update_target_handle(
|
||||
new_edge,
|
||||
g_nodes,
|
||||
group_node_id,
|
||||
):
|
||||
"""
|
||||
Updates the target handle of a given edge if it is a proxy node.
|
||||
|
||||
|
|
@ -124,6 +162,8 @@ def update_target_handle(new_edge, g_nodes, group_node_id):
|
|||
proxy_id = target_handle["proxy"]["id"]
|
||||
if node := next((n for n in g_nodes if n["id"] == proxy_id), None):
|
||||
set_new_target_handle(proxy_id, new_edge, target_handle, node)
|
||||
else:
|
||||
raise ValueError(f"Group node {group_node_id} has an invalid target proxy node {proxy_id}")
|
||||
return new_edge
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue