feat: implementation of Graph objects

This commit is contained in:
Gabriel Almeida 2023-03-24 10:49:19 -03:00
commit 616d015c5a
4 changed files with 120 additions and 29 deletions

View file

@ -17,6 +17,7 @@ from langchain.agents.load_tools import (
_EXTRA_LLM_TOOLS,
_EXTRA_OPTIONAL_TOOLS,
)
from langflow.utils.graph import Graph
def load_flow_from_json(path: str):
@ -36,8 +37,9 @@ def extract_json(data_graph):
nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
root = payload.get_root_node(nodes, edges)
return payload.build_json(root, nodes, edges)
graph = Graph(nodes, edges)
root = payload.get_root_node(graph)
return payload.build_json(root, graph)
def replace_zero_shot_prompt_with_prompt_template(nodes):

View file

@ -0,0 +1,87 @@
from typing import Dict, List, Union
class Node:
def __init__(self, data: Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]):
self.id: str = data["id"]
self._data = data
self.edges: List[Edge] = []
self._parse_data()
def _parse_data(self) -> None:
self.data = self._data["data"]
def add_edge(self, edge: "Edge") -> None:
self.edges.append(edge)
def __repr__(self) -> str:
return f"Node(id={self.id}, data={self.data})"
def __eq__(self, __o: object) -> bool:
return self.id == __o.id if isinstance(__o, Node) else False
def __hash__(self) -> int:
return id(self)
class Edge:
def __init__(self, source: "Node", target: "Node"):
self.source: "Node" = source
self.target: "Node" = target
def __repr__(self) -> str:
return f"Edge(source={self.source.id}, target={self.target.id})"
class Graph:
def __init__(
self,
nodes: List[Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]],
edges: List[Dict[str, str]],
) -> None:
self._nodes = nodes
self._edges = edges
self._build_graph()
def _build_graph(self) -> None:
self.nodes = self._build_nodes()
self.edges = self._build_edges()
for edge in self.edges:
edge.source.add_edge(edge)
edge.target.add_edge(edge)
def get_node(self, node_id: str) -> Union[None, Node]:
return next((node for node in self.nodes if node.id == node_id), None)
def get_connected_nodes(self, node_id: str) -> List[Node]:
connected_nodes: List[Node] = []
for edge in self.edges:
if edge.source.id == node_id:
connected_nodes.append(edge.target)
elif edge.target.id == node_id:
connected_nodes.append(edge.source)
return connected_nodes
def get_node_neighbors(self, node_id: str) -> Dict[str, int]:
neighbors: Dict[str, int] = {}
for edge in self.edges:
if edge.source.id == node_id:
neighbor_id = edge.target.id
if neighbor_id not in neighbors:
neighbors[neighbor_id] = 0
neighbors[neighbor_id] += 1
elif edge.target.id == node_id:
neighbor_id = edge.source.id
if neighbor_id not in neighbors:
neighbors[neighbor_id] = 0
neighbors[neighbor_id] += 1
return neighbors
def _build_edges(self) -> List[Edge]:
return [
Edge(self.get_node(edge["source"]), self.get_node(edge["target"]))
for edge in self._edges
]
def _build_nodes(self) -> List[Node]:
return [Node(node) for node in self._nodes]

View file

@ -27,25 +27,22 @@ def extract_input_variables(nodes):
return nodes
def get_root_node(nodes, edges):
def get_root_node(graph):
"""
Returns the root node of the template.
"""
incoming_edges = {edge["source"] for edge in edges}
return next((node for node in nodes if node["id"] not in incoming_edges), None)
incoming_edges = {edge.source for edge in graph.edges}
return next((node for node in graph.nodes if node not in incoming_edges), None)
def build_json(root, nodes, edges):
"""
Builds a json from the nodes and edges
"""
edge_ids = [edge["source"] for edge in edges if edge["target"] == root["id"]]
local_nodes = [node for node in nodes if node["id"] in edge_ids]
def build_json(root, graph):
edge_ids = [edge.source for edge in graph.edges if edge.target == root]
local_nodes = [node for node in graph.nodes if node in edge_ids]
if "node" not in root["data"]:
return build_json(local_nodes[0], nodes, edges)
if "node" not in root.data:
return build_json(local_nodes[0], graph)
final_dict = root["data"]["node"]["template"].copy()
final_dict = root.data["node"]["template"].copy()
for key, value in final_dict.items():
if key == "_type":
@ -59,16 +56,16 @@ def build_json(root, nodes, edges):
value = {}
else:
children = []
for c in local_nodes:
module_types = [c["data"]["type"]]
if "node" in c["data"]:
module_types += c["data"]["node"]["base_classes"]
for local_node in local_nodes:
module_types = [local_node.data["type"]]
if "node" in local_node.data:
module_types += local_node.data["node"]["base_classes"]
if module_type in module_types:
children.append(c)
children.append(local_node)
if value["required"] and not children:
raise ValueError(f"No child with type {module_type} found")
values = [build_json(child, nodes, edges) for child in children]
values = [build_json(child, graph) for child in children]
value = list(values) if value["list"] else next(iter(values), None)
final_dict[key] = value
return final_dict

View file

@ -1,5 +1,6 @@
import json
from langchain import LLMChain, OpenAI
from langflow.utils.graph import Graph
import pytest
from pathlib import Path
from langflow import load_flow_from_json
@ -31,10 +32,11 @@ def test_get_root_node():
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
root = get_root_node(nodes, edges)
graph = Graph(nodes, edges)
root = get_root_node(graph)
assert root is not None
assert "id" in root
assert "data" in root
assert hasattr(root, "id")
assert hasattr(root, "data")
def test_build_json():
@ -43,8 +45,9 @@ def test_build_json():
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
root = get_root_node(nodes, edges)
built_json = build_json(root, nodes, edges)
graph = Graph(nodes, edges)
root = get_root_node(graph)
built_json = build_json(root, graph)
assert built_json is not None
assert isinstance(built_json, dict)
@ -63,9 +66,10 @@ def test_build_json_missing_child():
if isinstance(value, dict) and "required" in value:
value["required"] = True
root = get_root_node(nodes, edges)
graph = Graph(nodes, edges)
root = get_root_node(graph)
with pytest.raises(ValueError):
build_json(root, nodes, edges)
build_json(root, graph)
def test_build_json_no_nodes():
@ -83,8 +87,9 @@ def test_build_json_invalid_edge():
for edge in edges:
edge["source"] = "invalid_id"
root = get_root_node(nodes, edges)
with pytest.raises(ValueError):
with pytest.raises(AttributeError):
graph = Graph(nodes, edges)
root = get_root_node(graph)
build_json(root, nodes, edges)