🐛 fix(manager.py): fix import statements and variable names for better organization and readability

 feat(manager.py): add support for chat cache and attach it to the chat manager for better chat history management
🔧 chore(manager.py): change client_id format in set_cache method to avoid conflicts with existing cache keys
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-13 23:41:08 -03:00
commit 561e048278

View file

@ -2,19 +2,17 @@ from collections import defaultdict
from fastapi import WebSocket, status
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.services.base import Service
from langflow.services import service_manager
from langflow.services.cache.manager import Subject
from langflow.services.chat.cache import Subject
from langflow.services.chat.utils import process_graph
from langflow.interface.utils import pil_to_base64
from langflow.services.schema import ServiceType
from langflow.utils.logger import logger
from .cache import cache_manager
import asyncio
import json
from typing import Any, Dict, List
from langflow.services.cache.flow import InMemoryCache
from langflow.services import service_manager, ServiceType
class ChatHistory(Subject):
@ -50,13 +48,13 @@ class ChatManager(Service):
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.chat_history = ChatHistory()
self.chat_cache = cache_manager
self.chat_cache.attach(self.update)
self.cache_manager = service_manager.get(ServiceType.CACHE_MANAGER)
self.cache_manager.attach(self.update)
self.in_memory_cache = InMemoryCache()
def on_chat_history_update(self):
"""Send the last chat message to the client."""
client_id = self.cache_manager.current_client_id
client_id = self.chat_cache.current_client_id
if client_id in self.active_connections:
chat_response = self.chat_history.get_history(
client_id, filter_messages=False
@ -77,8 +75,8 @@ class ChatManager(Service):
asyncio.run_coroutine_threadsafe(coroutine, loop)
def update(self):
if self.cache_manager.current_client_id in self.active_connections:
self.last_cached_object_dict = self.cache_manager.get_last()
if self.chat_cache.current_client_id in self.active_connections:
self.last_cached_object_dict = self.chat_cache.get_last()
# Add a new ChatResponse with the data
chat_response = FileResponse(
message=None,
@ -88,7 +86,7 @@ class ChatManager(Service):
)
self.chat_history.add_message(
self.cache_manager.current_client_id, chat_response
self.chat_cache.current_client_id, chat_response
)
async def connect(self, client_id: str, websocket: WebSocket):
@ -174,9 +172,12 @@ class ChatManager(Service):
"""
Set the cache for a client.
"""
# client_id is the flow id but that already exists in the cache
# so we need to change it to something else
self.in_memory_cache.set(client_id, langchain_object)
return client_id in self.in_memory_cache
client_id = f"{client_id}_chat" if "_chat" not in client_id else client_id
self.cache_manager.set(client_id, langchain_object)
return client_id in self.cache_manager
async def handle_websocket(self, client_id: str, websocket: WebSocket):
await self.connect(client_id, websocket)
@ -197,8 +198,8 @@ class ChatManager(Service):
self.chat_history.history[client_id] = []
continue
with self.cache_manager.set_client_id(client_id):
langchain_object = self.in_memory_cache.get(client_id)
with self.chat_cache.set_client_id(client_id):
langchain_object = self.cache_manager.get(f"{client_id}_chat")
await self.process_message(client_id, payload, langchain_object)
except Exception as exc: