From 57826f12482c8cbd980e66469393d63e45584498 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 25 Apr 2023 20:26:17 -0300 Subject: [PATCH] fix(api/chat.py): catch and log exceptions in websocket endpoint fix(api/chat_manager.py): remove async from ChatHistory.add_message and on_chat_history_update fix(interface/run.py): remove async from async_get_result_and_steps refactor(utils/util.py): remove unused code and simplify sync_to_async decorator --- src/backend/langflow/api/chat.py | 8 ++++- src/backend/langflow/api/chat_manager.py | 38 +++++++++++++----------- src/backend/langflow/interface/run.py | 10 +++---- src/backend/langflow/utils/util.py | 4 +-- 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py index d5c2dc879..e25d0d2f1 100644 --- a/src/backend/langflow/api/chat.py +++ b/src/backend/langflow/api/chat.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, WebSocket from langflow.api.chat_manager import ChatManager +from langflow.utils.logger import logger router = APIRouter() chat_manager = ChatManager() @@ -9,4 +10,9 @@ chat_manager = ChatManager() @router.websocket("/chat/{client_id}") async def websocket_endpoint(client_id: str, websocket: WebSocket): """Websocket endpoint for chat.""" - await chat_manager.handle_websocket(client_id, websocket) + try: + await chat_manager.handle_websocket(client_id, websocket) + except Exception as e: + # Log stack trace + logger.exception(e) + raise e diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 8f407a791..5b6f25eff 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -8,11 +8,12 @@ import json from langchain.llms import OpenAI, AzureOpenAI from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse -from langflow.cache.manager import AsyncSubject +from langflow.cache.manager import AsyncSubject, Subject from langchain.callbacks.base import AsyncCallbackManager from langflow.api.callback import StreamingLLMCallbackHandler from langflow.interface.run import ( async_get_result_and_steps, + get_result_and_steps, load_or_build_langchain_object, ) from langflow.utils.logger import logger @@ -20,21 +21,23 @@ from langflow.cache import cache_manager from PIL.Image import Image -class ChatHistory(AsyncSubject): +class ChatHistory(Subject): def __init__(self): super().__init__() self.history: Dict[str, List[ChatMessage]] = defaultdict(list) - async def add_message(self, client_id: str, message: ChatMessage): + def add_message(self, client_id: str, message: ChatMessage): """Add a message to the chat history.""" self.history[client_id].append(message) - await self.notify() + self.notify() - def get_history(self, client_id: str) -> List[ChatMessage]: + def get_history(self, client_id: str, filter=True) -> List[ChatMessage]: """Get the chat history for a client.""" if history := self.history.get(client_id, []): - return [msg for msg in history if msg.type not in ["start", "stream"]] + if filter: + return [msg for msg in history if msg.type not in ["start", "stream"]] + return history else: return [] @@ -47,11 +50,11 @@ class ChatManager: self.cache_manager = cache_manager self.cache_manager.attach(self.update) - async def on_chat_history_update(self): + def on_chat_history_update(self): """Send the last chat message to the client.""" client_id = self.cache_manager.current_client_id if client_id in self.active_connections: - chat_response = self.chat_history.get_history(client_id)[-1] + chat_response = self.chat_history.get_history(client_id, filter=False)[-1] if chat_response.is_bot: # Process FileResponse if isinstance(chat_response, FileResponse): @@ -61,8 +64,11 @@ class ChatManager: elif chat_response.data_type == "image": # Base64 encode the image chat_response.data = pil_to_base64(chat_response.data) + # get event loop + loop = asyncio.get_event_loop() - await self.send_json(client_id, chat_response) + coroutine = self.send_json(client_id, chat_response) + asyncio.run_coroutine_threadsafe(coroutine, loop) def update(self): if self.cache_manager.current_client_id in self.active_connections: @@ -75,10 +81,8 @@ class ChatManager: data_type=self.last_cached_object_dict["type"], ) - asyncio.create_task( - self.chat_history.add_message( - self.cache_manager.current_client_id, chat_response - ) + self.chat_history.add_message( + self.cache_manager.current_client_id, chat_response ) async def connect(self, client_id: str, websocket: WebSocket): @@ -100,11 +104,11 @@ class ChatManager: # Process the graph data and chat message chat_message = payload.pop("message", "") chat_message = ChatMessage(message=chat_message) - await self.chat_history.add_message(client_id, chat_message) + self.chat_history.add_message(client_id, chat_message) graph_data = payload start_resp = ChatResponse(message=None, type="start", intermediate_steps="") - await self.chat_history.add_message(client_id, start_resp) + self.chat_history.add_message(client_id, start_resp) is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 # Generate result and thought @@ -127,7 +131,7 @@ class ChatManager: intermediate_steps=intermediate_steps or "", type="end", ) - await self.chat_history.add_message(client_id, response) + self.chat_history.add_message(client_id, response) async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) @@ -173,7 +177,7 @@ async def process_graph( # Generate result and thought try: logger.debug("Generating result and thought") - result, intermediate_steps = await async_get_result_and_steps( + result, intermediate_steps = get_result_and_steps( langchain_object, chat_message.message or "" ) logger.debug("Generated result and intermediate_steps") diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 5fb4f0045..c823ba531 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -240,7 +240,7 @@ def get_result_and_steps(langchain_object, message: str): return result, thought -async def async_get_result_and_steps(langchain_object, message: str): +def async_get_result_and_steps(langchain_object, message: str): """Get result and thought from extracted json""" try: if hasattr(langchain_object, "verbose"): @@ -267,10 +267,10 @@ async def async_get_result_and_steps(langchain_object, message: str): with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): try: - if hasattr(langchain_object, "acall"): - output = await langchain_object.acall(chat_input) - else: - output = langchain_object(chat_input) + # if hasattr(langchain_object, "acall"): + # output = await langchain_object.acall(chat_input) + # else: + output = langchain_object(chat_input) except ValueError as exc: # make the error message more informative logger.debug(f"Error: {str(exc)}") diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index eddd59ce1..080137c26 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -312,8 +312,6 @@ def sync_to_async(func): @wraps(func) async def async_wrapper(*args, **kwargs): - loop = asyncio.get_event_loop() - func_call = partial(func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) + return func(*args, **kwargs) return async_wrapper