fix(api/chat.py): catch and log exceptions in websocket endpoint

fix(api/chat_manager.py): remove async from ChatHistory.add_message and on_chat_history_update
fix(interface/run.py): remove async from async_get_result_and_steps
refactor(utils/util.py): remove unused code and simplify sync_to_async decorator
This commit is contained in:
Gabriel Almeida 2023-04-25 20:26:17 -03:00
commit 57826f1248
4 changed files with 34 additions and 26 deletions

View file

@ -1,6 +1,7 @@
from fastapi import APIRouter, WebSocket
from langflow.api.chat_manager import ChatManager
from langflow.utils.logger import logger
router = APIRouter()
chat_manager = ChatManager()
@ -9,4 +10,9 @@ chat_manager = ChatManager()
@router.websocket("/chat/{client_id}")
async def websocket_endpoint(client_id: str, websocket: WebSocket):
"""Websocket endpoint for chat."""
await chat_manager.handle_websocket(client_id, websocket)
try:
await chat_manager.handle_websocket(client_id, websocket)
except Exception as e:
# Log stack trace
logger.exception(e)
raise e

View file

@ -8,11 +8,12 @@ import json
from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.cache.manager import AsyncSubject
from langflow.cache.manager import AsyncSubject, Subject
from langchain.callbacks.base import AsyncCallbackManager
from langflow.api.callback import StreamingLLMCallbackHandler
from langflow.interface.run import (
async_get_result_and_steps,
get_result_and_steps,
load_or_build_langchain_object,
)
from langflow.utils.logger import logger
@ -20,21 +21,23 @@ from langflow.cache import cache_manager
from PIL.Image import Image
class ChatHistory(AsyncSubject):
class ChatHistory(Subject):
def __init__(self):
super().__init__()
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
async def add_message(self, client_id: str, message: ChatMessage):
def add_message(self, client_id: str, message: ChatMessage):
"""Add a message to the chat history."""
self.history[client_id].append(message)
await self.notify()
self.notify()
def get_history(self, client_id: str) -> List[ChatMessage]:
def get_history(self, client_id: str, filter=True) -> List[ChatMessage]:
"""Get the chat history for a client."""
if history := self.history.get(client_id, []):
return [msg for msg in history if msg.type not in ["start", "stream"]]
if filter:
return [msg for msg in history if msg.type not in ["start", "stream"]]
return history
else:
return []
@ -47,11 +50,11 @@ class ChatManager:
self.cache_manager = cache_manager
self.cache_manager.attach(self.update)
async def on_chat_history_update(self):
def on_chat_history_update(self):
"""Send the last chat message to the client."""
client_id = self.cache_manager.current_client_id
if client_id in self.active_connections:
chat_response = self.chat_history.get_history(client_id)[-1]
chat_response = self.chat_history.get_history(client_id, filter=False)[-1]
if chat_response.is_bot:
# Process FileResponse
if isinstance(chat_response, FileResponse):
@ -61,8 +64,11 @@ class ChatManager:
elif chat_response.data_type == "image":
# Base64 encode the image
chat_response.data = pil_to_base64(chat_response.data)
# get event loop
loop = asyncio.get_event_loop()
await self.send_json(client_id, chat_response)
coroutine = self.send_json(client_id, chat_response)
asyncio.run_coroutine_threadsafe(coroutine, loop)
def update(self):
if self.cache_manager.current_client_id in self.active_connections:
@ -75,10 +81,8 @@ class ChatManager:
data_type=self.last_cached_object_dict["type"],
)
asyncio.create_task(
self.chat_history.add_message(
self.cache_manager.current_client_id, chat_response
)
self.chat_history.add_message(
self.cache_manager.current_client_id, chat_response
)
async def connect(self, client_id: str, websocket: WebSocket):
@ -100,11 +104,11 @@ class ChatManager:
# Process the graph data and chat message
chat_message = payload.pop("message", "")
chat_message = ChatMessage(message=chat_message)
await self.chat_history.add_message(client_id, chat_message)
self.chat_history.add_message(client_id, chat_message)
graph_data = payload
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
await self.chat_history.add_message(client_id, start_resp)
self.chat_history.add_message(client_id, start_resp)
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
# Generate result and thought
@ -127,7 +131,7 @@ class ChatManager:
intermediate_steps=intermediate_steps or "",
type="end",
)
await self.chat_history.add_message(client_id, response)
self.chat_history.add_message(client_id, response)
async def handle_websocket(self, client_id: str, websocket: WebSocket):
await self.connect(client_id, websocket)
@ -173,7 +177,7 @@ async def process_graph(
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await async_get_result_and_steps(
result, intermediate_steps = get_result_and_steps(
langchain_object, chat_message.message or ""
)
logger.debug("Generated result and intermediate_steps")

View file

@ -240,7 +240,7 @@ def get_result_and_steps(langchain_object, message: str):
return result, thought
async def async_get_result_and_steps(langchain_object, message: str):
def async_get_result_and_steps(langchain_object, message: str):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
@ -267,10 +267,10 @@ async def async_get_result_and_steps(langchain_object, message: str):
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
if hasattr(langchain_object, "acall"):
output = await langchain_object.acall(chat_input)
else:
output = langchain_object(chat_input)
# if hasattr(langchain_object, "acall"):
# output = await langchain_object.acall(chat_input)
# else:
output = langchain_object(chat_input)
except ValueError as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")

View file

@ -312,8 +312,6 @@ def sync_to_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
func_call = partial(func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)
return func(*args, **kwargs)
return async_wrapper