From da01f5c723fe6466be01f3a34400b5bc06a6f4b7 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sun, 17 Nov 2024 12:35:15 +0100 Subject: [PATCH] fix: Use AsyncSession in build_graph_from_db (#4649) Use AsyncSession in build_graph_from_db --- src/backend/base/langflow/api/utils.py | 6 +++--- src/backend/base/langflow/api/v1/chat.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/api/utils.py b/src/backend/base/langflow/api/utils.py index f7c34661b..0f9f5c639 100644 --- a/src/backend/base/langflow/api/utils.py +++ b/src/backend/base/langflow/api/utils.py @@ -160,16 +160,16 @@ async def build_graph_from_data(flow_id: str, payload: dict, **kwargs): return graph -async def build_graph_from_db_no_cache(flow_id: str, session: Session): +async def build_graph_from_db_no_cache(flow_id: str, session: AsyncSession): """Build and cache the graph.""" - flow: Flow | None = session.get(Flow, flow_id) + flow: Flow | None = await session.get(Flow, flow_id) if not flow or not flow.data: msg = "Invalid flow ID" raise ValueError(msg) return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id)) -async def build_graph_from_db(flow_id: str, session: Session, chat_service: ChatService): +async def build_graph_from_db(flow_id: str, session: AsyncSession, chat_service: ChatService): graph = await build_graph_from_db_no_cache(flow_id, session) await chat_service.set_cache(flow_id, graph) return graph diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index ce7dfc579..5568decda 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -16,8 +16,8 @@ from starlette.responses import ContentStream from starlette.types import Receive from langflow.api.utils import ( + AsyncDbSession, CurrentActiveUser, - DbSession, build_and_cache_graph_from_data, build_graph_from_data, build_graph_from_db, @@ -42,7 +42,7 @@ from langflow.graph.utils import log_vertex_build from langflow.schema.schema import OutputValue from langflow.services.cache.utils import CacheMiss from langflow.services.chat.service import ChatService -from langflow.services.deps import get_chat_service, get_session, get_telemetry_service +from langflow.services.deps import get_async_session, get_chat_service, get_telemetry_service from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload if TYPE_CHECKING: @@ -75,7 +75,7 @@ async def retrieve_vertices_order( data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None, stop_component_id: str | None = None, start_component_id: str | None = None, - session: DbSession, + session: AsyncDbSession, ) -> VerticesOrderResponse: """Retrieve the vertices order for a given flow. @@ -85,7 +85,7 @@ async def retrieve_vertices_order( data (Optional[FlowDataRequest], optional): The flow data. Defaults to None. stop_component_id (str, optional): The ID of the stop component. Defaults to None. start_component_id (str, optional): The ID of the start component. Defaults to None. - session (Session, optional): The session dependency. Defaults to Depends(get_session). + session (AsyncSession, optional): The session dependency. Returns: VerticesOrderResponse: The response containing the ordered vertex IDs and the run ID. @@ -151,7 +151,7 @@ async def build_flow( start_component_id: str | None = None, log_builds: bool | None = True, current_user: CurrentActiveUser, - session: DbSession, + session: AsyncDbSession, ): chat_service = get_chat_service() telemetry_service = get_telemetry_service() @@ -501,7 +501,7 @@ async def build_vertex( # If there's no cache logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}") graph: Graph = await build_graph_from_db( - flow_id=flow_id_str, session=next(get_session()), chat_service=chat_service + flow_id=flow_id_str, session=await anext(get_async_session()), chat_service=chat_service ) else: graph = cache.get("result")