feat: refactor graph vertex sorting (#2583)
* refactor: extract method from class to new func * test: add new tests * refactor: simplify funcs to improve readability * refactor: extract new func from larger func * refactor: remove recursion from func * refactor: remove coupling with graph and vertex * refactor: create adapter funcs to use new code * refactor: add test for sorting up to vertex N with is_start=True --------- Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
3406575c67
commit
aa1958a4ad
3 changed files with 190 additions and 71 deletions
|
|
@ -14,7 +14,7 @@ from langflow.graph.edge.base import ContractEdge
|
|||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
|
||||
from langflow.graph.graph.state_manager import GraphStateManager
|
||||
from langflow.graph.graph.utils import find_start_component_id, process_flow
|
||||
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
|
||||
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
|
||||
from langflow.graph.vertex.base import Vertex, VertexStates
|
||||
from langflow.graph.vertex.types import InterfaceVertex, StateVertex
|
||||
|
|
@ -1197,74 +1197,6 @@ class Graph:
|
|||
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
|
||||
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
|
||||
|
||||
def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
|
||||
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
|
||||
# Initial setup
|
||||
visited = set() # To keep track of visited vertices
|
||||
excluded = set() # To keep track of vertices that should be excluded
|
||||
|
||||
def get_successors(vertex, recursive=True):
|
||||
# Recursively get the successors of the current vertex
|
||||
successors = vertex.successors
|
||||
if not successors:
|
||||
return []
|
||||
successors_result = []
|
||||
for successor in successors:
|
||||
# Just return a list of successors
|
||||
if recursive:
|
||||
next_successors = get_successors(successor)
|
||||
successors_result.extend(next_successors)
|
||||
successors_result.append(successor)
|
||||
return successors_result
|
||||
|
||||
try:
|
||||
stop_or_start_vertex = self.get_vertex(vertex_id)
|
||||
stack = [vertex_id] # Use a list as a stack for DFS
|
||||
except ValueError:
|
||||
stop_or_start_vertex = self.get_root_of_group_node(vertex_id)
|
||||
stack = [stop_or_start_vertex.id]
|
||||
vertex_id = stop_or_start_vertex.id
|
||||
stop_predecessors = [pre.id for pre in stop_or_start_vertex.predecessors]
|
||||
# DFS to collect all vertices that can reach the specified vertex
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
if current_id not in visited and current_id not in excluded:
|
||||
visited.add(current_id)
|
||||
current_vertex = self.get_vertex(current_id)
|
||||
# Assuming get_predecessors is a method that returns all vertices with edges to current_vertex
|
||||
for predecessor in current_vertex.predecessors:
|
||||
stack.append(predecessor.id)
|
||||
|
||||
if current_id == vertex_id:
|
||||
# We should add to visited all the vertices that are successors of the current vertex
|
||||
# and their successors and so on
|
||||
# if the vertex is a start, it means we are starting from the beginning
|
||||
# and getting successors
|
||||
for successor in current_vertex.successors:
|
||||
if is_start:
|
||||
stack.append(successor.id)
|
||||
else:
|
||||
excluded.add(successor.id)
|
||||
all_successors = get_successors(successor, recursive=False)
|
||||
for successor in all_successors:
|
||||
if is_start:
|
||||
stack.append(successor.id)
|
||||
else:
|
||||
excluded.add(successor.id)
|
||||
elif current_id not in stop_predecessors and is_start:
|
||||
# If the current vertex is not the target vertex, we should add all its successors
|
||||
# to the stack if they are not in visited
|
||||
|
||||
# If we are starting from the beginning, we should add all successors
|
||||
for successor in current_vertex.successors:
|
||||
if successor.id not in visited:
|
||||
stack.append(successor.id)
|
||||
|
||||
# Filter the original graph's vertices and edges to keep only those in `visited`
|
||||
vertices_to_keep = [self.get_vertex(vid) for vid in visited]
|
||||
|
||||
return vertices_to_keep
|
||||
|
||||
def layered_topological_sort(
|
||||
self,
|
||||
vertices: List[Vertex],
|
||||
|
|
@ -1395,6 +1327,21 @@ class Graph:
|
|||
max_index = max(max_index, index_map[successor.id])
|
||||
return max_index
|
||||
|
||||
def __to_dict(self) -> Dict[str, Dict[str, List[str]]]:
|
||||
"""Converts the graph to a dictionary."""
|
||||
result: Dict = dict()
|
||||
for vertex in self.vertices:
|
||||
vertex_id = vertex.id
|
||||
sucessors = [i.id for i in self.get_all_successors(vertex)]
|
||||
predecessors = [i.id for i in self.get_predecessors(vertex)]
|
||||
result |= {vertex_id: {"successors": sucessors, "predecessors": predecessors}}
|
||||
return result
|
||||
|
||||
def __filter_vertices(self, vertex_id: str, is_start: bool = False):
|
||||
dictionaryized_graph = self.__to_dict()
|
||||
vertex_ids = sort_up_to_vertex(dictionaryized_graph, vertex_id, is_start)
|
||||
return [self.get_vertex(vertex_id) for vertex_id in vertex_ids]
|
||||
|
||||
def sort_vertices(
|
||||
self,
|
||||
stop_component_id: Optional[str] = None,
|
||||
|
|
@ -1404,9 +1351,11 @@ class Graph:
|
|||
self.mark_all_vertices("ACTIVE")
|
||||
if stop_component_id is not None:
|
||||
self.stop_vertex = stop_component_id
|
||||
vertices = self.sort_up_to_vertex(stop_component_id)
|
||||
vertices = self.__filter_vertices(stop_component_id)
|
||||
|
||||
elif start_component_id:
|
||||
vertices = self.sort_up_to_vertex(start_component_id, is_start=True)
|
||||
vertices = self.__filter_vertices(start_component_id, is_start=True)
|
||||
|
||||
else:
|
||||
vertices = self.vertices
|
||||
# without component_id we are probably running in the chat
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from typing import List, Dict
|
||||
import copy
|
||||
from collections import deque
|
||||
|
||||
|
||||
PRIORITY_LIST_OF_INPUTS = ["webhook", "chat"]
|
||||
|
||||
|
||||
|
|
@ -224,3 +226,49 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
|
|||
if edge["target"] == group_node_id or edge["source"] == group_node_id:
|
||||
updated_edges.append(new_edge)
|
||||
return updated_edges
|
||||
|
||||
|
||||
def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> List[str]:
|
||||
successors_result = []
|
||||
stack = [vertex_id]
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
successors_result.append(current_id)
|
||||
stack.extend(graph[current_id]["successors"])
|
||||
return successors_result
|
||||
|
||||
|
||||
def sort_up_to_vertex(graph: Dict[str, Dict[str, List[str]]], vertex_id: str, is_start: bool = False) -> List[str]:
|
||||
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
|
||||
try:
|
||||
stop_or_start_vertex = graph[vertex_id]
|
||||
except KeyError:
|
||||
raise ValueError(f"Vertex {vertex_id} not found into graph")
|
||||
|
||||
visited, excluded = set(), set()
|
||||
stack = [vertex_id]
|
||||
stop_predecessors = set(stop_or_start_vertex["predecessors"])
|
||||
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
if current_id in visited or current_id in excluded:
|
||||
continue
|
||||
|
||||
visited.add(current_id)
|
||||
current_vertex = graph[current_id]
|
||||
|
||||
stack.extend(current_vertex["predecessors"])
|
||||
|
||||
if current_id == vertex_id or (current_id not in stop_predecessors and is_start):
|
||||
for successor_id in current_vertex["successors"]:
|
||||
if is_start:
|
||||
stack.append(successor_id)
|
||||
else:
|
||||
excluded.add(successor_id)
|
||||
for succ_id in get_successors(graph, successor_id):
|
||||
if is_start:
|
||||
stack.append(succ_id)
|
||||
else:
|
||||
excluded.add(succ_id)
|
||||
|
||||
return list(visited)
|
||||
|
|
|
|||
122
tests/unit/graph/graph/test_utils.py
Normal file
122
tests/unit/graph/graph/test_utils.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import pytest
|
||||
|
||||
from langflow.graph.graph import utils
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph():
|
||||
return {
|
||||
"A": {"successors": ["B"], "predecessors": []},
|
||||
"B": {"successors": ["D"], "predecessors": ["A", "C"]},
|
||||
"C": {"successors": ["B", "I"], "predecessors": ["N"]},
|
||||
"D": {"successors": ["E", "F"], "predecessors": ["B"]},
|
||||
"E": {"successors": ["G"], "predecessors": ["D"]},
|
||||
"F": {"successors": ["G", "H"], "predecessors": ["D"]},
|
||||
"G": {"successors": [], "predecessors": ["E", "F"]},
|
||||
"H": {"successors": [], "predecessors": ["F"]},
|
||||
"I": {"successors": ["M"], "predecessors": ["C", "J"]},
|
||||
"J": {"successors": ["I", "K"], "predecessors": ["N"]},
|
||||
"K": {"successors": ["Q", "P", "O"], "predecessors": ["J", "L"]},
|
||||
"L": {"successors": ["K"], "predecessors": []},
|
||||
"M": {"successors": [], "predecessors": ["I"]},
|
||||
"N": {"successors": ["C", "J"], "predecessors": []},
|
||||
"O": {"successors": ["R"], "predecessors": ["K"]},
|
||||
"P": {"successors": ["U"], "predecessors": ["K"]},
|
||||
"Q": {"successors": ["V"], "predecessors": ["K"]},
|
||||
"R": {"successors": ["S"], "predecessors": ["O"]},
|
||||
"S": {"successors": ["T"], "predecessors": ["R"]},
|
||||
"T": {"successors": [], "predecessors": ["S"]},
|
||||
"U": {"successors": ["W"], "predecessors": ["P"]},
|
||||
"V": {"successors": ["Y"], "predecessors": ["Q"]},
|
||||
"W": {"successors": ["X"], "predecessors": ["U"]},
|
||||
"X": {"successors": [], "predecessors": ["W"]},
|
||||
"Y": {"successors": ["Z"], "predecessors": ["V"]},
|
||||
"Z": {"successors": [], "predecessors": ["Y"]},
|
||||
}
|
||||
|
||||
|
||||
def test_get_successors_a(graph):
|
||||
vertex_id = "A"
|
||||
|
||||
result = utils.get_successors(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"A", "B", "D", "E", "F", "H", "G"}
|
||||
|
||||
|
||||
def test_get_successors_z(graph):
|
||||
vertex_id = "Z"
|
||||
|
||||
result = utils.get_successors(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"Z"}
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_n_is_start(graph):
|
||||
vertex_id = "N"
|
||||
|
||||
result = utils.sort_up_to_vertex(graph, vertex_id, is_start=True)
|
||||
# Result shoud be all the vertices
|
||||
assert set(result) == set(graph.keys())
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_z(graph):
|
||||
vertex_id = "Z"
|
||||
|
||||
result = utils.sort_up_to_vertex(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"L", "N", "J", "K", "Q", "V", "Y", "Z"}
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_x(graph):
|
||||
vertex_id = "X"
|
||||
|
||||
result = utils.sort_up_to_vertex(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"L", "N", "J", "K", "P", "U", "W", "X"}
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_t(graph):
|
||||
vertex_id = "T"
|
||||
|
||||
result = utils.sort_up_to_vertex(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"L", "N", "J", "K", "O", "R", "S", "T"}
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_m(graph):
|
||||
vertex_id = "M"
|
||||
|
||||
result = utils.sort_up_to_vertex(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"N", "C", "J", "I", "M"}
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_h(graph):
|
||||
vertex_id = "H"
|
||||
|
||||
result = utils.sort_up_to_vertex(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"N", "C", "A", "B", "D", "F", "H"}
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_g(graph):
|
||||
vertex_id = "G"
|
||||
|
||||
result = utils.sort_up_to_vertex(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"N", "C", "A", "B", "D", "F", "E", "G"}
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_a(graph):
|
||||
vertex_id = "A"
|
||||
|
||||
result = utils.sort_up_to_vertex(graph, vertex_id)
|
||||
|
||||
assert set(result) == {"A"}
|
||||
|
||||
|
||||
def test_sort_up_to_vertex_invalid_vertex(graph):
|
||||
vertex_id = "7"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
utils.sort_up_to_vertex(graph, vertex_id)
|
||||
Loading…
Add table
Add a link
Reference in a new issue