From 7d183ff57ed1fdc23ee9993a8d563925c6ad52c0 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 21:28:05 -0300 Subject: [PATCH] refactor(chat.py, chat_manager.py, schemas.py, run.py): add chat history to ChatManager and ChatMessage schema feat(chat.py, chat_manager.py): add error handling for async_get_result_and_steps feat(chat.py): add client_id to websocket endpoint feat(schemas.py): add data_type field to ChatResponse schema refactor(run.py): memoize build_langchain_object_with_caching function with maxsize of 10 --- src/backend/langflow/api/chat.py | 5 +-- src/backend/langflow/api/chat_manager.py | 44 ++++++++++-------- src/backend/langflow/api/schemas.py | 5 ++- src/backend/langflow/interface/run.py | 57 +++++++++++++++++++++++- 4 files changed, 87 insertions(+), 24 deletions(-) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py index e40ac34ca..11b861c77 100644 --- a/src/backend/langflow/api/chat.py +++ b/src/backend/langflow/api/chat.py @@ -7,7 +7,6 @@ router = APIRouter() chat_manager = ChatManager() -@router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket): - client_id = str(uuid4()) +@router.websocket("/ws/{client_id}") +async def websocket_endpoint(client_id: str, websocket: WebSocket): await chat_manager.handle_websocket(client_id, websocket) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 17902247b..384de4d12 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -5,7 +5,7 @@ import json from langflow.api.schemas import ChatMessage, ChatResponse from langflow.interface.run import ( - get_result_and_steps, + async_get_result_and_steps, load_or_build_langchain_object, ) from langflow.utils.logger import logger @@ -38,22 +38,25 @@ class ChatManager: websocket = self.active_connections[client_id] await websocket.send_text(message) - async def send_json(self, client_id: str, message: Dict): + async def send_json(self, client_id: str, message: ChatMessage): websocket = self.active_connections[client_id] - await websocket.send_json(message) + self.chat_history.add_message(client_id, message) + await websocket.send_json(message.dict()) async def process_message(self, client_id: str, payload: Dict): # Process the graph data and chat message chat_message = payload.pop("message", "") chat_message = ChatMessage(sender="user", message=chat_message) + self.chat_history.add_message(client_id, chat_message) + graph_data = payload start_resp = ChatResponse( - sender="bot", message="", type="start", intermediate_steps="" + sender="bot", message=None, type="start", intermediate_steps="" ) - await self.send_json(client_id, start_resp.dict()) + await self.send_json(client_id, start_resp) - is_first_message = len(graph_data.get("chatHistory", [])) == 0 + is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 langchain_object = load_or_build_langchain_object(graph_data, is_first_message) logger.debug("Loaded langchain object") @@ -64,15 +67,20 @@ class ChatManager: ) # Generate result and thought - logger.debug("Generating result and thought") - result, intermediate_steps = get_result_and_steps( - langchain_object, chat_message.message - ) - - logger.debug("Generated result and intermediate_steps") - # Save the message to chat history - self.chat_history.add_message(client_id, chat_message) - + try: + logger.debug("Generating result and thought") + result, intermediate_steps = await async_get_result_and_steps( + langchain_object, chat_message.message or "" + ) + logger.debug("Generated result and intermediate_steps") + except Exception as e: + # Log stack trace + logger.exception(e) + error_resp = ChatResponse( + sender="bot", message=str(e), type="error", intermediate_steps="" + ) + await self.send_json(client_id, error_resp) + return # Send a response back to the frontend, if needed response = ChatResponse( sender="bot", @@ -80,16 +88,16 @@ class ChatManager: intermediate_steps=intermediate_steps or "", type="end", ) - await self.send_json(client_id, response.dict()) + await self.send_json(client_id, response) async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) try: chat_history = self.chat_history.get_history(client_id) - await websocket.send_text(json.dumps(chat_history)) + await websocket.send_json(json.dumps(chat_history)) while True: - json_payload = await websocket.receive_text() + json_payload = await websocket.receive_json() payload = json.loads(json_payload) await self.process_message(client_id, payload) except Exception as e: diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index 588c35287..fd9ef0816 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Union from pydantic import BaseModel, validator @@ -6,7 +6,7 @@ class ChatMessage(BaseModel): """Chat message schema.""" sender: str - message: str + message: Union[str, None] = None @validator("sender") def sender_must_be_bot_or_you(cls, v): @@ -21,6 +21,7 @@ class ChatResponse(ChatMessage): intermediate_steps: str type: str data: Any = None + data_type: str = "" @validator("type") def validate_message_type(cls, v): diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index f8920724a..110d3827f 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -31,7 +31,7 @@ def load_or_build_langchain_object(data_graph, is_first_message=False): return build_langchain_object_with_caching(data_graph) -@memoize_dict(maxsize=1) +@memoize_dict(maxsize=10) def build_langchain_object_with_caching(data_graph): """ Build langchain object from data_graph. @@ -235,6 +235,61 @@ def get_result_and_steps(langchain_object, message: str): return result, thought +async def async_get_result_and_steps(langchain_object, message: str): + """Get result and thought from extracted json""" + try: + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + chat_input = None + memory_key = "" + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + memory_key = langchain_object.memory.memory_key + + if hasattr(langchain_object, "input_keys"): + for key in langchain_object.input_keys: + if key not in [memory_key, "chat_history"]: + chat_input = {key: message} + else: + chat_input = message # type: ignore + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + + fix_memory_inputs(langchain_object) + + with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): + try: + if hasattr(langchain_object, "acall"): + output = await langchain_object.acall(chat_input) + else: + output = langchain_object(chat_input) + except ValueError as exc: + # make the error message more informative + logger.debug(f"Error: {str(exc)}") + output = langchain_object.run(chat_input) + + intermediate_steps = ( + output.get("intermediate_steps", []) if isinstance(output, dict) else [] + ) + + result = ( + output.get(langchain_object.output_keys[0]) + if isinstance(output, dict) + else output + ) + if intermediate_steps: + thought = format_intermediate_steps(intermediate_steps) + else: + thought = output_buffer.getvalue() + + except Exception as exc: + raise ValueError(f"Error: {str(exc)}") from exc + return result, thought + + def get_result_and_thought(extracted_json: Dict[str, Any], message: str): """Get result and thought from extracted json""" try: