feat: refactor graph vertex sorting (#2583)

* refactor: extract method from class to new func

* test: add new tests

* refactor: simplify funcs to improve readability

* refactor: extract new func from larger func

* refactor: remove recursion from func

* refactor: remove coupling with graph and vertex

* refactor: create adapter funcs to use new code

* refactor: add test for sorting up to vertex N with is_start=True

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Ítalo Johnny 2024-07-10 11:37:39 -03:00 committed by GitHub
commit aa1958a4ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 190 additions and 71 deletions

View file

@ -14,7 +14,7 @@ from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import find_start_component_id, process_flow
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.types import InterfaceVertex, StateVertex
@ -1197,74 +1197,6 @@ class Graph:
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
# Initial setup
visited = set() # To keep track of visited vertices
excluded = set() # To keep track of vertices that should be excluded
def get_successors(vertex, recursive=True):
# Recursively get the successors of the current vertex
successors = vertex.successors
if not successors:
return []
successors_result = []
for successor in successors:
# Just return a list of successors
if recursive:
next_successors = get_successors(successor)
successors_result.extend(next_successors)
successors_result.append(successor)
return successors_result
try:
stop_or_start_vertex = self.get_vertex(vertex_id)
stack = [vertex_id] # Use a list as a stack for DFS
except ValueError:
stop_or_start_vertex = self.get_root_of_group_node(vertex_id)
stack = [stop_or_start_vertex.id]
vertex_id = stop_or_start_vertex.id
stop_predecessors = [pre.id for pre in stop_or_start_vertex.predecessors]
# DFS to collect all vertices that can reach the specified vertex
while stack:
current_id = stack.pop()
if current_id not in visited and current_id not in excluded:
visited.add(current_id)
current_vertex = self.get_vertex(current_id)
# Assuming get_predecessors is a method that returns all vertices with edges to current_vertex
for predecessor in current_vertex.predecessors:
stack.append(predecessor.id)
if current_id == vertex_id:
# We should add to visited all the vertices that are successors of the current vertex
# and their successors and so on
# if the vertex is a start, it means we are starting from the beginning
# and getting successors
for successor in current_vertex.successors:
if is_start:
stack.append(successor.id)
else:
excluded.add(successor.id)
all_successors = get_successors(successor, recursive=False)
for successor in all_successors:
if is_start:
stack.append(successor.id)
else:
excluded.add(successor.id)
elif current_id not in stop_predecessors and is_start:
# If the current vertex is not the target vertex, we should add all its successors
# to the stack if they are not in visited
# If we are starting from the beginning, we should add all successors
for successor in current_vertex.successors:
if successor.id not in visited:
stack.append(successor.id)
# Filter the original graph's vertices and edges to keep only those in `visited`
vertices_to_keep = [self.get_vertex(vid) for vid in visited]
return vertices_to_keep
def layered_topological_sort(
self,
vertices: List[Vertex],
@ -1395,6 +1327,21 @@ class Graph:
max_index = max(max_index, index_map[successor.id])
return max_index
def __to_dict(self) -> Dict[str, Dict[str, List[str]]]:
"""Converts the graph to a dictionary."""
result: Dict = dict()
for vertex in self.vertices:
vertex_id = vertex.id
sucessors = [i.id for i in self.get_all_successors(vertex)]
predecessors = [i.id for i in self.get_predecessors(vertex)]
result |= {vertex_id: {"successors": sucessors, "predecessors": predecessors}}
return result
def __filter_vertices(self, vertex_id: str, is_start: bool = False):
dictionaryized_graph = self.__to_dict()
vertex_ids = sort_up_to_vertex(dictionaryized_graph, vertex_id, is_start)
return [self.get_vertex(vertex_id) for vertex_id in vertex_ids]
def sort_vertices(
self,
stop_component_id: Optional[str] = None,
@ -1404,9 +1351,11 @@ class Graph:
self.mark_all_vertices("ACTIVE")
if stop_component_id is not None:
self.stop_vertex = stop_component_id
vertices = self.sort_up_to_vertex(stop_component_id)
vertices = self.__filter_vertices(stop_component_id)
elif start_component_id:
vertices = self.sort_up_to_vertex(start_component_id, is_start=True)
vertices = self.__filter_vertices(start_component_id, is_start=True)
else:
vertices = self.vertices
# without component_id we are probably running in the chat

View file

@ -1,6 +1,8 @@
from typing import List, Dict
import copy
from collections import deque
PRIORITY_LIST_OF_INPUTS = ["webhook", "chat"]
@ -224,3 +226,49 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
if edge["target"] == group_node_id or edge["source"] == group_node_id:
updated_edges.append(new_edge)
return updated_edges
def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> List[str]:
successors_result = []
stack = [vertex_id]
while stack:
current_id = stack.pop()
successors_result.append(current_id)
stack.extend(graph[current_id]["successors"])
return successors_result
def sort_up_to_vertex(graph: Dict[str, Dict[str, List[str]]], vertex_id: str, is_start: bool = False) -> List[str]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
try:
stop_or_start_vertex = graph[vertex_id]
except KeyError:
raise ValueError(f"Vertex {vertex_id} not found into graph")
visited, excluded = set(), set()
stack = [vertex_id]
stop_predecessors = set(stop_or_start_vertex["predecessors"])
while stack:
current_id = stack.pop()
if current_id in visited or current_id in excluded:
continue
visited.add(current_id)
current_vertex = graph[current_id]
stack.extend(current_vertex["predecessors"])
if current_id == vertex_id or (current_id not in stop_predecessors and is_start):
for successor_id in current_vertex["successors"]:
if is_start:
stack.append(successor_id)
else:
excluded.add(successor_id)
for succ_id in get_successors(graph, successor_id):
if is_start:
stack.append(succ_id)
else:
excluded.add(succ_id)
return list(visited)

View file

@ -0,0 +1,122 @@
import pytest
from langflow.graph.graph import utils
@pytest.fixture
def graph():
return {
"A": {"successors": ["B"], "predecessors": []},
"B": {"successors": ["D"], "predecessors": ["A", "C"]},
"C": {"successors": ["B", "I"], "predecessors": ["N"]},
"D": {"successors": ["E", "F"], "predecessors": ["B"]},
"E": {"successors": ["G"], "predecessors": ["D"]},
"F": {"successors": ["G", "H"], "predecessors": ["D"]},
"G": {"successors": [], "predecessors": ["E", "F"]},
"H": {"successors": [], "predecessors": ["F"]},
"I": {"successors": ["M"], "predecessors": ["C", "J"]},
"J": {"successors": ["I", "K"], "predecessors": ["N"]},
"K": {"successors": ["Q", "P", "O"], "predecessors": ["J", "L"]},
"L": {"successors": ["K"], "predecessors": []},
"M": {"successors": [], "predecessors": ["I"]},
"N": {"successors": ["C", "J"], "predecessors": []},
"O": {"successors": ["R"], "predecessors": ["K"]},
"P": {"successors": ["U"], "predecessors": ["K"]},
"Q": {"successors": ["V"], "predecessors": ["K"]},
"R": {"successors": ["S"], "predecessors": ["O"]},
"S": {"successors": ["T"], "predecessors": ["R"]},
"T": {"successors": [], "predecessors": ["S"]},
"U": {"successors": ["W"], "predecessors": ["P"]},
"V": {"successors": ["Y"], "predecessors": ["Q"]},
"W": {"successors": ["X"], "predecessors": ["U"]},
"X": {"successors": [], "predecessors": ["W"]},
"Y": {"successors": ["Z"], "predecessors": ["V"]},
"Z": {"successors": [], "predecessors": ["Y"]},
}
def test_get_successors_a(graph):
vertex_id = "A"
result = utils.get_successors(graph, vertex_id)
assert set(result) == {"A", "B", "D", "E", "F", "H", "G"}
def test_get_successors_z(graph):
vertex_id = "Z"
result = utils.get_successors(graph, vertex_id)
assert set(result) == {"Z"}
def test_sort_up_to_vertex_n_is_start(graph):
vertex_id = "N"
result = utils.sort_up_to_vertex(graph, vertex_id, is_start=True)
# Result shoud be all the vertices
assert set(result) == set(graph.keys())
def test_sort_up_to_vertex_z(graph):
vertex_id = "Z"
result = utils.sort_up_to_vertex(graph, vertex_id)
assert set(result) == {"L", "N", "J", "K", "Q", "V", "Y", "Z"}
def test_sort_up_to_vertex_x(graph):
vertex_id = "X"
result = utils.sort_up_to_vertex(graph, vertex_id)
assert set(result) == {"L", "N", "J", "K", "P", "U", "W", "X"}
def test_sort_up_to_vertex_t(graph):
vertex_id = "T"
result = utils.sort_up_to_vertex(graph, vertex_id)
assert set(result) == {"L", "N", "J", "K", "O", "R", "S", "T"}
def test_sort_up_to_vertex_m(graph):
vertex_id = "M"
result = utils.sort_up_to_vertex(graph, vertex_id)
assert set(result) == {"N", "C", "J", "I", "M"}
def test_sort_up_to_vertex_h(graph):
vertex_id = "H"
result = utils.sort_up_to_vertex(graph, vertex_id)
assert set(result) == {"N", "C", "A", "B", "D", "F", "H"}
def test_sort_up_to_vertex_g(graph):
vertex_id = "G"
result = utils.sort_up_to_vertex(graph, vertex_id)
assert set(result) == {"N", "C", "A", "B", "D", "F", "E", "G"}
def test_sort_up_to_vertex_a(graph):
vertex_id = "A"
result = utils.sort_up_to_vertex(graph, vertex_id)
assert set(result) == {"A"}
def test_sort_up_to_vertex_invalid_vertex(graph):
vertex_id = "7"
with pytest.raises(ValueError):
utils.sort_up_to_vertex(graph, vertex_id)