From 84a0d3acb39c14f8e0f51a2cd8f2dac11904fc89 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 4 Sep 2023 06:42:01 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(chat.py):=20remove=20unused?= =?UTF-8?q?=20imports=20and=20type=20hints=20to=20improve=20code=20readabi?= =?UTF-8?q?lity=20=E2=9C=A8=20feat(chat.py):=20add=20dependency=20injectio?= =?UTF-8?q?n=20for=20ChatManager=20in=20chat=20and=20init=5Fbuild=20routes?= =?UTF-8?q?=20to=20improve=20modularity=20and=20testability=20=F0=9F=94=A7?= =?UTF-8?q?=20fix(chat.py):=20remove=20duplicate=20instantiation=20of=20Ch?= =?UTF-8?q?atManager=20in=20chat=20and=20init=5Fbuild=20routes=20to=20impr?= =?UTF-8?q?ove=20efficiency=20=F0=9F=94=A7=20fix(chat.py):=20remove=20dupl?= =?UTF-8?q?icate=20instantiation=20of=20ChatManager=20in=20stream=5Fbuild?= =?UTF-8?q?=20route=20to=20improve=20efficiency=20=F0=9F=94=A7=20fix(utils?= =?UTF-8?q?.py):=20add=20missing=20import=20for=20ChatManager=20in=20get?= =?UTF-8?q?=5Fchat=5Fmanager=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/chat.py | 20 ++++++++++---------- src/backend/langflow/services/utils.py | 8 +++++++- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 9d322b03c..2c276c75d 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -11,17 +11,14 @@ from fastapi.responses import StreamingResponse from langflow.api.utils import build_input_keys_response from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData -from langflow.services import service_manager, ServiceType from langflow.graph.graph.base import Graph from langflow.services.auth.utils import get_current_active_user, get_current_user -from langflow.services.utils import get_session +from langflow.services.utils import get_chat_manager, get_session from langflow.utils.logger import logger from cachetools import LRUCache from sqlmodel import Session -from typing import TYPE_CHECKING +from langflow.services.chat.manager import ChatManager -if TYPE_CHECKING: - from langflow.services.chat.manager import ChatManager router = APIRouter(tags=["Chat"]) @@ -34,6 +31,7 @@ async def chat( websocket: WebSocket, token: str = Query(...), db: Session = Depends(get_session), + chat_manager: "ChatManager" = Depends(get_chat_manager), ): """Websocket endpoint for chat.""" try: @@ -48,7 +46,6 @@ async def chat( code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" ) - chat_manager: "ChatManager" = service_manager.get(ServiceType.CHAT_MANAGER) if client_id in chat_manager.in_memory_cache: await chat_manager.handle_websocket(client_id, websocket) else: @@ -72,7 +69,10 @@ async def chat( @router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201) async def init_build( - graph_data: dict, flow_id: str, current_user=Depends(get_current_active_user) + graph_data: dict, + flow_id: str, + current_user=Depends(get_current_active_user), + chat_manager: "ChatManager" = Depends(get_chat_manager), ): """Initialize the build by storing graph data and returning a unique session ID.""" @@ -87,7 +87,6 @@ async def init_build( return InitResponse(flowId=flow_id) # Delete from cache if already exists - chat_manager = service_manager.get(ServiceType.CHAT_MANAGER) if flow_id in chat_manager.in_memory_cache: with chat_manager.in_memory_cache._lock: chat_manager.in_memory_cache.delete(flow_id) @@ -123,7 +122,9 @@ async def build_status(flow_id: str): @router.get("/build/stream/{flow_id}", response_class=StreamingResponse) -async def stream_build(flow_id: str): +async def stream_build( + flow_id: str, chat_manager: "ChatManager" = Depends(get_chat_manager) +): """Stream the build process based on stored flow data.""" async def event_stream(flow_id): @@ -202,7 +203,6 @@ async def stream_build(flow_id: str): "handle_keys": [], } yield str(StreamData(event="message", data=input_keys_response)) - chat_manager = service_manager.get(ServiceType.CHAT_MANAGER) chat_manager.set_cache(flow_id, langchain_object) # We need to reset the chat history chat_manager.chat_history.empty_history(flow_id) diff --git a/src/backend/langflow/services/utils.py b/src/backend/langflow/services/utils.py index 6860f8928..708377d14 100644 --- a/src/backend/langflow/services/utils.py +++ b/src/backend/langflow/services/utils.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from langflow.services.database.manager import DatabaseManager from langflow.services.settings.manager import SettingsManager + from langflow.services.chat.manager import ChatManager + from sqlmodel import Session def get_settings_manager() -> "SettingsManager": @@ -15,6 +17,10 @@ def get_db_manager() -> "DatabaseManager": return service_manager.get(ServiceType.DATABASE_MANAGER) -def get_session(): +def get_session() -> "Session": db_manager = service_manager.get(ServiceType.DATABASE_MANAGER) yield from db_manager.get_session() + + +def get_chat_manager() -> "ChatManager": + return service_manager.get(ServiceType.CHAT_MANAGER)