feat(chat_manager.py): add support for sending file responses
fix(schemas.py): add validation for file response type and data type test(test_websocket.py): remove data and data_type fields from ChatResponse messages in tests
This commit is contained in:
parent
3da30cc5bf
commit
5169c0bc27
3 changed files with 77 additions and 16 deletions
|
|
@ -1,22 +1,30 @@
|
|||
import asyncio
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict
|
||||
from fastapi import WebSocket
|
||||
import json
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache.manager import AsyncSubject
|
||||
|
||||
from langflow.interface.run import (
|
||||
async_get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.cache import cache_manager
|
||||
from PIL.Image import Image
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
class ChatHistory(AsyncSubject):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
|
||||
|
||||
def add_message(self, client_id: str, message: ChatMessage):
|
||||
async def add_message(self, client_id: str, message: ChatMessage):
|
||||
self.history[client_id].append(message)
|
||||
await self.notify()
|
||||
|
||||
def get_history(self, client_id: str) -> List[ChatMessage]:
|
||||
return self.history[client_id]
|
||||
|
|
@ -26,6 +34,44 @@ class ChatManager:
|
|||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.chat_history = ChatHistory()
|
||||
self.chat_history.attach(self.on_chat_history_update)
|
||||
self.cache_manager = cache_manager
|
||||
self.cache_manager.attach(self.update)
|
||||
|
||||
async def on_chat_history_update(self):
|
||||
"""Send the last chat message to the client."""
|
||||
client_id = self.cache_manager.current_client_id
|
||||
if client_id in self.active_connections:
|
||||
chat_response = self.chat_history.get_history(client_id)[-1]
|
||||
if chat_response.sender == "bot":
|
||||
# Process FileResponse
|
||||
if isinstance(chat_response, FileResponse):
|
||||
# If data_type is pandas, convert to csv
|
||||
if chat_response.data_type == "pandas":
|
||||
chat_response.data = chat_response.data.to_csv()
|
||||
elif chat_response.data_type == "image":
|
||||
# Base64 encode the image
|
||||
chat_response.data = pil_to_base64(chat_response.data)
|
||||
|
||||
await self.send_json(client_id, chat_response)
|
||||
|
||||
def update(self):
|
||||
if self.cache_manager.current_client_id in self.active_connections:
|
||||
self.last_cached_object_dict = self.cache_manager.get_last()
|
||||
# Add a new ChatResponse with the data
|
||||
chat_response = FileResponse(
|
||||
sender="bot",
|
||||
message=None,
|
||||
type="file",
|
||||
data=self.last_cached_object_dict["obj"],
|
||||
data_type=self.last_cached_object_dict["type"],
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.chat_history.add_message(
|
||||
self.cache_manager.current_client_id, chat_response
|
||||
)
|
||||
)
|
||||
|
||||
async def connect(self, client_id: str, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
|
@ -40,7 +86,6 @@ class ChatManager:
|
|||
|
||||
async def send_json(self, client_id: str, message: ChatMessage):
|
||||
websocket = self.active_connections[client_id]
|
||||
self.chat_history.add_message(client_id, message)
|
||||
await websocket.send_json(json.dumps(message.dict()))
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
|
|
@ -48,13 +93,13 @@ class ChatManager:
|
|||
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(sender="you", message=chat_message)
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
await self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
start_resp = ChatResponse(
|
||||
sender="bot", message=None, type="start", intermediate_steps=""
|
||||
)
|
||||
await self.send_json(client_id, start_resp)
|
||||
await self.chat_history.add_message(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
|
||||
# Generate result and thought
|
||||
|
|
@ -80,7 +125,7 @@ class ChatManager:
|
|||
intermediate_steps=intermediate_steps or "",
|
||||
type="end",
|
||||
)
|
||||
await self.send_json(client_id, response)
|
||||
await self.chat_history.add_message(client_id, response)
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
|
|
@ -91,7 +136,8 @@ class ChatManager:
|
|||
while True:
|
||||
json_payload = await websocket.receive_json()
|
||||
payload = json.loads(json_payload)
|
||||
await self.process_message(client_id, payload)
|
||||
with self.cache_manager.set_client_id(client_id):
|
||||
await self.process_message(client_id, payload)
|
||||
except Exception as e:
|
||||
# Handle any exceptions that might occur
|
||||
print(f"Error: {e}")
|
||||
|
|
@ -123,3 +169,10 @@ async def process_graph(
|
|||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
|
||||
def pil_to_base64(image: Image) -> str:
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -20,11 +20,23 @@ class ChatResponse(ChatMessage):
|
|||
|
||||
intermediate_steps: str
|
||||
type: str
|
||||
data: Any = None
|
||||
data_type: str = ""
|
||||
|
||||
@validator("type")
|
||||
def validate_message_type(cls, v):
|
||||
if v not in ["start", "stream", "end", "error", "info"]:
|
||||
raise ValueError("type must be start, stream, end, error or info")
|
||||
if v not in ["start", "stream", "end", "error", "info", "file"]:
|
||||
raise ValueError("type must be start, stream, end, error, info, or file")
|
||||
return v
|
||||
|
||||
|
||||
class FileResponse(ChatMessage):
|
||||
"""File response schema."""
|
||||
|
||||
data: Any
|
||||
data_type: str
|
||||
type: str = "file"
|
||||
|
||||
@validator("data_type")
|
||||
def validate_data_type(cls, v):
|
||||
if v not in ["image", "csv"]:
|
||||
raise ValueError("data_type must be image or csv")
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -32,8 +32,6 @@ def test_chat_history(client: TestClient):
|
|||
"message": None,
|
||||
"intermediate_steps": "",
|
||||
"type": "start",
|
||||
"data": None,
|
||||
"data_type": "",
|
||||
}
|
||||
# Send another message
|
||||
payload = {"message": "How are you?"}
|
||||
|
|
@ -46,6 +44,4 @@ def test_chat_history(client: TestClient):
|
|||
"message": "Hello, I'm a mock response!",
|
||||
"intermediate_steps": "",
|
||||
"type": "end",
|
||||
"data": None,
|
||||
"data_type": "",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue