refactor(chat_manager.py): move process_graph function outside of ChatManager class

test(websocket.py): add tests for websocket connection, chat history, and sending messages
This commit is contained in:
Gabriel Almeida 2023-04-19 22:23:31 -03:00
commit 0a630cd70d
2 changed files with 74 additions and 35 deletions

View file

@ -41,13 +41,13 @@ class ChatManager:
async def send_json(self, client_id: str, message: ChatMessage):
websocket = self.active_connections[client_id]
self.chat_history.add_message(client_id, message)
await websocket.send_json(message.dict())
await websocket.send_json(json.dumps(message.dict()))
async def process_message(self, client_id: str, payload: Dict):
# Process the graph data and chat message
chat_message = payload.pop("message", "")
chat_message = ChatMessage(sender="user", message=chat_message)
chat_message = ChatMessage(sender="you", message=chat_message)
self.chat_history.add_message(client_id, chat_message)
graph_data = payload
@ -57,22 +57,14 @@ class ChatManager:
await self.send_json(client_id, start_resp)
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await async_get_result_and_steps(
langchain_object, chat_message.message or ""
result, intermediate_steps = await process_graph(
graph_data=graph_data,
is_first_message=is_first_message,
chat_message=chat_message,
)
logger.debug("Generated result and intermediate_steps")
except Exception as e:
# Log stack trace
logger.exception(e)
@ -105,3 +97,29 @@ class ChatManager:
print(f"Error: {e}")
finally:
self.disconnect(client_id)
async def process_graph(
graph_data: Dict, is_first_message: bool, chat_message: ChatMessage
):
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await async_get_result_and_steps(
langchain_object, chat_message.message or ""
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps
except Exception as e:
# Log stack trace
logger.exception(e)
raise e

View file

@ -1,30 +1,51 @@
import json
from unittest.mock import patch
from langflow.api.schemas import ChatMessage
from fastapi.testclient import TestClient
def test_websocket_connection(client):
with client.websocket_connect("/ws") as websocket:
assert websocket.client == client
assert websocket.url.path == "/ws"
def test_websocket_connection(client: TestClient):
with client.websocket_connect("/ws/test_client") as websocket:
assert websocket.scope["client"] == ["testclient", 50000]
assert websocket.scope["path"] == "/ws/test_client"
def test_chat_history(client):
chat_history = ["Test message 1", "Test message 2"]
def test_chat_history(client: TestClient):
chat_history = []
with client.websocket_connect("/ws") as websocket:
received_history = websocket.receive_text()
received_history = json.loads(received_history)
# Mock the process_graph function to return a specific value
with patch("langflow.api.chat_manager.process_graph") as mock_process_graph:
mock_process_graph.return_value = ("Hello, I'm a mock response!", "")
assert received_history == chat_history
with client.websocket_connect("/ws/test_client") as websocket:
# First message should be the history
history = websocket.receive_json()
assert json.loads(history) == [] # Empty history
# Send a message
payload = {"message": "Hello"}
websocket.send_json(json.dumps(payload))
# Receive the response from the server
response = websocket.receive_json()
assert json.loads(response) == {
"sender": "bot",
"message": None,
"intermediate_steps": "",
"type": "start",
"data": None,
"data_type": "",
}
# Send another message
payload = {"message": "How are you?"}
websocket.send_json(json.dumps(payload))
def test_send_message(client, basic_graph):
with client.websocket_connect("/ws") as websocket:
# Send the JSON payload through the WebSocket connection
websocket.send_text(basic_graph)
# Receive and parse the response from the server
response = websocket.receive_text()
response = json.loads(response)
# Test that the response is as expected
assert response == "Your response message here"
# Receive the response from the server
response = websocket.receive_json()
assert json.loads(response) == {
"sender": "bot",
"message": "Hello, I'm a mock response!",
"intermediate_steps": "",
"type": "end",
"data": None,
"data_type": "",
}