refactor: removes most of the circular dependencies in the Graph

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-27 21:50:38 -03:00
commit 3d20c8dc38
3 changed files with 170 additions and 178 deletions

View file

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, List, Optional
from loguru import logger
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field
from typing import List, Optional
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
@ -22,8 +22,8 @@ class TargetHandle(BaseModel):
class Edge:
def __init__(self, source: "Vertex", target: "Vertex", edge: dict):
self.source: "Vertex" = source
self.target: "Vertex" = target
self.source_id: str = source.id
self.target_id: str = target.id
if data := edge.get("data", {}):
self._source_handle = data.get("sourceHandle", {})
self._target_handle = data.get("targetHandle", {})
@ -31,7 +31,7 @@ class Edge:
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
self.target_param = self.target_handle.fieldName
# validate handles
self.validate_handles()
self.validate_handles(source, target)
else:
# Logging here because this is a breaking change
logger.error("Edge data is empty")
@ -41,9 +41,9 @@ class Edge:
# target_param is documents
self.target_param = self._target_handle.split("|")[1]
# Validate in __init__ to fail fast
self.validate_edge()
self.validate_edge(source, target)
def validate_handles(self) -> None:
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
else:
@ -54,26 +54,20 @@ class Edge:
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has invalid handles"
)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
def __setstate__(self, state):
self.source = state["source"]
self.target = state["target"]
self.source_id = state["source_id"]
self.target_id = state["target_id"]
self.target_param = state["target_param"]
self.source_handle = state.get("source_handle")
self.target_handle = state.get("target_handle")
def reset(self) -> None:
self.source._build_params()
self.target._build_params()
def validate_edge(self) -> None:
def validate_edge(self, source, target) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node
self.source_types = self.source.output
self.target_reqs = self.target.required_inputs + self.target.optional_inputs
self.source_types = source.output
self.target_reqs = target.required_inputs + target.optional_inputs
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
@ -88,13 +82,11 @@ class Edge:
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " f"has no matched type"
)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
def __repr__(self) -> str:
return (
f"Edge(source={self.source.id}, target={self.target.id}, target_param={self.target_param}"
f"Edge(source={self.source_id}, target={self.target_id}, target_param={self.target_param}"
f", matched_type={self.matched_type})"
)

View file

