Refactor vertex build process and add new response field

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-20 16:12:12 -03:00
commit cb98118ce8
3 changed files with 29 additions and 23 deletions

View file

@ -10,6 +10,8 @@ from langflow.api.utils import (
build_and_cache_graph,
format_elapsed_time,
format_exception_message,
get_next_runnable_vertices,
get_top_level_vertices,
)
from langflow.api.v1.schemas import (
InputValueRequest,
@ -95,7 +97,8 @@ async def build_vertex(
"""Build a vertex instead of the entire graph."""
start_time = time.perf_counter()
next_vertices_ids = []
next_runnable_vertices = []
top_level_vertices = []
try:
start_time = time.perf_counter()
cache = await chat_service.get_cache(flow_id)
@ -121,12 +124,9 @@ async def build_vertex(
artifacts = vertex.artifacts
else:
raise ValueError(f"No result found for vertex {vertex_id}")
async with chat_service._cache_locks[flow_id] as lock:
graph.remove_from_predecessors(vertex_id)
next_vertices_ids = vertex.successors_ids
next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)]
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)
next_runnable_vertices = await get_next_runnable_vertices(graph, vertex, vertex_id, chat_service, flow_id)
top_level_vertices = get_top_level_vertices(graph, next_runnable_vertices)
result_data_response = ResultDataResponse(**result_dict.model_dump())
except Exception as exc:
@ -166,12 +166,13 @@ async def build_vertex(
# to stop the build of the graph at a certain vertex
# if it is in next_vertices_ids, we need to remove other
# vertices from next_vertices_ids
if graph.stop_vertex and graph.stop_vertex in next_vertices_ids:
next_vertices_ids = [graph.stop_vertex]
if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices:
next_runnable_vertices = [graph.stop_vertex]
build_response = VertexBuildResponse(
inactivated_vertices=inactivated_vertices,
next_vertices_ids=next_vertices_ids,
next_vertices_ids=next_runnable_vertices,
top_level_vertices=top_level_vertices,
valid=valid,
params=params,
id=vertex.id,

View file

@ -247,6 +247,7 @@ class VertexBuildResponse(BaseModel):
id: Optional[str] = None
inactivated_vertices: Optional[List[str]] = None
next_vertices_ids: Optional[List[str]] = None
top_level_vertices: Optional[List[str]] = None
valid: bool
params: Optional[Any] = Field(default_factory=dict)
"""JSON string of the params."""

View file

@ -953,22 +953,26 @@ class Graph:
# Return just the first layer
return first_layer
def vertex_has_no_more_predecessors(self, vertex_id: str) -> bool:
"""Returns whether a vertex has no more predecessors."""
return not self.run_predecessors.get(vertex_id)
def is_vertex_runnable(self, vertex_id: str) -> bool:
"""Returns whether a vertex is runnable."""
return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id)
def should_run_vertex(self, vertex_id: str) -> bool:
"""Returns whether a component should be run."""
# the self.run_map is a map of vertex_id to a list of predecessors
# each time a vertex is run, we remove it from the list of predecessors
# if a vertex has no more predecessors, it should be run
should_run = vertex_id in self.vertices_to_run and self.vertex_has_no_more_predecessors(vertex_id)
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:
"""
For each successor of the current vertex, find runnable predecessors if any.
This checks the direct predecessors of each successor to identify any that are
immediately runnable, expanding the search to ensure progress can be made.
"""
runnable_vertices = []
visited = set()
if should_run:
self.vertices_to_run.remove(vertex_id)
# remove the vertex from the run_map
self.remove_from_predecessors(vertex_id)
return should_run
for successor_id in self.run_map.get(vertex_id, []):
for predecessor_id in self.run_predecessors.get(successor_id, []):
if predecessor_id not in visited and self.is_vertex_runnable(predecessor_id):
runnable_vertices.append(predecessor_id)
visited.add(predecessor_id)
return runnable_vertices
def remove_from_predecessors(self, vertex_id: str):
predecessors = self.run_map.get(vertex_id, [])