diff --git a/src/backend/langflow/utils/graph.py b/src/backend/langflow/utils/graph.py index 0b6f09dfd..55d9bc0d3 100644 --- a/src/backend/langflow/utils/graph.py +++ b/src/backend/langflow/utils/graph.py @@ -53,13 +53,10 @@ class Graph: def get_node(self, node_id: str) -> Union[None, Node]: return next((node for node in self.nodes if node.id == node_id), None) - def get_connected_nodes(self, node_id: str) -> List[Node]: - connected_nodes: List[Node] = [] - for edge in self.edges: - if edge.source.id == node_id: - connected_nodes.append(edge.target) - elif edge.target.id == node_id: - connected_nodes.append(edge.source) + def get_nodes_with_target(self, node: Node) -> List[Node]: + connected_nodes: List[Node] = [ + edge.source for edge in self.edges if edge.target == node + ] return connected_nodes def get_node_neighbors(self, node: Node) -> Dict[str, int]: diff --git a/src/backend/langflow/utils/payload.py b/src/backend/langflow/utils/payload.py index 3b98dfedb..902dd9af9 100644 --- a/src/backend/langflow/utils/payload.py +++ b/src/backend/langflow/utils/payload.py @@ -1,5 +1,8 @@ import contextlib import re +from typing import Dict + +from langflow.utils.graph import Graph, Node def extract_input_variables(nodes): @@ -35,14 +38,23 @@ def get_root_node(graph): return next((node for node in graph.nodes if node not in incoming_edges), None) -def build_json(root, graph): - edge_ids = [edge.source for edge in graph.edges if edge.target == root] - local_nodes = [node for node in graph.nodes if node in edge_ids] - +def build_json(root: Node, graph: Graph) -> Dict: if "node" not in root.data: - return build_json(local_nodes[0], graph) + # If the root node has no "node" key, then it has only one child, + # which is the target of the single outgoing edge + edge = root.edges[0] + local_nodes = [edge.target] + else: + # Otherwise, find all children whose type matches the type + # specified in the template + module_type = root.data["node"]["template"]["_type"] + local_nodes = graph.get_nodes_with_target(root) - final_dict = root.data["node"]["template"].copy() + if len(local_nodes) == 1: + return build_json(local_nodes[0], graph) + # Build a dictionary from the template + template = root.data["node"]["template"] + final_dict = template.copy() for key, value in final_dict.items(): if key == "_type": @@ -51,10 +63,13 @@ def build_json(root, graph): module_type = value["type"] if "value" in value and value["value"] is not None: + # If the value is specified, use it value = value["value"] elif "dict" in module_type: + # If the value is a dictionary, create an empty dictionary value = {} else: + # Otherwise, recursively build the child nodes children = [] for local_node in local_nodes: module_types = [local_node.data["type"]] @@ -68,4 +83,5 @@ def build_json(root, graph): values = [build_json(child, graph) for child in children] value = list(values) if value["list"] else next(iter(values), None) final_dict[key] = value + return final_dict diff --git a/tests/test_graph.py b/tests/test_graph.py index c1b6f0c56..4580809a9 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -17,14 +17,14 @@ def get_graph(basic=True): return Graph(nodes, edges) -def test_get_connected_nodes(): +def test_get_nodes_with_target(): """Test getting connected nodes""" graph = get_graph() assert isinstance(graph, Graph) # Get root node root = get_root_node(graph) assert root is not None - connected_nodes = graph.get_connected_nodes(root) + connected_nodes = graph.get_nodes_with_target(root) assert connected_nodes is not None