feat(chat_manager.py): add support for sending file responses

fix(schemas.py): add validation for file response type and data type
test(test_websocket.py): remove data and data_type fields from ChatResponse messages in tests
This commit is contained in:
Gabriel Almeida 2023-04-20 11:09:42 -03:00
commit 5169c0bc27
3 changed files with 77 additions and 16 deletions

View file

@ -1,22 +1,30 @@
import asyncio
import base64
from io import BytesIO
from typing import Dict, List
from collections import defaultdict
from fastapi import WebSocket
import json
from langflow.api.schemas import ChatMessage, ChatResponse
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.cache.manager import AsyncSubject
from langflow.interface.run import (
async_get_result_and_steps,
load_or_build_langchain_object,
)
from langflow.utils.logger import logger
from langflow.cache import cache_manager
from PIL.Image import Image
class ChatHistory:
class ChatHistory(AsyncSubject):
def __init__(self):
super().__init__()
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
def add_message(self, client_id: str, message: ChatMessage):
async def add_message(self, client_id: str, message: ChatMessage):
self.history[client_id].append(message)
await self.notify()
def get_history(self, client_id: str) -> List[ChatMessage]:
return self.history[client_id]
@ -26,6 +34,44 @@ class ChatManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.chat_history = ChatHistory()
self.chat_history.attach(self.on_chat_history_update)
self.cache_manager = cache_manager
self.cache_manager.attach(self.update)
async def on_chat_history_update(self):
"""Send the last chat message to the client."""
client_id = self.cache_manager.current_client_id
if client_id in self.active_connections:
chat_response = self.chat_history.get_history(client_id)[-1]
if chat_response.sender == "bot":
# Process FileResponse
if isinstance(chat_response, FileResponse):
# If data_type is pandas, convert to csv
if chat_response.data_type == "pandas":
chat_response.data = chat_response.data.to_csv()
elif chat_response.data_type == "image":
# Base64 encode the image
chat_response.data = pil_to_base64(chat_response.data)
await self.send_json(client_id, chat_response)
def update(self):
if self.cache_manager.current_client_id in self.active_connections:
self.last_cached_object_dict = self.cache_manager.get_last()
# Add a new ChatResponse with the data
chat_response = FileResponse(
sender="bot",
message=None,
type="file",
data=self.last_cached_object_dict["obj"],
data_type=self.last_cached_object_dict["type"],
)
asyncio.create_task(
self.chat_history.add_message(
self.cache_manager.current_client_id, chat_response
)
)
async def connect(self, client_id: str, websocket: WebSocket):
await websocket.accept()
@ -40,7 +86,6 @@ class ChatManager:
async def send_json(self, client_id: str, message: ChatMessage):
websocket = self.active_connections[client_id]
self.chat_history.add_message(client_id, message)
await websocket.send_json(json.dumps(message.dict()))
async def process_message(self, client_id: str, payload: Dict):
@ -48,13 +93,13 @@ class ChatManager:
chat_message = payload.pop("message", "")
chat_message = ChatMessage(sender="you", message=chat_message)
self.chat_history.add_message(client_id, chat_message)
await self.chat_history.add_message(client_id, chat_message)
graph_data = payload
start_resp = ChatResponse(
sender="bot", message=None, type="start", intermediate_steps=""
)
await self.send_json(client_id, start_resp)
await self.chat_history.add_message(client_id, start_resp)
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
# Generate result and thought
@ -80,7 +125,7 @@ class ChatManager:
intermediate_steps=intermediate_steps or "",
type="end",
)
await self.send_json(client_id, response)
await self.chat_history.add_message(client_id, response)
async def handle_websocket(self, client_id: str, websocket: WebSocket):
await self.connect(client_id, websocket)
@ -91,7 +136,8 @@ class ChatManager:
while True:
json_payload = await websocket.receive_json()
payload = json.loads(json_payload)
await self.process_message(client_id, payload)
with self.cache_manager.set_client_id(client_id):
await self.process_message(client_id, payload)
except Exception as e:
# Handle any exceptions that might occur
print(f"Error: {e}")
@ -123,3 +169,10 @@ async def process_graph(
# Log stack trace
logger.exception(e)
raise e
def pil_to_base64(image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")

View file

@ -20,11 +20,23 @@ class ChatResponse(ChatMessage):
intermediate_steps: str
type: str
data: Any = None
data_type: str = ""
@validator("type")
def validate_message_type(cls, v):
if v not in ["start", "stream", "end", "error", "info"]:
raise ValueError("type must be start, stream, end, error or info")
if v not in ["start", "stream", "end", "error", "info", "file"]:
raise ValueError("type must be start, stream, end, error, info, or file")
return v
class FileResponse(ChatMessage):
"""File response schema."""
data: Any
data_type: str
type: str = "file"
@validator("data_type")
def validate_data_type(cls, v):
if v not in ["image", "csv"]:
raise ValueError("data_type must be image or csv")
return v

View file

@ -32,8 +32,6 @@ def test_chat_history(client: TestClient):
"message": None,
"intermediate_steps": "",
"type": "start",
"data": None,
"data_type": "",
}
# Send another message
payload = {"message": "How are you?"}
@ -46,6 +44,4 @@ def test_chat_history(client: TestClient):
"message": "Hello, I'm a mock response!",
"intermediate_steps": "",
"type": "end",
"data": None,
"data_type": "",
}