From 5cef5b868aff7be4b8a7d1e72683fdf148dd48a6 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 20 Feb 2024 13:13:01 -0300 Subject: [PATCH] Refactor graph building and caching --- src/backend/langflow/api/v1/chat.py | 86 ++++++++++++++++++----------- 1 file changed, 54 insertions(+), 32 deletions(-) diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index c3d560cc3..6b942e081 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -22,7 +22,6 @@ from langflow.api.v1.schemas import ( VerticesOrderResponse, ) from langflow.graph.graph.base import Graph -from langflow.graph.vertex.base import StatelessVertex from langflow.processing.process import process_tweaks_on_graph from langflow.services.auth.utils import ( get_current_active_user, @@ -275,6 +274,25 @@ async def try_running_celery_task(vertex, user_id): return vertex +def build_and_cache_graph( + flow_id: str, + session: Session, + chat_service: "ChatService", + graph: Optional[Graph] = None, +): + """Build and cache the graph.""" + flow: Flow = session.get(Flow, flow_id) + if not flow or not flow.data: + raise ValueError("Invalid flow ID") + other_graph = Graph.from_payload(flow.data) + if graph is None: + graph = other_graph + else: + graph = graph.update(other_graph) + chat_service.set_cache(flow_id, graph) + return graph + + @router.get("/build/{flow_id}/vertices", response_model=VerticesOrderResponse) async def get_vertices( flow_id: str, @@ -288,17 +306,7 @@ async def get_vertices( graph = None if cache := chat_service.get_cache(flow_id): graph: Graph = cache.get("result") - - flow: Flow = session.get(Flow, flow_id) - if not flow or not flow.data: - raise ValueError("Invalid flow ID") - other_graph = Graph.from_payload(flow.data) - if graph is None: - graph = other_graph - else: - graph = graph.update(other_graph) - chat_service.set_cache(flow_id, graph) - + graph = build_and_cache_graph(flow_id, session, chat_service, graph) if component_id: try: vertices = graph.sort_up_to_vertex(component_id) @@ -332,34 +340,48 @@ async def build_vertex( """Build a vertex instead of the entire graph.""" try: cache = chat_service.get_cache(flow_id) - graph = cache.get("result") + if not cache: + # If there's no cache + logger.warning( + f"No cache found for {flow_id}. Building graph starting at {vertex_id}" + ) + graph = build_and_cache_graph( + flow_id=flow_id, session=get_session(), chat_service=chat_service + ) + else: + graph = cache.get("result") result_dict = {} duration = "" start_time = time.perf_counter() if tweaks: graph = process_tweaks_on_graph(graph, tweaks) - if not isinstance(graph, Graph): - raise ValueError("Invalid graph") if not (vertex := graph.get_vertex(vertex_id)): - raise ValueError("Invalid vertex") + raise ValueError(f"Invalid vertex {vertex_id}") try: - if isinstance(vertex, StatelessVertex) or not vertex._built: + if not vertex.pinned or not vertex._built: await vertex.build(user_id=current_user.id) - params = vertex._built_object_repr() - valid = True - result_dict = vertex.get_built_result() - # We need to set the artifacts to pass information - # to the frontend - vertex.set_artifacts() - artifacts = vertex.artifacts - timedelta = time.perf_counter() - start_time - duration = format_elapsed_time(timedelta) - result_dict = ResultDict( - results=result_dict, - artifacts=artifacts, - duration=duration, - timedelta=timedelta, - ) + params = vertex._built_object_repr() + valid = True + result_dict = vertex.get_built_result() + # We need to set the artifacts to pass information + # to the frontend + vertex.set_artifacts() + artifacts = vertex.artifacts + timedelta = time.perf_counter() - start_time + duration = format_elapsed_time(timedelta) + result_dict = ResultDict( + results=result_dict, + artifacts=artifacts, + duration=duration, + timedelta=timedelta, + ) + vertex.set_result(result_dict) + elif vertex.result is not None: + params = vertex._built_object_repr() + valid = True + result_dict = vertex.result + else: + raise ValueError(f"No result found for vertex {vertex_id}") chat_service.set_cache(flow_id, graph) except Exception as exc: params = str(exc)