diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 9d322b03c..2c276c75d 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -11,17 +11,14 @@ from fastapi.responses import StreamingResponse from langflow.api.utils import build_input_keys_response from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData -from langflow.services import service_manager, ServiceType from langflow.graph.graph.base import Graph from langflow.services.auth.utils import get_current_active_user, get_current_user -from langflow.services.utils import get_session +from langflow.services.utils import get_chat_manager, get_session from langflow.utils.logger import logger from cachetools import LRUCache from sqlmodel import Session -from typing import TYPE_CHECKING +from langflow.services.chat.manager import ChatManager -if TYPE_CHECKING: - from langflow.services.chat.manager import ChatManager router = APIRouter(tags=["Chat"]) @@ -34,6 +31,7 @@ async def chat( websocket: WebSocket, token: str = Query(...), db: Session = Depends(get_session), + chat_manager: "ChatManager" = Depends(get_chat_manager), ): """Websocket endpoint for chat.""" try: @@ -48,7 +46,6 @@ async def chat( code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" ) - chat_manager: "ChatManager" = service_manager.get(ServiceType.CHAT_MANAGER) if client_id in chat_manager.in_memory_cache: await chat_manager.handle_websocket(client_id, websocket) else: @@ -72,7 +69,10 @@ async def chat( @router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201) async def init_build( - graph_data: dict, flow_id: str, current_user=Depends(get_current_active_user) + graph_data: dict, + flow_id: str, + current_user=Depends(get_current_active_user), + chat_manager: "ChatManager" = Depends(get_chat_manager), ): """Initialize the build by storing graph data and returning a unique session ID.""" @@ -87,7 +87,6 @@ async def init_build( return InitResponse(flowId=flow_id) # Delete from cache if already exists - chat_manager = service_manager.get(ServiceType.CHAT_MANAGER) if flow_id in chat_manager.in_memory_cache: with chat_manager.in_memory_cache._lock: chat_manager.in_memory_cache.delete(flow_id) @@ -123,7 +122,9 @@ async def build_status(flow_id: str): @router.get("/build/stream/{flow_id}", response_class=StreamingResponse) -async def stream_build(flow_id: str): +async def stream_build( + flow_id: str, chat_manager: "ChatManager" = Depends(get_chat_manager) +): """Stream the build process based on stored flow data.""" async def event_stream(flow_id): @@ -202,7 +203,6 @@ async def stream_build(flow_id: str): "handle_keys": [], } yield str(StreamData(event="message", data=input_keys_response)) - chat_manager = service_manager.get(ServiceType.CHAT_MANAGER) chat_manager.set_cache(flow_id, langchain_object) # We need to reset the chat history chat_manager.chat_history.empty_history(flow_id) diff --git a/src/backend/langflow/services/utils.py b/src/backend/langflow/services/utils.py index 6860f8928..708377d14 100644 --- a/src/backend/langflow/services/utils.py +++ b/src/backend/langflow/services/utils.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from langflow.services.database.manager import DatabaseManager from langflow.services.settings.manager import SettingsManager + from langflow.services.chat.manager import ChatManager + from sqlmodel import Session def get_settings_manager() -> "SettingsManager": @@ -15,6 +17,10 @@ def get_db_manager() -> "DatabaseManager": return service_manager.get(ServiceType.DATABASE_MANAGER) -def get_session(): +def get_session() -> "Session": db_manager = service_manager.get(ServiceType.DATABASE_MANAGER) yield from db_manager.get_session() + + +def get_chat_manager() -> "ChatManager": + return service_manager.get(ServiceType.CHAT_MANAGER)