Refactor graph class and add layer refinement algorithm

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-16 13:03:35 -03:00
commit b7427d08d3

View file

@ -6,7 +6,12 @@ from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.utils import process_flow
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, ToolkitVertex
from langflow.graph.vertex.types import (
ChatVertex,
FileToolVertex,
LLMVertex,
ToolkitVertex,
)
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.utils import payload
from loguru import logger
@ -127,7 +132,9 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
raise ValueError(
f"{vertex.vertex_type} is not connected to any other components"
)
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -140,7 +147,11 @@ class Graph:
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
"""Returns a list of edges for a given vertex."""
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
return [
edge
for edge in self.edges
if edge.source_id == vertex_id or edge.target_id == vertex_id
]
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
"""Returns the vertices connected to a vertex."""
@ -178,7 +189,9 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -237,7 +250,9 @@ class Graph:
edges.append(ContractEdge(source, target, edge))
return edges
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -267,14 +282,18 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -286,7 +305,9 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def layered_topological_sort(self):
@ -299,7 +320,9 @@ class Graph:
in_degree[edge.target_id] += 1
# Queue for vertices with no incoming edges
queue = deque(vertex.id for vertex in self.vertices if in_degree[vertex.id] == 0)
queue = deque(
vertex.id for vertex in self.vertices if in_degree[vertex.id] == 0
)
layers = []
current_layer = 0
@ -314,9 +337,40 @@ class Graph:
if in_degree[neighbor] == 0:
queue.append(neighbor)
current_layer += 1 # Next layer
new_layers = self.refine_layers(graph, layers)
return new_layers
return layers
return layers
return layers
return layers
return layers
def refine_layers(self, graph, initial_layers):
# Map each vertex to its current layer
vertex_to_layer = {}
for layer_index, layer in enumerate(initial_layers):
for vertex in layer:
vertex_to_layer[vertex] = layer_index
# Build the adjacency list for reverse lookup (dependencies)
refined_layers = [[] for _ in initial_layers] # Start with empty layers
new_layer_index_map = defaultdict(
int
) # Map each vertex to its highest dependency layer
for vertex_id, deps in graph.items():
for dep in deps:
new_layer_index_map[vertex_id] = (
max(new_layer_index_map[vertex_id], vertex_to_layer[dep]) - 1
)
for layer_index, layer in enumerate(initial_layers):
for vertex_id in layer:
# Place the vertex in the highest possible layer where its dependencies are met
new_layer_index = new_layer_index_map[vertex_id]
if new_layer_index > layer_index:
refined_layers[new_layer_index].append(vertex_id)
vertex_to_layer[vertex_id] = new_layer_index
else:
refined_layers[layer_index].append(vertex_id)
# Remove empty layers if any
refined_layers = [layer for layer in refined_layers if layer]
return refined_layers