From d413abbf904e8ad45b7f450f07e2a135784e3c21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dtalo=20Johnny?= Date: Sun, 19 Jan 2025 22:25:54 -0300 Subject: [PATCH] refactor: enhance graph initialization and telemetry handling (#5721) * refactor: asyncdbsession handling in event_generator * refactor: simplify graph building logic in build_flow * refactor: simplify graph initialization and sorting * refactor: adjust graph initialization and telemetry --- src/backend/base/langflow/api/v1/chat.py | 116 ++++++++++------------- 1 file changed, 52 insertions(+), 64 deletions(-) diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 1cbecb268..00204fcaf 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -160,32 +160,16 @@ async def build_flow( async def build_graph_and_get_order() -> tuple[list[str], list[str], Graph]: start_time = time.perf_counter() - components_count = None + components_count = 0 + graph = None try: flow_id_str = str(flow_id) # Create a fresh session for database operations async with session_scope() as fresh_session: - if not data: - graph = await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service) - else: - result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id)) - flow_name = result.first() - graph = await build_graph_from_data( - flow_id=flow_id_str, - payload=data.model_dump(), - user_id=str(current_user.id), - flow_name=flow_name, - ) + graph = await create_graph(fresh_session, flow_id_str) graph.validate_stream() - if stop_component_id or start_component_id: - try: - first_layer = graph.sort_vertices(stop_component_id, start_component_id) - except Exception: # noqa: BLE001 - logger.exception("Error sorting vertices") - first_layer = graph.sort_vertices() - else: - first_layer = graph.sort_vertices() + first_layer = sort_vertices(graph) if inputs is not None and hasattr(inputs, "session") and inputs.session is not None: graph.session_id = inputs.session @@ -198,31 +182,53 @@ async def build_flow( # and return the same structure but only with the ids components_count = len(graph.vertices) vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run))) + await chat_service.set_cache(flow_id_str, graph) - background_tasks.add_task( - telemetry_service.log_package_playground, - PlaygroundPayload( - playground_seconds=int(time.perf_counter() - start_time), - playground_component_count=components_count, - playground_success=True, - ), - ) + await log_telemetry(start_time, components_count, success=True) + except Exception as exc: - background_tasks.add_task( - telemetry_service.log_package_playground, - PlaygroundPayload( - playground_seconds=int(time.perf_counter() - start_time), - playground_component_count=components_count, - playground_success=False, - playground_error_message=str(exc), - ), - ) + await log_telemetry(start_time, components_count, success=False, error_message=str(exc)) + if "stream or streaming set to True" in str(exc): raise HTTPException(status_code=400, detail=str(exc)) from exc logger.exception("Error checking build status") raise HTTPException(status_code=500, detail=str(exc)) from exc return first_layer, vertices_to_run, graph + async def log_telemetry( + start_time: float, components_count: int, *, success: bool, error_message: str | None = None + ): + background_tasks.add_task( + telemetry_service.log_package_playground, + PlaygroundPayload( + playground_seconds=int(time.perf_counter() - start_time), + playground_component_count=components_count, + playground_success=success, + playground_error_message=str(error_message) if error_message else "", + ), + ) + + async def create_graph(fresh_session, flow_id_str: str) -> Graph: + if not data: + return await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service) + + result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id)) + flow_name = result.first() + + return await build_graph_from_data( + flow_id=flow_id_str, + payload=data.model_dump(), + user_id=str(current_user.id), + flow_name=flow_name, + ) + + def sort_vertices(graph: Graph) -> list[str]: + try: + return graph.sort_vertices(stop_component_id, start_component_id) + except Exception: # noqa: BLE001 + logger.exception("Error sorting vertices") + return graph.sort_vertices() + async def _build_vertex(vertex_id: str, graph: Graph, event_manager: EventManager) -> VertexBuildResponse: flow_id_str = str(flow_id) next_runnable_vertices = [] @@ -372,33 +378,15 @@ async def build_flow( return async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None: - if not data: - # using another task since the build_graph_and_get_order is now an async function - vertices_task = asyncio.create_task(build_graph_and_get_order()) - try: - await vertices_task - except asyncio.CancelledError: - vertices_task.cancel() - return - except Exception as e: - error_message = ErrorMessage( - flow_id=flow_id, - exception=e, - ) - event_manager.on_error(data=error_message.data) - raise - - ids, vertices_to_run, graph = vertices_task.result() - else: - try: - ids, vertices_to_run, graph = await build_graph_and_get_order() - except Exception as e: - error_message = ErrorMessage( - flow_id=flow_id, - exception=e, - ) - event_manager.on_error(data=error_message.data) - raise + try: + ids, vertices_to_run, graph = await build_graph_and_get_order() + except Exception as e: + error_message = ErrorMessage( + flow_id=flow_id, + exception=e, + ) + event_manager.on_error(data=error_message.data) + raise event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run}) await client_consumed_queue.get()