Add import statement and update build_vertex function

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-25 12:02:21 -03:00
commit 97cbb20f90
3 changed files with 61 additions and 19 deletions

View file

@ -1,5 +1,6 @@
import time
import uuid
from functools import partial
from typing import TYPE_CHECKING, Annotated, Optional
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
@ -141,7 +142,10 @@ async def build_vertex(
graph = cache.get("result")
result_data_response = ResultDataResponse(results={})
duration = ""
vertex = graph.get_vertex(vertex_id)
try:
lock = chat_service._cache_locks[flow_id]
set_cache_coro = partial(chat_service.set_cache, flow_id=flow_id)
(
next_runnable_vertices,
top_level_vertices,
@ -151,7 +155,8 @@ async def build_vertex(
artifacts,
vertex,
) = await graph.build_vertex(
chat_service=chat_service,
lock=lock,
set_cache_coro=set_cache_coro,
vertex_id=vertex_id,
user_id=current_user.id,
inputs=inputs.model_dump() if inputs else {},

View file

@ -1,7 +1,7 @@
import asyncio
from collections import defaultdict, deque
from itertools import chain
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
from typing import TYPE_CHECKING, Coroutine, Dict, Generator, List, Optional, Type, Union
from loguru import logger
@ -387,6 +387,7 @@ class Graph:
self.in_degree_map = self.build_in_degree()
self.parent_child_map = self.build_parent_child_map()
self.run_manager.build_run_map(self)
def reset_inactivated_vertices(self):
"""
@ -648,8 +649,30 @@ class Graph:
raise ValueError(f"Vertex {vertex_id} not found")
async def build_vertex(
self, chat_service, vertex_id: str, inputs: Optional[Dict[str, str]] = None, user_id: Optional[str] = None
self,
lock: asyncio.Lock,
set_cache_coro: Coroutine,
vertex_id: str,
inputs: Optional[Dict[str, str]] = None,
user_id: Optional[str] = None,
):
"""
Builds a vertex in the graph.
Args:
lock (asyncio.Lock): A lock to synchronize access to the graph.
set_cache_coro (Coroutine): A coroutine to set the cache.
vertex_id (str): The ID of the vertex to build.
inputs (Optional[Dict[str, str]]): Optional dictionary of inputs for the vertex. Defaults to None.
user_id (Optional[str]): Optional user ID. Defaults to None.
Returns:
Tuple: A tuple containing the next runnable vertices, top level vertices, result dictionary,
parameters, validity flag, artifacts, and the built vertex.
Raises:
ValueError: If no result is found for the vertex.
"""
vertex = self.get_vertex(vertex_id)
try:
if not vertex.frozen or not vertex._built:
@ -664,15 +687,30 @@ class Graph:
else:
raise ValueError(f"No result found for vertex {vertex_id}")
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
self, vertex, vertex_id, chat_service, self.flow_id
next_runnable_vertices, top_level_vertices = await self.get_next_and_top_level_vertices(
lock, set_cache_coro, vertex
)
top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices)
return next_runnable_vertices, top_level_vertices, result_dict, params, valid, artifacts, vertex
except Exception as exc:
logger.exception(f"Error building vertex: {exc}")
raise exc
async def get_next_and_top_level_vertices(self, lock: asyncio.Lock, set_cache_coro: Coroutine, vertex: Vertex):
"""
Retrieves the next runnable vertices and the top level vertices for a given vertex.
Args:
lock (asyncio.Lock): The lock used to synchronize access to the graph.
set_cache_coro (Coroutine): The coroutine used to set the cache for the graph.
vertex (Vertex): The vertex for which to retrieve the next runnable and top level vertices.
Returns:
Tuple[List[Vertex], List[Vertex]]: A tuple containing the next runnable vertices and the top level vertices.
"""
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(lock, set_cache_coro, self, vertex)
top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices)
return next_runnable_vertices, top_level_vertices
def get_vertex_edges(
self,
vertex_id: str,
@ -1111,7 +1149,7 @@ class Graph:
# save the only the rest
self.vertices_layers = vertices_layers[1:]
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
self.build_run_map()
self.build_graph_maps()
# Return just the first layer
return first_layer
@ -1167,7 +1205,7 @@ class Graph:
This checks the direct predecessors of each successor to identify any that are
immediately runnable, expanding the search to ensure progress can be made.
"""
self.run_manager.find_runnable_predecessors_for_successors(vertex_id)
return self.run_manager.find_runnable_predecessors_for_successors(vertex_id)
def remove_from_predecessors(self, vertex_id: str):
self.run_manager.remove_from_predecessors(vertex_id)

View file

@ -1,10 +1,10 @@
import asyncio
from collections import defaultdict
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Coroutine, List
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.services.chat.service import ChatService
class RunnableVerticesManager:
@ -42,7 +42,7 @@ class RunnableVerticesManager:
for vertex_id, predecessors in graph.predecessor_map.items():
for predecessor in predecessors:
self.run_map[predecessor].append(vertex_id)
self.run_predecessors = {k: set(v) for k, v in self.run_map.items()}
self.run_predecessors = graph.predecessor_map.copy()
def update_vertex_run_state(self, vertex_id: str, is_runnable: bool):
"""Updates the runnable state of a vertex."""
@ -51,13 +51,12 @@ class RunnableVerticesManager:
else:
self.vertices_to_run.discard(vertex_id)
@staticmethod
async def get_next_runnable_vertices(
self,
lock: asyncio.Lock,
set_cache_coro: Coroutine,
graph: "Graph",
vertex: "Vertex",
vertex_id: str,
chat_service: "ChatService",
flow_id: str,
):
"""
Retrieves the next runnable vertices in the graph for a given vertex.
@ -73,19 +72,19 @@ class RunnableVerticesManager:
list: A list of IDs of the next runnable vertices.
"""
async with chat_service._cache_locks[flow_id] as lock:
graph.remove_from_predecessors(vertex_id)
async with lock:
graph.remove_from_predecessors(vertex.id)
direct_successors_ready = [v for v in vertex.successors_ids if graph.is_vertex_runnable(v)]
if not direct_successors_ready:
# No direct successors ready, look for runnable predecessors of successors
next_runnable_vertices = graph.find_runnable_predecessors_for_successors(vertex_id)
next_runnable_vertices = self.find_runnable_predecessors_for_successors(vertex.id)
else:
next_runnable_vertices = direct_successors_ready
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
graph.vertices_to_run.remove(v_id)
graph.remove_from_predecessors(v_id)
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)
await set_cache_coro(data=graph, lock=lock)
return next_runnable_vertices
@staticmethod