@ -13,24 +13,24 @@ from langflow.utils import payload
class Graph:
"""A class representing a graph of nodes and edges."""
"""A class representing a graph of vertices and edges."""
def __init__(
self,
nodes: List[Dict],
edges: List[Dict[str, str]],
) -> None:
self._nodes = nodes
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
self.top_level_nodes = []
for node in self._nodes:
if node_id := node.get("id"):
self.top_level_nodes.append(node_id)
self.top_level_vertices = []
for vertex in self._vertices:
if vertex_id := vertex.get("id"):
self.top_level_vertices.append(vertex_id)
self._graph_data = process_flow(self.raw_graph_data)
self._nodes = self._graph_data["nodes"]
self._vertices = self._graph_data["nodes"]
self._edges = self._graph_data["edges"]
self._build_graph()
@ -54,9 +54,9 @@ class Graph:
if "data" in payload:
payload = payload["data"]
try:
nodes = payload["nodes"]
vertices = payload["nodes"]
edges = payload["edges"]
return cls(nodes, edges)
return cls(vertices, edges)
except KeyError as exc:
logger.exception(exc)
raise ValueError(
@ -69,61 +69,69 @@ class Graph:
return self.__repr__() == other.__repr__()
def _build_graph(self) -> None:
"""Builds the graph from the nodes and edges."""
self.nodes = self._build_vertices()
"""Builds the graph from the vertices and edges."""
self.vertices = self._build_vertices()
self.vertex_ids = [vertex.id for vertex in self.vertices]
self.edges = self._build_edges()
for edge in self.edges:
edge.source.add_edge(edge)
edge.target.add_edge(edge)
# This is a hack to make sure that the LLM node is sent to
# the toolkit node
self._build_node_params()
# remove invalid nodes
self._validate_nodes()
# This is a hack to make sure that the LLM vertex is sent to
# the toolkit vertex
self._build_vertex_params()
# remove invalid vertices
self._validate_vertices()
def _build_node_params(self) -> None:
"""Identifies and handles the LLM node within the graph."""
llm_node = None
for node in self.nodes:
node._build_params()
if isinstance(node, LLMVertex):
llm_node = node
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
llm_vertex = None
for vertex in self.vertices:
vertex._build_params()
if isinstance(vertex, LLMVertex):
llm_vertex = vertex
if llm_node:
for node in self.nodes:
if isinstance(node, ToolkitVertex):
node.params["llm"] = llm_node
if llm_vertex:
for vertex in self.vertices:
if isinstance(vertex, ToolkitVertex):
vertex.params["llm"] = llm_vertex
def _validate_nodes(self) -> None:
"""Check that all nodes have edges"""
if len(self.nodes) == 1:
def _validate_vertices(self) -> None:
"""Check that all vertices have edges"""
if len(self.vertices) == 1:
return
for node in self.nodes:
if not self._validate_node(node):
raise ValueError(f"{node.vertex_type} is not connected to any other components")
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
def _validate_node(self, node: Vertex) -> bool:
"""Validates a node."""
# All nodes that do not have edges are invalid
return len(node.edges) > 0
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
# All vertices that do not have edges are invalid
return len(self.get_vertex_edges(vertex.id)) > 0
def get_node(self, node_id: str) -> Union[None, Vertex]:
"""Returns a node by id."""
return next((node for node in self.nodes if node.id == node_id), None)
def get_vertex(self, vertex_id: str) -> Union[None, Vertex]:
"""Returns a vertex by id."""
return next((vertex for vertex in self.vertices if vertex.id == vertex_id), None)
def get_nodes_with_target(self, node: Vertex) -> List[Vertex]:
"""Returns the nodes connected to a node."""
connected_nodes: List[Vertex] = [edge.source for edge in self.edges if edge.target == node]
return connected_nodes
def get_vertex_edges(self, vertex_id: str) -> List[Edge]:
"""Returns a list of edges for a given vertex."""
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
"""Returns the vertices connected to a vertex."""
vertices: List[Vertex] = []
for edge in self.edges:
if edge.target_id == vertex_id:
vertex = self.get_vertex(edge.source_id)
if vertex is None:
continue
vertices.append(vertex)
return vertices
async def build(self) -> Chain:
"""Builds the graph."""
# Get root node
root_node = payload.get_root_node(self)
if root_node is None:
raise ValueError("No root node found")
return await root_node.build()
# Get root vertex
root_vertex = payload.get_root_vertex(self)
if root_vertex is None:
raise ValueError("No root vertex found")
return await root_vertex.build()
def topological_sort(self) -> List[Vertex]:
"""
@ -136,25 +144,25 @@ class Graph:
ValueError: If the graph contains a cycle.
"""
# States: 0 = unvisited, 1 = visiting, 2 = visited
state = {node: 0 for node in self.nodes}
state = {vertex: 0 for vertex in self.vertices}
sorted_vertices = []
def dfs(node):
if state[node] == 1:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
if state[node] == 0:
state[node] = 1
for edge in node.edges:
if edge.source == node:
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
if edge.source == vertex:
dfs(edge.target)
state[node] = 2
sorted_vertices.append(node)
state[vertex] = 2
sorted_vertices.append(vertex)
# Visit each node
for node in self.nodes:
if state[node] == 0:
dfs(node)
# Visit each vertex
for vertex in self.vertices:
if state[vertex] == 0:
dfs(vertex)
return list(reversed(sorted_vertices))
@ -164,17 +172,21 @@ class Graph:
logger.debug("There are %s vertices in the graph", len(sorted_vertices))
yield from sorted_vertices
def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a node."""
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
neighbors: Dict[Vertex, int] = {}
for edge in self.edges:
if edge.source == node:
neighbor = edge.target
if edge.source_id == vertex.id:
neighbor = self.get_vertex(edge.target_id)
if neighbor is None:
continue
if neighbor not in neighbors:
neighbors[neighbor] = 0
neighbors[neighbor] += 1
elif edge.target == node:
neighbor = edge.source
elif edge.target_id == vertex.id:
neighbor = self.get_vertex(edge.source_id)
if neighbor is None:
continue
if neighbor not in neighbors:
neighbors[neighbor] = 0
neighbors[neighbor] += 1
@ -182,59 +194,59 @@ class Graph:
def _build_edges(self) -> List[Edge]:
"""Builds the edges of the graph."""
# Edge takes two nodes as arguments, so we need to build the nodes first
# Edge takes two vertices as arguments, so we need to build the vertices first
# and then build the edges
# if we can't find a node, we raise an error
# if we can't find a vertex, we raise an error
edges: List[Edge] = []
for edge in self._edges:
source = self.get_node(edge["source"])
target = self.get_node(edge["target"])
source = self.get_vertex(edge["source"])
target = self.get_vertex(edge["target"])
if source is None:
raise ValueError(f"Source node {edge['source']} not found")
raise ValueError(f"Source vertex {edge['source']} not found")
if target is None:
raise ValueError(f"Target node {edge['target']} not found")
raise ValueError(f"Target vertex {edge['target']} not found")
edges.append(Edge(source, target, edge))
return edges
def _get_vertex_class(self, node_type: str, node_lc_type: str) -> Type[Vertex]:
"""Returns the node class based on the node type."""
if node_type in FILE_TOOLS:
def _get_vertex_class(self, vertex_type: str, vertex_lc_type: str) -> Type[Vertex]:
"""Returns the vertex class based on the vertex type."""
if vertex_type in FILE_TOOLS:
return FileToolVertex
if node_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_type]
if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type]
return (
lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_lc_type]
if node_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_lc_type]
if vertex_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
else Vertex
)
def _build_vertices(self) -> List[Vertex]:
"""Builds the vertices of the graph."""
nodes: List[Vertex] = []
for node in self._nodes:
node_data = node["data"]
node_type: str = node_data["type"] # type: ignore
node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore
vertices: List[Vertex] = []
for vertex in self._vertices:
vertex_data = vertex["data"]
vertex_type: str = vertex_data["type"] # type: ignore
vertex_lc_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(node_type, node_lc_type)
vertex = VertexClass(node)
vertex.set_top_level(self.top_level_nodes)
nodes.append(vertex)
VertexClass = self._get_vertex_class(vertex_type, vertex_lc_type)
vertex = VertexClass(vertex, graph=self)
vertex.set_top_level(self.top_level_vertices)
vertices.append(vertex)
return nodes
return vertices
def get_children_by_node_type(self, node: Vertex, node_type: str) -> List[Vertex]:
"""Returns the children of a node based on the node type."""
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
node_types = [node.data["type"]]
if "node" in node.data:
node_types += node.data["node"]["base_classes"]
if node_type in node_types:
children.append(node)
vertex_types = [vertex.data["type"]]
if "node" in vertex.data:
vertex_types += vertex.data["node"]["base_classes"]
if vertex_type in vertex_types:
children.append(vertex)
return children
def __repr__(self):
node_ids = [node.id for node in self.nodes]
edges_repr = "\n".join([f"{edge.source.id} --> {edge.target.id}" for edge in self.edges])
return f"Graph:\nNodes: {node_ids}\nConnections:\n{edges_repr}"
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"

View file

@ -1,33 +1,32 @@
import ast
import inspect
import pickle
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from loguru import logger
from langflow.graph.utils import UnbuiltObject
from langflow.graph.vertex.utils import is_basic_type
from langflow.interface.initialize import loading
from langflow.interface.listing import lazy_load_dict
from langflow.utils.constants import DIRECT_TYPES
from langflow.utils.util import sync_to_async
from loguru import logger
if TYPE_CHECKING:
from langflow.graph.edge.base import Edge
from langflow.graph.graph.base import Graph
class Vertex:
def __init__(
self,
data: Dict,
graph: "Graph",
base_type: Optional[str] = None,
is_task: bool = False,
params: Optional[Dict] = None,
) -> None:
self.graph = graph
self.id: str = data["id"]
self._data = data
self.edges: List["Edge"] = []
self.base_type: Optional[str] = base_type
self._parse_data()
self._built_object = UnbuiltObject()
@ -39,43 +38,28 @@ class Vertex:
self.parent_node_id: Optional[str] = self._data.get("parent_node_id")
self.parent_is_top_level = False
def reset_params(self):
for edge in self.edges:
if edge.source != self:
target_param = edge.target_param
if target_param in ["document", "texts"]:
# this means they got data and have already ingested it
# so we continue after removing the param
self.params.pop(target_param, None)
continue
if target_param in self.params and not is_basic_type(self.params[target_param]):
# edge.source.params = {}
edge.source._build_params()
edge.source._built_object = UnbuiltObject()
edge.source._built = False
self.params[target_param] = edge.source
@property
def edges(self) -> List["Edge"]:
return self.graph.get_vertex_edges(self.id)
def __getstate__(self):
state_dict = self.__dict__.copy()
try:
# try pickling the built object
# if it fails, then we need to delete it
# and build it again
pickle.dumps(state_dict["_built_object"])
except Exception:
self.reset_params()
del state_dict["_built_object"]
del state_dict["_built"]
return state_dict
return {
"_data": self._data,
"params": {},
"base_type": self.base_type,
"is_task": self.is_task,
"id": self.id,
"_built_object": UnbuiltObject(),
"_built": False,
"parent_node_id": self.parent_node_id,
"parent_is_top_level": self.parent_is_top_level,
}
def __setstate__(self, state):
self._data = state["_data"]
self.params = state["params"]
self.base_type = state["base_type"]
self.is_task = state["is_task"]
self.edges = state["edges"]
self.id = state["id"]
self._parse_data()
if "_built_object" in state:
@ -144,6 +128,10 @@ class Vertex:
# and use that as the value for the param
# If the type is "str", then we need to get the value of the "value" key
# and use that as the value for the param
if self.graph is None:
raise ValueError("Graph not found")
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
params = self.params.copy() if self.params else {}
@ -155,9 +143,9 @@ class Vertex:
if template_dict[param_key]["list"]:
if param_key not in params:
params[param_key] = []
params[param_key].append(edge.source)
elif edge.target.id == self.id:
params[param_key] = edge.source
params[param_key].append(self.graph.get_vertex(edge.source_id))
elif edge.target_id == self.id:
params[param_key] = self.graph.get_vertex(edge.source_id)
for key, value in template_dict.items():
if key in params:
@ -177,33 +165,33 @@ class Vertex:
else:
raise ValueError(f"File path not found for {self.vertex_type}")
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
val = value.get("value")
if value.get("type") == "code":
try:
params[key] = ast.literal_eval(value.get("value"))
params[key] = ast.literal_eval(val) if val else None
except Exception as exc:
logger.debug(f"Error parsing code: {exc}")
params[key] = value.get("value")
params[key] = val
elif value.get("type") in ["dict", "NestedDict"]:
# When dict comes from the frontend it comes as a
# list of dicts, so we need to convert it to a dict
# before passing it to the build method
_value = value.get("value")
if isinstance(_value, list):
if isinstance(val, list):
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
elif isinstance(_value, dict):
params[key] = _value
elif value.get("type") == "int" and value.get("value") is not None:
elif isinstance(val, dict):
params[key] = val
elif value.get("type") == "int" and val is not None:
try:
params[key] = int(value.get("value"))
params[key] = int(val)
except ValueError:
params[key] = value.get("value")
elif value.get("type") == "float" and value.get("value") is not None:
params[key] = val
elif value.get("type") == "float" and val is not None:
try:
params[key] = float(value.get("value"))
params[key] = float(val)
except ValueError:
params[key] = value.get("value")
params[key] = val
else:
params[key] = value.get("value")
params[key] = val
if not value.get("required") and params.get(key) is None:
if value.get("default"):
@ -266,7 +254,7 @@ class Vertex:
pass
# If there's no task_id, build the vertex locally
await self.build(user_id)
await self.build(user_id=user_id)
return self._built_object
async def _build_node_and_update_params(self, key, node, user_id=None):