feat: implementation of Graph objects
This commit is contained in:
parent
ac24b1b655
commit
616d015c5a
4 changed files with 120 additions and 29 deletions
|
|
@ -17,6 +17,7 @@ from langchain.agents.load_tools import (
|
|||
_EXTRA_LLM_TOOLS,
|
||||
_EXTRA_OPTIONAL_TOOLS,
|
||||
)
|
||||
from langflow.utils.graph import Graph
|
||||
|
||||
|
||||
def load_flow_from_json(path: str):
|
||||
|
|
@ -36,8 +37,9 @@ def extract_json(data_graph):
|
|||
nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
root = payload.get_root_node(nodes, edges)
|
||||
return payload.build_json(root, nodes, edges)
|
||||
graph = Graph(nodes, edges)
|
||||
root = payload.get_root_node(graph)
|
||||
return payload.build_json(root, graph)
|
||||
|
||||
|
||||
def replace_zero_shot_prompt_with_prompt_template(nodes):
|
||||
|
|
|
|||
87
src/backend/langflow/utils/graph.py
Normal file
87
src/backend/langflow/utils/graph.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, data: Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]):
|
||||
self.id: str = data["id"]
|
||||
self._data = data
|
||||
self.edges: List[Edge] = []
|
||||
self._parse_data()
|
||||
|
||||
def _parse_data(self) -> None:
|
||||
self.data = self._data["data"]
|
||||
|
||||
def add_edge(self, edge: "Edge") -> None:
|
||||
self.edges.append(edge)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Node(id={self.id}, data={self.data})"
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self.id == __o.id if isinstance(__o, Node) else False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, source: "Node", target: "Node"):
|
||||
self.source: "Node" = source
|
||||
self.target: "Node" = target
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Edge(source={self.source.id}, target={self.target.id})"
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]],
|
||||
edges: List[Dict[str, str]],
|
||||
) -> None:
|
||||
self._nodes = nodes
|
||||
self._edges = edges
|
||||
self._build_graph()
|
||||
|
||||
def _build_graph(self) -> None:
|
||||
self.nodes = self._build_nodes()
|
||||
self.edges = self._build_edges()
|
||||
for edge in self.edges:
|
||||
edge.source.add_edge(edge)
|
||||
edge.target.add_edge(edge)
|
||||
|
||||
def get_node(self, node_id: str) -> Union[None, Node]:
|
||||
return next((node for node in self.nodes if node.id == node_id), None)
|
||||
|
||||
def get_connected_nodes(self, node_id: str) -> List[Node]:
|
||||
connected_nodes: List[Node] = []
|
||||
for edge in self.edges:
|
||||
if edge.source.id == node_id:
|
||||
connected_nodes.append(edge.target)
|
||||
elif edge.target.id == node_id:
|
||||
connected_nodes.append(edge.source)
|
||||
return connected_nodes
|
||||
|
||||
def get_node_neighbors(self, node_id: str) -> Dict[str, int]:
|
||||
neighbors: Dict[str, int] = {}
|
||||
for edge in self.edges:
|
||||
if edge.source.id == node_id:
|
||||
neighbor_id = edge.target.id
|
||||
if neighbor_id not in neighbors:
|
||||
neighbors[neighbor_id] = 0
|
||||
neighbors[neighbor_id] += 1
|
||||
elif edge.target.id == node_id:
|
||||
neighbor_id = edge.source.id
|
||||
if neighbor_id not in neighbors:
|
||||
neighbors[neighbor_id] = 0
|
||||
neighbors[neighbor_id] += 1
|
||||
return neighbors
|
||||
|
||||
def _build_edges(self) -> List[Edge]:
|
||||
return [
|
||||
Edge(self.get_node(edge["source"]), self.get_node(edge["target"]))
|
||||
for edge in self._edges
|
||||
]
|
||||
|
||||
def _build_nodes(self) -> List[Node]:
|
||||
return [Node(node) for node in self._nodes]
|
||||
|
|
@ -27,25 +27,22 @@ def extract_input_variables(nodes):
|
|||
return nodes
|
||||
|
||||
|
||||
def get_root_node(nodes, edges):
|
||||
def get_root_node(graph):
|
||||
"""
|
||||
Returns the root node of the template.
|
||||
"""
|
||||
incoming_edges = {edge["source"] for edge in edges}
|
||||
return next((node for node in nodes if node["id"] not in incoming_edges), None)
|
||||
incoming_edges = {edge.source for edge in graph.edges}
|
||||
return next((node for node in graph.nodes if node not in incoming_edges), None)
|
||||
|
||||
|
||||
def build_json(root, nodes, edges):
|
||||
"""
|
||||
Builds a json from the nodes and edges
|
||||
"""
|
||||
edge_ids = [edge["source"] for edge in edges if edge["target"] == root["id"]]
|
||||
local_nodes = [node for node in nodes if node["id"] in edge_ids]
|
||||
def build_json(root, graph):
|
||||
edge_ids = [edge.source for edge in graph.edges if edge.target == root]
|
||||
local_nodes = [node for node in graph.nodes if node in edge_ids]
|
||||
|
||||
if "node" not in root["data"]:
|
||||
return build_json(local_nodes[0], nodes, edges)
|
||||
if "node" not in root.data:
|
||||
return build_json(local_nodes[0], graph)
|
||||
|
||||
final_dict = root["data"]["node"]["template"].copy()
|
||||
final_dict = root.data["node"]["template"].copy()
|
||||
|
||||
for key, value in final_dict.items():
|
||||
if key == "_type":
|
||||
|
|
@ -59,16 +56,16 @@ def build_json(root, nodes, edges):
|
|||
value = {}
|
||||
else:
|
||||
children = []
|
||||
for c in local_nodes:
|
||||
module_types = [c["data"]["type"]]
|
||||
if "node" in c["data"]:
|
||||
module_types += c["data"]["node"]["base_classes"]
|
||||
for local_node in local_nodes:
|
||||
module_types = [local_node.data["type"]]
|
||||
if "node" in local_node.data:
|
||||
module_types += local_node.data["node"]["base_classes"]
|
||||
if module_type in module_types:
|
||||
children.append(c)
|
||||
children.append(local_node)
|
||||
|
||||
if value["required"] and not children:
|
||||
raise ValueError(f"No child with type {module_type} found")
|
||||
values = [build_json(child, nodes, edges) for child in children]
|
||||
values = [build_json(child, graph) for child in children]
|
||||
value = list(values) if value["list"] else next(iter(values), None)
|
||||
final_dict[key] = value
|
||||
return final_dict
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from langchain import LLMChain, OpenAI
|
||||
from langflow.utils.graph import Graph
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from langflow import load_flow_from_json
|
||||
|
|
@ -31,10 +32,11 @@ def test_get_root_node():
|
|||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
root = get_root_node(nodes, edges)
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
assert root is not None
|
||||
assert "id" in root
|
||||
assert "data" in root
|
||||
assert hasattr(root, "id")
|
||||
assert hasattr(root, "data")
|
||||
|
||||
|
||||
def test_build_json():
|
||||
|
|
@ -43,8 +45,9 @@ def test_build_json():
|
|||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
root = get_root_node(nodes, edges)
|
||||
built_json = build_json(root, nodes, edges)
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
built_json = build_json(root, graph)
|
||||
assert built_json is not None
|
||||
assert isinstance(built_json, dict)
|
||||
|
||||
|
|
@ -63,9 +66,10 @@ def test_build_json_missing_child():
|
|||
if isinstance(value, dict) and "required" in value:
|
||||
value["required"] = True
|
||||
|
||||
root = get_root_node(nodes, edges)
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
with pytest.raises(ValueError):
|
||||
build_json(root, nodes, edges)
|
||||
build_json(root, graph)
|
||||
|
||||
|
||||
def test_build_json_no_nodes():
|
||||
|
|
@ -83,8 +87,9 @@ def test_build_json_invalid_edge():
|
|||
for edge in edges:
|
||||
edge["source"] = "invalid_id"
|
||||
|
||||
root = get_root_node(nodes, edges)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(AttributeError):
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
build_json(root, nodes, edges)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue