From 0a630cd70daa593f74ea3dbe700058dc811d906d Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 22:23:31 -0300 Subject: [PATCH] refactor(chat_manager.py): move process_graph function outside of ChatManager class test(websocket.py): add tests for websocket connection, chat history, and sending messages --- src/backend/langflow/api/chat_manager.py | 46 +++++++++++------ tests/test_websocket.py | 63 ++++++++++++++++-------- 2 files changed, 74 insertions(+), 35 deletions(-) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 384de4d12..7ca04abf7 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -41,13 +41,13 @@ class ChatManager: async def send_json(self, client_id: str, message: ChatMessage): websocket = self.active_connections[client_id] self.chat_history.add_message(client_id, message) - await websocket.send_json(message.dict()) + await websocket.send_json(json.dumps(message.dict())) async def process_message(self, client_id: str, payload: Dict): # Process the graph data and chat message chat_message = payload.pop("message", "") - chat_message = ChatMessage(sender="user", message=chat_message) + chat_message = ChatMessage(sender="you", message=chat_message) self.chat_history.add_message(client_id, chat_message) graph_data = payload @@ -57,22 +57,14 @@ class ChatManager: await self.send_json(client_id, start_resp) is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 - langchain_object = load_or_build_langchain_object(graph_data, is_first_message) - logger.debug("Loaded langchain object") - - if langchain_object is None: - # Raise user facing error - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) - # Generate result and thought try: logger.debug("Generating result and thought") - result, intermediate_steps = await async_get_result_and_steps( - langchain_object, chat_message.message or "" + result, intermediate_steps = await process_graph( + graph_data=graph_data, + is_first_message=is_first_message, + chat_message=chat_message, ) - logger.debug("Generated result and intermediate_steps") except Exception as e: # Log stack trace logger.exception(e) @@ -105,3 +97,29 @@ class ChatManager: print(f"Error: {e}") finally: self.disconnect(client_id) + + +async def process_graph( + graph_data: Dict, is_first_message: bool, chat_message: ChatMessage +): + langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + try: + logger.debug("Generating result and thought") + result, intermediate_steps = await async_get_result_and_steps( + langchain_object, chat_message.message or "" + ) + logger.debug("Generated result and intermediate_steps") + return result, intermediate_steps + except Exception as e: + # Log stack trace + logger.exception(e) + raise e diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 9ce20bc45..41405867f 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,30 +1,51 @@ import json +from unittest.mock import patch +from langflow.api.schemas import ChatMessage +from fastapi.testclient import TestClient -def test_websocket_connection(client): - with client.websocket_connect("/ws") as websocket: - assert websocket.client == client - assert websocket.url.path == "/ws" +def test_websocket_connection(client: TestClient): + with client.websocket_connect("/ws/test_client") as websocket: + assert websocket.scope["client"] == ["testclient", 50000] + assert websocket.scope["path"] == "/ws/test_client" -def test_chat_history(client): - chat_history = ["Test message 1", "Test message 2"] +def test_chat_history(client: TestClient): + chat_history = [] - with client.websocket_connect("/ws") as websocket: - received_history = websocket.receive_text() - received_history = json.loads(received_history) + # Mock the process_graph function to return a specific value + with patch("langflow.api.chat_manager.process_graph") as mock_process_graph: + mock_process_graph.return_value = ("Hello, I'm a mock response!", "") - assert received_history == chat_history + with client.websocket_connect("/ws/test_client") as websocket: + # First message should be the history + history = websocket.receive_json() + assert json.loads(history) == [] # Empty history + # Send a message + payload = {"message": "Hello"} + websocket.send_json(json.dumps(payload)) + # Receive the response from the server + response = websocket.receive_json() + assert json.loads(response) == { + "sender": "bot", + "message": None, + "intermediate_steps": "", + "type": "start", + "data": None, + "data_type": "", + } + # Send another message + payload = {"message": "How are you?"} + websocket.send_json(json.dumps(payload)) -def test_send_message(client, basic_graph): - with client.websocket_connect("/ws") as websocket: - # Send the JSON payload through the WebSocket connection - websocket.send_text(basic_graph) - - # Receive and parse the response from the server - response = websocket.receive_text() - response = json.loads(response) - - # Test that the response is as expected - assert response == "Your response message here" + # Receive the response from the server + response = websocket.receive_json() + assert json.loads(response) == { + "sender": "bot", + "message": "Hello, I'm a mock response!", + "intermediate_steps": "", + "type": "end", + "data": None, + "data_type": "", + }