From 5169c0bc27960b678961fa94fed066a648c8efa4 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Thu, 20 Apr 2023 11:09:42 -0300 Subject: [PATCH] feat(chat_manager.py): add support for sending file responses fix(schemas.py): add validation for file response type and data type test(test_websocket.py): remove data and data_type fields from ChatResponse messages in tests --- src/backend/langflow/api/chat_manager.py | 69 +++++++++++++++++++++--- src/backend/langflow/api/schemas.py | 20 +++++-- tests/test_websocket.py | 4 -- 3 files changed, 77 insertions(+), 16 deletions(-) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 7ca04abf7..8dcaf05ac 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -1,22 +1,30 @@ +import asyncio +import base64 +from io import BytesIO from typing import Dict, List from collections import defaultdict from fastapi import WebSocket import json -from langflow.api.schemas import ChatMessage, ChatResponse +from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse +from langflow.cache.manager import AsyncSubject from langflow.interface.run import ( async_get_result_and_steps, load_or_build_langchain_object, ) from langflow.utils.logger import logger +from langflow.cache import cache_manager +from PIL.Image import Image -class ChatHistory: +class ChatHistory(AsyncSubject): def __init__(self): + super().__init__() self.history: Dict[str, List[ChatMessage]] = defaultdict(list) - def add_message(self, client_id: str, message: ChatMessage): + async def add_message(self, client_id: str, message: ChatMessage): self.history[client_id].append(message) + await self.notify() def get_history(self, client_id: str) -> List[ChatMessage]: return self.history[client_id] @@ -26,6 +34,44 @@ class ChatManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.chat_history = ChatHistory() + self.chat_history.attach(self.on_chat_history_update) + self.cache_manager = cache_manager + self.cache_manager.attach(self.update) + + async def on_chat_history_update(self): + """Send the last chat message to the client.""" + client_id = self.cache_manager.current_client_id + if client_id in self.active_connections: + chat_response = self.chat_history.get_history(client_id)[-1] + if chat_response.sender == "bot": + # Process FileResponse + if isinstance(chat_response, FileResponse): + # If data_type is pandas, convert to csv + if chat_response.data_type == "pandas": + chat_response.data = chat_response.data.to_csv() + elif chat_response.data_type == "image": + # Base64 encode the image + chat_response.data = pil_to_base64(chat_response.data) + + await self.send_json(client_id, chat_response) + + def update(self): + if self.cache_manager.current_client_id in self.active_connections: + self.last_cached_object_dict = self.cache_manager.get_last() + # Add a new ChatResponse with the data + chat_response = FileResponse( + sender="bot", + message=None, + type="file", + data=self.last_cached_object_dict["obj"], + data_type=self.last_cached_object_dict["type"], + ) + + asyncio.create_task( + self.chat_history.add_message( + self.cache_manager.current_client_id, chat_response + ) + ) async def connect(self, client_id: str, websocket: WebSocket): await websocket.accept() @@ -40,7 +86,6 @@ class ChatManager: async def send_json(self, client_id: str, message: ChatMessage): websocket = self.active_connections[client_id] - self.chat_history.add_message(client_id, message) await websocket.send_json(json.dumps(message.dict())) async def process_message(self, client_id: str, payload: Dict): @@ -48,13 +93,13 @@ class ChatManager: chat_message = payload.pop("message", "") chat_message = ChatMessage(sender="you", message=chat_message) - self.chat_history.add_message(client_id, chat_message) + await self.chat_history.add_message(client_id, chat_message) graph_data = payload start_resp = ChatResponse( sender="bot", message=None, type="start", intermediate_steps="" ) - await self.send_json(client_id, start_resp) + await self.chat_history.add_message(client_id, start_resp) is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 # Generate result and thought @@ -80,7 +125,7 @@ class ChatManager: intermediate_steps=intermediate_steps or "", type="end", ) - await self.send_json(client_id, response) + await self.chat_history.add_message(client_id, response) async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) @@ -91,7 +136,8 @@ class ChatManager: while True: json_payload = await websocket.receive_json() payload = json.loads(json_payload) - await self.process_message(client_id, payload) + with self.cache_manager.set_client_id(client_id): + await self.process_message(client_id, payload) except Exception as e: # Handle any exceptions that might occur print(f"Error: {e}") @@ -123,3 +169,10 @@ async def process_graph( # Log stack trace logger.exception(e) raise e + + +def pil_to_base64(image: Image) -> str: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()) + return img_str.decode("utf-8") diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index fd9ef0816..1aefe5c8e 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -20,11 +20,23 @@ class ChatResponse(ChatMessage): intermediate_steps: str type: str - data: Any = None - data_type: str = "" @validator("type") def validate_message_type(cls, v): - if v not in ["start", "stream", "end", "error", "info"]: - raise ValueError("type must be start, stream, end, error or info") + if v not in ["start", "stream", "end", "error", "info", "file"]: + raise ValueError("type must be start, stream, end, error, info, or file") + return v + + +class FileResponse(ChatMessage): + """File response schema.""" + + data: Any + data_type: str + type: str = "file" + + @validator("data_type") + def validate_data_type(cls, v): + if v not in ["image", "csv"]: + raise ValueError("data_type must be image or csv") return v diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 41405867f..74e147075 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -32,8 +32,6 @@ def test_chat_history(client: TestClient): "message": None, "intermediate_steps": "", "type": "start", - "data": None, - "data_type": "", } # Send another message payload = {"message": "How are you?"} @@ -46,6 +44,4 @@ def test_chat_history(client: TestClient): "message": "Hello, I'm a mock response!", "intermediate_steps": "", "type": "end", - "data": None, - "data_type": "", }