fix: Use AsyncSession in build_graph_from_db (#4649)

Use AsyncSession in build_graph_from_db
This commit is contained in:
Christophe Bornet 2024-11-17 12:35:15 +01:00 committed by GitHub
commit da01f5c723
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 9 deletions

View file

@ -160,16 +160,16 @@ async def build_graph_from_data(flow_id: str, payload: dict, **kwargs):
return graph
async def build_graph_from_db_no_cache(flow_id: str, session: Session):
async def build_graph_from_db_no_cache(flow_id: str, session: AsyncSession):
"""Build and cache the graph."""
flow: Flow | None = session.get(Flow, flow_id)
flow: Flow | None = await session.get(Flow, flow_id)
if not flow or not flow.data:
msg = "Invalid flow ID"
raise ValueError(msg)
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id))
async def build_graph_from_db(flow_id: str, session: Session, chat_service: ChatService):
async def build_graph_from_db(flow_id: str, session: AsyncSession, chat_service: ChatService):
graph = await build_graph_from_db_no_cache(flow_id, session)
await chat_service.set_cache(flow_id, graph)
return graph

View file

@ -16,8 +16,8 @@ from starlette.responses import ContentStream
from starlette.types import Receive
from langflow.api.utils import (
AsyncDbSession,
CurrentActiveUser,
DbSession,
build_and_cache_graph_from_data,
build_graph_from_data,
build_graph_from_db,
@ -42,7 +42,7 @@ from langflow.graph.utils import log_vertex_build
from langflow.schema.schema import OutputValue
from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session, get_telemetry_service
from langflow.services.deps import get_async_session, get_chat_service, get_telemetry_service
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
if TYPE_CHECKING:
@ -75,7 +75,7 @@ async def retrieve_vertices_order(
data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None,
stop_component_id: str | None = None,
start_component_id: str | None = None,
session: DbSession,
session: AsyncDbSession,
) -> VerticesOrderResponse:
"""Retrieve the vertices order for a given flow.
@ -85,7 +85,7 @@ async def retrieve_vertices_order(
data (Optional[FlowDataRequest], optional): The flow data. Defaults to None.
stop_component_id (str, optional): The ID of the stop component. Defaults to None.
start_component_id (str, optional): The ID of the start component. Defaults to None.
session (Session, optional): The session dependency. Defaults to Depends(get_session).
session (AsyncSession, optional): The session dependency.
Returns:
VerticesOrderResponse: The response containing the ordered vertex IDs and the run ID.
@ -151,7 +151,7 @@ async def build_flow(
start_component_id: str | None = None,
log_builds: bool | None = True,
current_user: CurrentActiveUser,
session: DbSession,
session: AsyncDbSession,
):
chat_service = get_chat_service()
telemetry_service = get_telemetry_service()
@ -501,7 +501,7 @@ async def build_vertex(
# If there's no cache
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
graph: Graph = await build_graph_from_db(
flow_id=flow_id_str, session=next(get_session()), chat_service=chat_service
flow_id=flow_id_str, session=await anext(get_async_session()), chat_service=chat_service
)
else:
graph = cache.get("result")