refactor(chat_manager.py): move process_graph function outside of ChatManager class
test(websocket.py): add tests for websocket connection, chat history, and sending messages
This commit is contained in:
parent
7d183ff57e
commit
0a630cd70d
2 changed files with 74 additions and 35 deletions
|
|
@ -41,13 +41,13 @@ class ChatManager:
|
|||
async def send_json(self, client_id: str, message: ChatMessage):
|
||||
websocket = self.active_connections[client_id]
|
||||
self.chat_history.add_message(client_id, message)
|
||||
await websocket.send_json(message.dict())
|
||||
await websocket.send_json(json.dumps(message.dict()))
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
# Process the graph data and chat message
|
||||
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(sender="user", message=chat_message)
|
||||
chat_message = ChatMessage(sender="you", message=chat_message)
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
|
|
@ -57,22 +57,14 @@ class ChatManager:
|
|||
await self.send_json(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await async_get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
result, intermediate_steps = await process_graph(
|
||||
graph_data=graph_data,
|
||||
is_first_message=is_first_message,
|
||||
chat_message=chat_message,
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
@ -105,3 +97,29 @@ class ChatManager:
|
|||
print(f"Error: {e}")
|
||||
finally:
|
||||
self.disconnect(client_id)
|
||||
|
||||
|
||||
async def process_graph(
|
||||
graph_data: Dict, is_first_message: bool, chat_message: ChatMessage
|
||||
):
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await async_get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -1,30 +1,51 @@
|
|||
import json
|
||||
from unittest.mock import patch
|
||||
from langflow.api.schemas import ChatMessage
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_websocket_connection(client):
|
||||
with client.websocket_connect("/ws") as websocket:
|
||||
assert websocket.client == client
|
||||
assert websocket.url.path == "/ws"
|
||||
def test_websocket_connection(client: TestClient):
|
||||
with client.websocket_connect("/ws/test_client") as websocket:
|
||||
assert websocket.scope["client"] == ["testclient", 50000]
|
||||
assert websocket.scope["path"] == "/ws/test_client"
|
||||
|
||||
|
||||
def test_chat_history(client):
|
||||
chat_history = ["Test message 1", "Test message 2"]
|
||||
def test_chat_history(client: TestClient):
|
||||
chat_history = []
|
||||
|
||||
with client.websocket_connect("/ws") as websocket:
|
||||
received_history = websocket.receive_text()
|
||||
received_history = json.loads(received_history)
|
||||
# Mock the process_graph function to return a specific value
|
||||
with patch("langflow.api.chat_manager.process_graph") as mock_process_graph:
|
||||
mock_process_graph.return_value = ("Hello, I'm a mock response!", "")
|
||||
|
||||
assert received_history == chat_history
|
||||
with client.websocket_connect("/ws/test_client") as websocket:
|
||||
# First message should be the history
|
||||
history = websocket.receive_json()
|
||||
assert json.loads(history) == [] # Empty history
|
||||
# Send a message
|
||||
payload = {"message": "Hello"}
|
||||
websocket.send_json(json.dumps(payload))
|
||||
|
||||
# Receive the response from the server
|
||||
response = websocket.receive_json()
|
||||
assert json.loads(response) == {
|
||||
"sender": "bot",
|
||||
"message": None,
|
||||
"intermediate_steps": "",
|
||||
"type": "start",
|
||||
"data": None,
|
||||
"data_type": "",
|
||||
}
|
||||
# Send another message
|
||||
payload = {"message": "How are you?"}
|
||||
websocket.send_json(json.dumps(payload))
|
||||
|
||||
def test_send_message(client, basic_graph):
|
||||
with client.websocket_connect("/ws") as websocket:
|
||||
# Send the JSON payload through the WebSocket connection
|
||||
websocket.send_text(basic_graph)
|
||||
|
||||
# Receive and parse the response from the server
|
||||
response = websocket.receive_text()
|
||||
response = json.loads(response)
|
||||
|
||||
# Test that the response is as expected
|
||||
assert response == "Your response message here"
|
||||
# Receive the response from the server
|
||||
response = websocket.receive_json()
|
||||
assert json.loads(response) == {
|
||||
"sender": "bot",
|
||||
"message": "Hello, I'm a mock response!",
|
||||
"intermediate_steps": "",
|
||||
"type": "end",
|
||||
"data": None,
|
||||
"data_type": "",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue