From 5bd379e7145da32330c0ae3a88f73e4cce7839f8 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 22 Nov 2023 21:12:32 -0300 Subject: [PATCH] Refactor graph utils module and add raw_topological_sort function --- src/backend/langflow/graph/graph/utils.py | 46 +++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/src/backend/langflow/graph/graph/utils.py b/src/backend/langflow/graph/graph/utils.py index a3299739e..d43fc4d84 100644 --- a/src/backend/langflow/graph/graph/utils.py +++ b/src/backend/langflow/graph/graph/utils.py @@ -1,5 +1,6 @@ -from collections import deque import copy +from collections import deque +from typing import Dict, List def find_last_node(nodes, edges): @@ -46,6 +47,38 @@ def ungroup_node(group_node_data, base_flow): return nodes +def raw_topological_sort(nodes, edges) -> List[Dict]: + # Redefine the above function but using the nodes and self._edges + # which are dicts instead of Vertex and Edge objects + # nodes have an id, edges have a source and target keys + # return a list of node ids in topological order + + # States: 0 = unvisited, 1 = visiting, 2 = visited + state = {node["id"]: 0 for node in nodes} + nodes_dict = {node["id"]: node for node in nodes} + sorted_vertices = [] + + def dfs(node): + if state[node] == 1: + # We have a cycle + raise ValueError("Graph contains a cycle, cannot perform topological sort") + if state[node] == 0: + state[node] = 1 + for edge in edges: + if edge["source"] == node: + dfs(edge["target"]) + state[node] = 2 + sorted_vertices.append(node) + + # Visit each node + for node in nodes: + if state[node["id"]] == 0: + dfs(node["id"]) + + reverse_sorted = list(reversed(sorted_vertices)) + return [nodes_dict[node_id] for node_id in reverse_sorted] + + def process_flow(flow_object): cloned_flow = copy.deepcopy(flow_object) processed_nodes = set() # To keep track of processed nodes @@ -66,7 +99,8 @@ def process_flow(flow_object): # Mark node as processed processed_nodes.add(node_id) - nodes_to_process = deque(cloned_flow["nodes"]) + sorted_nodes_list = raw_topological_sort(cloned_flow["nodes"], cloned_flow["edges"]) + nodes_to_process = deque(sorted_nodes_list) while nodes_to_process: node = nodes_to_process.popleft() @@ -107,7 +141,11 @@ def update_template(template, g_nodes): g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] = display_name -def update_target_handle(new_edge, g_nodes, group_node_id): +def update_target_handle( + new_edge, + g_nodes, + group_node_id, +): """ Updates the target handle of a given edge if it is a proxy node. @@ -124,6 +162,8 @@ def update_target_handle(new_edge, g_nodes, group_node_id): proxy_id = target_handle["proxy"]["id"] if node := next((n for n in g_nodes if n["id"] == proxy_id), None): set_new_target_handle(proxy_id, new_edge, target_handle, node) + else: + raise ValueError(f"Group node {group_node_id} has an invalid target proxy node {proxy_id}") return new_edge