🔧 fix(chat.py): remove unused imports and type hints to improve code readability
✨ feat(chat.py): add dependency injection for ChatManager in chat and init_build routes to improve modularity and testability 🔧 fix(chat.py): remove duplicate instantiation of ChatManager in chat and init_build routes to improve efficiency 🔧 fix(chat.py): remove duplicate instantiation of ChatManager in stream_build route to improve efficiency 🔧 fix(utils.py): add missing import for ChatManager in get_chat_manager function
This commit is contained in:
parent
2bbbf44b39
commit
84a0d3acb3
2 changed files with 17 additions and 11 deletions
|
|
@ -11,17 +11,14 @@ from fastapi.responses import StreamingResponse
|
|||
from langflow.api.utils import build_input_keys_response
|
||||
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
|
||||
|
||||
from langflow.services import service_manager, ServiceType
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import get_current_active_user, get_current_user
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.services.utils import get_chat_manager, get_session
|
||||
from langflow.utils.logger import logger
|
||||
from cachetools import LRUCache
|
||||
from sqlmodel import Session
|
||||
from typing import TYPE_CHECKING
|
||||
from langflow.services.chat.manager import ChatManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.chat.manager import ChatManager
|
||||
|
||||
router = APIRouter(tags=["Chat"])
|
||||
|
||||
|
|
@ -34,6 +31,7 @@ async def chat(
|
|||
websocket: WebSocket,
|
||||
token: str = Query(...),
|
||||
db: Session = Depends(get_session),
|
||||
chat_manager: "ChatManager" = Depends(get_chat_manager),
|
||||
):
|
||||
"""Websocket endpoint for chat."""
|
||||
try:
|
||||
|
|
@ -48,7 +46,6 @@ async def chat(
|
|||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
|
||||
chat_manager: "ChatManager" = service_manager.get(ServiceType.CHAT_MANAGER)
|
||||
if client_id in chat_manager.in_memory_cache:
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
else:
|
||||
|
|
@ -72,7 +69,10 @@ async def chat(
|
|||
|
||||
@router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201)
|
||||
async def init_build(
|
||||
graph_data: dict, flow_id: str, current_user=Depends(get_current_active_user)
|
||||
graph_data: dict,
|
||||
flow_id: str,
|
||||
current_user=Depends(get_current_active_user),
|
||||
chat_manager: "ChatManager" = Depends(get_chat_manager),
|
||||
):
|
||||
"""Initialize the build by storing graph data and returning a unique session ID."""
|
||||
|
||||
|
|
@ -87,7 +87,6 @@ async def init_build(
|
|||
return InitResponse(flowId=flow_id)
|
||||
|
||||
# Delete from cache if already exists
|
||||
chat_manager = service_manager.get(ServiceType.CHAT_MANAGER)
|
||||
if flow_id in chat_manager.in_memory_cache:
|
||||
with chat_manager.in_memory_cache._lock:
|
||||
chat_manager.in_memory_cache.delete(flow_id)
|
||||
|
|
@ -123,7 +122,9 @@ async def build_status(flow_id: str):
|
|||
|
||||
|
||||
@router.get("/build/stream/{flow_id}", response_class=StreamingResponse)
|
||||
async def stream_build(flow_id: str):
|
||||
async def stream_build(
|
||||
flow_id: str, chat_manager: "ChatManager" = Depends(get_chat_manager)
|
||||
):
|
||||
"""Stream the build process based on stored flow data."""
|
||||
|
||||
async def event_stream(flow_id):
|
||||
|
|
@ -202,7 +203,6 @@ async def stream_build(flow_id: str):
|
|||
"handle_keys": [],
|
||||
}
|
||||
yield str(StreamData(event="message", data=input_keys_response))
|
||||
chat_manager = service_manager.get(ServiceType.CHAT_MANAGER)
|
||||
chat_manager.set_cache(flow_id, langchain_object)
|
||||
# We need to reset the chat history
|
||||
chat_manager.chat_history.empty_history(flow_id)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from typing import TYPE_CHECKING
|
|||
if TYPE_CHECKING:
|
||||
from langflow.services.database.manager import DatabaseManager
|
||||
from langflow.services.settings.manager import SettingsManager
|
||||
from langflow.services.chat.manager import ChatManager
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
def get_settings_manager() -> "SettingsManager":
|
||||
|
|
@ -15,6 +17,10 @@ def get_db_manager() -> "DatabaseManager":
|
|||
return service_manager.get(ServiceType.DATABASE_MANAGER)
|
||||
|
||||
|
||||
def get_session():
|
||||
def get_session() -> "Session":
|
||||
db_manager = service_manager.get(ServiceType.DATABASE_MANAGER)
|
||||
yield from db_manager.get_session()
|
||||
|
||||
|
||||
def get_chat_manager() -> "ChatManager":
|
||||
return service_manager.get(ServiceType.CHAT_MANAGER)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue