diff --git a/src/backend/langflow/services/chat/manager.py b/src/backend/langflow/services/chat/manager.py index a49f48273..76790cbb4 100644 --- a/src/backend/langflow/services/chat/manager.py +++ b/src/backend/langflow/services/chat/manager.py @@ -2,19 +2,17 @@ from collections import defaultdict from fastapi import WebSocket, status from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse from langflow.services.base import Service -from langflow.services import service_manager -from langflow.services.cache.manager import Subject +from langflow.services.chat.cache import Subject from langflow.services.chat.utils import process_graph from langflow.interface.utils import pil_to_base64 -from langflow.services.schema import ServiceType from langflow.utils.logger import logger - +from .cache import cache_manager import asyncio import json from typing import Any, Dict, List -from langflow.services.cache.flow import InMemoryCache +from langflow.services import service_manager, ServiceType class ChatHistory(Subject): @@ -50,13 +48,13 @@ class ChatManager(Service): def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.chat_history = ChatHistory() + self.chat_cache = cache_manager + self.chat_cache.attach(self.update) self.cache_manager = service_manager.get(ServiceType.CACHE_MANAGER) - self.cache_manager.attach(self.update) - self.in_memory_cache = InMemoryCache() def on_chat_history_update(self): """Send the last chat message to the client.""" - client_id = self.cache_manager.current_client_id + client_id = self.chat_cache.current_client_id if client_id in self.active_connections: chat_response = self.chat_history.get_history( client_id, filter_messages=False @@ -77,8 +75,8 @@ class ChatManager(Service): asyncio.run_coroutine_threadsafe(coroutine, loop) def update(self): - if self.cache_manager.current_client_id in self.active_connections: - self.last_cached_object_dict = self.cache_manager.get_last() + if self.chat_cache.current_client_id in self.active_connections: + self.last_cached_object_dict = self.chat_cache.get_last() # Add a new ChatResponse with the data chat_response = FileResponse( message=None, @@ -88,7 +86,7 @@ class ChatManager(Service): ) self.chat_history.add_message( - self.cache_manager.current_client_id, chat_response + self.chat_cache.current_client_id, chat_response ) async def connect(self, client_id: str, websocket: WebSocket): @@ -174,9 +172,12 @@ class ChatManager(Service): """ Set the cache for a client. """ + # client_id is the flow id but that already exists in the cache + # so we need to change it to something else - self.in_memory_cache.set(client_id, langchain_object) - return client_id in self.in_memory_cache + client_id = f"{client_id}_chat" if "_chat" not in client_id else client_id + self.cache_manager.set(client_id, langchain_object) + return client_id in self.cache_manager async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) @@ -197,8 +198,8 @@ class ChatManager(Service): self.chat_history.history[client_id] = [] continue - with self.cache_manager.set_client_id(client_id): - langchain_object = self.in_memory_cache.get(client_id) + with self.chat_cache.set_client_id(client_id): + langchain_object = self.cache_manager.get(f"{client_id}_chat") await self.process_message(client_id, payload, langchain_object) except Exception as exc: