diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index a8ffe42df..3c2f0bd4f 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -17,6 +17,7 @@ from langchain.agents.load_tools import ( _EXTRA_LLM_TOOLS, _EXTRA_OPTIONAL_TOOLS, ) +from langflow.utils.graph import Graph def load_flow_from_json(path: str): @@ -36,8 +37,9 @@ def extract_json(data_graph): nodes = payload.extract_input_variables(nodes) # Nodes, edges and root node edges = data_graph["edges"] - root = payload.get_root_node(nodes, edges) - return payload.build_json(root, nodes, edges) + graph = Graph(nodes, edges) + root = payload.get_root_node(graph) + return payload.build_json(root, graph) def replace_zero_shot_prompt_with_prompt_template(nodes): diff --git a/src/backend/langflow/utils/graph.py b/src/backend/langflow/utils/graph.py new file mode 100644 index 000000000..a5b44cc17 --- /dev/null +++ b/src/backend/langflow/utils/graph.py @@ -0,0 +1,87 @@ +from typing import Dict, List, Union + + +class Node: + def __init__(self, data: Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]): + self.id: str = data["id"] + self._data = data + self.edges: List[Edge] = [] + self._parse_data() + + def _parse_data(self) -> None: + self.data = self._data["data"] + + def add_edge(self, edge: "Edge") -> None: + self.edges.append(edge) + + def __repr__(self) -> str: + return f"Node(id={self.id}, data={self.data})" + + def __eq__(self, __o: object) -> bool: + return self.id == __o.id if isinstance(__o, Node) else False + + def __hash__(self) -> int: + return id(self) + + +class Edge: + def __init__(self, source: "Node", target: "Node"): + self.source: "Node" = source + self.target: "Node" = target + + def __repr__(self) -> str: + return f"Edge(source={self.source.id}, target={self.target.id})" + + +class Graph: + def __init__( + self, + nodes: List[Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]], + edges: List[Dict[str, str]], + ) -> None: + self._nodes = nodes + self._edges = edges + self._build_graph() + + def _build_graph(self) -> None: + self.nodes = self._build_nodes() + self.edges = self._build_edges() + for edge in self.edges: + edge.source.add_edge(edge) + edge.target.add_edge(edge) + + def get_node(self, node_id: str) -> Union[None, Node]: + return next((node for node in self.nodes if node.id == node_id), None) + + def get_connected_nodes(self, node_id: str) -> List[Node]: + connected_nodes: List[Node] = [] + for edge in self.edges: + if edge.source.id == node_id: + connected_nodes.append(edge.target) + elif edge.target.id == node_id: + connected_nodes.append(edge.source) + return connected_nodes + + def get_node_neighbors(self, node_id: str) -> Dict[str, int]: + neighbors: Dict[str, int] = {} + for edge in self.edges: + if edge.source.id == node_id: + neighbor_id = edge.target.id + if neighbor_id not in neighbors: + neighbors[neighbor_id] = 0 + neighbors[neighbor_id] += 1 + elif edge.target.id == node_id: + neighbor_id = edge.source.id + if neighbor_id not in neighbors: + neighbors[neighbor_id] = 0 + neighbors[neighbor_id] += 1 + return neighbors + + def _build_edges(self) -> List[Edge]: + return [ + Edge(self.get_node(edge["source"]), self.get_node(edge["target"])) + for edge in self._edges + ] + + def _build_nodes(self) -> List[Node]: + return [Node(node) for node in self._nodes] diff --git a/src/backend/langflow/utils/payload.py b/src/backend/langflow/utils/payload.py index e7d38139e..3b98dfedb 100644 --- a/src/backend/langflow/utils/payload.py +++ b/src/backend/langflow/utils/payload.py @@ -27,25 +27,22 @@ def extract_input_variables(nodes): return nodes -def get_root_node(nodes, edges): +def get_root_node(graph): """ Returns the root node of the template. """ - incoming_edges = {edge["source"] for edge in edges} - return next((node for node in nodes if node["id"] not in incoming_edges), None) + incoming_edges = {edge.source for edge in graph.edges} + return next((node for node in graph.nodes if node not in incoming_edges), None) -def build_json(root, nodes, edges): - """ - Builds a json from the nodes and edges - """ - edge_ids = [edge["source"] for edge in edges if edge["target"] == root["id"]] - local_nodes = [node for node in nodes if node["id"] in edge_ids] +def build_json(root, graph): + edge_ids = [edge.source for edge in graph.edges if edge.target == root] + local_nodes = [node for node in graph.nodes if node in edge_ids] - if "node" not in root["data"]: - return build_json(local_nodes[0], nodes, edges) + if "node" not in root.data: + return build_json(local_nodes[0], graph) - final_dict = root["data"]["node"]["template"].copy() + final_dict = root.data["node"]["template"].copy() for key, value in final_dict.items(): if key == "_type": @@ -59,16 +56,16 @@ def build_json(root, nodes, edges): value = {} else: children = [] - for c in local_nodes: - module_types = [c["data"]["type"]] - if "node" in c["data"]: - module_types += c["data"]["node"]["base_classes"] + for local_node in local_nodes: + module_types = [local_node.data["type"]] + if "node" in local_node.data: + module_types += local_node.data["node"]["base_classes"] if module_type in module_types: - children.append(c) + children.append(local_node) if value["required"] and not children: raise ValueError(f"No child with type {module_type} found") - values = [build_json(child, nodes, edges) for child in children] + values = [build_json(child, graph) for child in children] value = list(values) if value["list"] else next(iter(values), None) final_dict[key] = value return final_dict diff --git a/tests/test_loading.py b/tests/test_loading.py index e94fd3e9f..5ee8383de 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -1,5 +1,6 @@ import json from langchain import LLMChain, OpenAI +from langflow.utils.graph import Graph import pytest from pathlib import Path from langflow import load_flow_from_json @@ -31,10 +32,11 @@ def test_get_root_node(): data_graph = flow_graph["data"] nodes = data_graph["nodes"] edges = data_graph["edges"] - root = get_root_node(nodes, edges) + graph = Graph(nodes, edges) + root = get_root_node(graph) assert root is not None - assert "id" in root - assert "data" in root + assert hasattr(root, "id") + assert hasattr(root, "data") def test_build_json(): @@ -43,8 +45,9 @@ def test_build_json(): data_graph = flow_graph["data"] nodes = data_graph["nodes"] edges = data_graph["edges"] - root = get_root_node(nodes, edges) - built_json = build_json(root, nodes, edges) + graph = Graph(nodes, edges) + root = get_root_node(graph) + built_json = build_json(root, graph) assert built_json is not None assert isinstance(built_json, dict) @@ -63,9 +66,10 @@ def test_build_json_missing_child(): if isinstance(value, dict) and "required" in value: value["required"] = True - root = get_root_node(nodes, edges) + graph = Graph(nodes, edges) + root = get_root_node(graph) with pytest.raises(ValueError): - build_json(root, nodes, edges) + build_json(root, graph) def test_build_json_no_nodes(): @@ -83,8 +87,9 @@ def test_build_json_invalid_edge(): for edge in edges: edge["source"] = "invalid_id" - root = get_root_node(nodes, edges) - with pytest.raises(ValueError): + with pytest.raises(AttributeError): + graph = Graph(nodes, edges) + root = get_root_node(graph) build_json(root, nodes, edges)