Refactor graph utils module and add

raw_topological_sort function
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-22 21:12:32 -03:00
commit 5bd379e714

View file

@ -1,5 +1,6 @@
from collections import deque
import copy
from collections import deque
from typing import Dict, List
def find_last_node(nodes, edges):
@ -46,6 +47,38 @@ def ungroup_node(group_node_data, base_flow):
return nodes
def raw_topological_sort(nodes, edges) -> List[Dict]:
# Redefine the above function but using the nodes and self._edges
# which are dicts instead of Vertex and Edge objects
# nodes have an id, edges have a source and target keys
# return a list of node ids in topological order
# States: 0 = unvisited, 1 = visiting, 2 = visited
state = {node["id"]: 0 for node in nodes}
nodes_dict = {node["id"]: node for node in nodes}
sorted_vertices = []
def dfs(node):
if state[node] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
if state[node] == 0:
state[node] = 1
for edge in edges:
if edge["source"] == node:
dfs(edge["target"])
state[node] = 2
sorted_vertices.append(node)
# Visit each node
for node in nodes:
if state[node["id"]] == 0:
dfs(node["id"])
reverse_sorted = list(reversed(sorted_vertices))
return [nodes_dict[node_id] for node_id in reverse_sorted]
def process_flow(flow_object):
cloned_flow = copy.deepcopy(flow_object)
processed_nodes = set() # To keep track of processed nodes
@ -66,7 +99,8 @@ def process_flow(flow_object):
# Mark node as processed
processed_nodes.add(node_id)
nodes_to_process = deque(cloned_flow["nodes"])
sorted_nodes_list = raw_topological_sort(cloned_flow["nodes"], cloned_flow["edges"])
nodes_to_process = deque(sorted_nodes_list)
while nodes_to_process:
node = nodes_to_process.popleft()
@ -107,7 +141,11 @@ def update_template(template, g_nodes):
g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] = display_name
def update_target_handle(new_edge, g_nodes, group_node_id):
def update_target_handle(
new_edge,
g_nodes,
group_node_id,
):
"""
Updates the target handle of a given edge if it is a proxy node.
@ -124,6 +162,8 @@ def update_target_handle(new_edge, g_nodes, group_node_id):
proxy_id = target_handle["proxy"]["id"]
if node := next((n for n in g_nodes if n["id"] == proxy_id), None):
set_new_target_handle(proxy_id, new_edge, target_handle, node)
else:
raise ValueError(f"Group node {group_node_id} has an invalid target proxy node {proxy_id}")
return new_edge