From 18b35838504a310c2f06a7c96b401eee854646bc Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 13:13:58 -0300 Subject: [PATCH] test(websocket.py): add tests for websocket connection, chat history and sending message --- src/backend/langflow/api/chat.py | 13 ++++ src/backend/langflow/api/chat_manager.py | 99 ++++++++++++++++++++++++ src/backend/langflow/api/schemas.py | 29 +++++++ src/backend/langflow/cache/base.py | 1 + src/backend/langflow/interface/run.py | 6 +- src/backend/langflow/main.py | 2 + tests/conftest.py | 47 +++++++++++ tests/test_graph.py | 36 +-------- tests/test_websocket.py | 30 +++++++ 9 files changed, 226 insertions(+), 37 deletions(-) create mode 100644 src/backend/langflow/api/chat.py create mode 100644 src/backend/langflow/api/chat_manager.py create mode 100644 src/backend/langflow/api/schemas.py create mode 100644 tests/test_websocket.py diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py new file mode 100644 index 000000000..e40ac34ca --- /dev/null +++ b/src/backend/langflow/api/chat.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter, WebSocket +from uuid import uuid4 + +from langflow.api.chat_manager import ChatManager + +router = APIRouter() +chat_manager = ChatManager() + + +@router.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + client_id = str(uuid4()) + await chat_manager.handle_websocket(client_id, websocket) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py new file mode 100644 index 000000000..17902247b --- /dev/null +++ b/src/backend/langflow/api/chat_manager.py @@ -0,0 +1,99 @@ +from typing import Dict, List +from collections import defaultdict +from fastapi import WebSocket +import json +from langflow.api.schemas import ChatMessage, ChatResponse + +from langflow.interface.run import ( + get_result_and_steps, + load_or_build_langchain_object, +) +from langflow.utils.logger import logger + + +class ChatHistory: + def __init__(self): + self.history: Dict[str, List[ChatMessage]] = defaultdict(list) + + def add_message(self, client_id: str, message: ChatMessage): + self.history[client_id].append(message) + + def get_history(self, client_id: str) -> List[ChatMessage]: + return self.history[client_id] + + +class ChatManager: + def __init__(self): + self.active_connections: Dict[str, WebSocket] = {} + self.chat_history = ChatHistory() + + async def connect(self, client_id: str, websocket: WebSocket): + await websocket.accept() + self.active_connections[client_id] = websocket + + def disconnect(self, client_id: str): + del self.active_connections[client_id] + + async def send_message(self, client_id: str, message: str): + websocket = self.active_connections[client_id] + await websocket.send_text(message) + + async def send_json(self, client_id: str, message: Dict): + websocket = self.active_connections[client_id] + await websocket.send_json(message) + + async def process_message(self, client_id: str, payload: Dict): + # Process the graph data and chat message + + chat_message = payload.pop("message", "") + chat_message = ChatMessage(sender="user", message=chat_message) + graph_data = payload + start_resp = ChatResponse( + sender="bot", message="", type="start", intermediate_steps="" + ) + await self.send_json(client_id, start_resp.dict()) + + is_first_message = len(graph_data.get("chatHistory", [])) == 0 + langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + logger.debug("Generating result and thought") + result, intermediate_steps = get_result_and_steps( + langchain_object, chat_message.message + ) + + logger.debug("Generated result and intermediate_steps") + # Save the message to chat history + self.chat_history.add_message(client_id, chat_message) + + # Send a response back to the frontend, if needed + response = ChatResponse( + sender="bot", + message=result or "", + intermediate_steps=intermediate_steps or "", + type="end", + ) + await self.send_json(client_id, response.dict()) + + async def handle_websocket(self, client_id: str, websocket: WebSocket): + await self.connect(client_id, websocket) + try: + chat_history = self.chat_history.get_history(client_id) + await websocket.send_text(json.dumps(chat_history)) + + while True: + json_payload = await websocket.receive_text() + payload = json.loads(json_payload) + await self.process_message(client_id, payload) + except Exception as e: + # Handle any exceptions that might occur + print(f"Error: {e}") + finally: + self.disconnect(client_id) diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py new file mode 100644 index 000000000..588c35287 --- /dev/null +++ b/src/backend/langflow/api/schemas.py @@ -0,0 +1,29 @@ +from typing import Any +from pydantic import BaseModel, validator + + +class ChatMessage(BaseModel): + """Chat message schema.""" + + sender: str + message: str + + @validator("sender") + def sender_must_be_bot_or_you(cls, v): + if v not in ["bot", "you"]: + raise ValueError("sender must be bot or you") + return v + + +class ChatResponse(ChatMessage): + """Chat response schema.""" + + intermediate_steps: str + type: str + data: Any = None + + @validator("type") + def validate_message_type(cls, v): + if v not in ["start", "stream", "end", "error", "info"]: + raise ValueError("type must be start, stream, end, error or info") + return v diff --git a/src/backend/langflow/cache/base.py b/src/backend/langflow/cache/base.py index d96614218..9dd5c1780 100644 --- a/src/backend/langflow/cache/base.py +++ b/src/backend/langflow/cache/base.py @@ -1,3 +1,4 @@ +import base64 import contextlib import functools import hashlib diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 89fd7d784..f8920724a 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -86,7 +86,7 @@ def process_graph(data_graph: Dict[str, Any]): # Generate result and thought logger.debug("Generating result and thought") - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) logger.debug("Generated result and thought") # Save langchain_object to cache @@ -117,7 +117,7 @@ def process_graph_cached(data_graph: Dict[str, Any]): # Generate result and thought logger.debug("Generating result and thought") - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) logger.debug("Generated result and thought") return {"result": str(result), "thought": thought.strip()} @@ -183,7 +183,7 @@ def fix_memory_inputs(langchain_object): update_memory_keys(langchain_object, possible_new_mem_key) -def get_result_and_thought_using_graph(langchain_object, message: str): +def get_result_and_steps(langchain_object, message: str): """Get result and thought from extracted json""" try: if hasattr(langchain_object, "verbose"): diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 176e46236..c1f2decd5 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from langflow.api.endpoints import router as endpoints_router from langflow.api.validate import router as validate_router +from langflow.api.chat import router as chat_router def create_app(): @@ -23,6 +24,7 @@ def create_app(): app.include_router(endpoints_router) app.include_router(validate_router) + app.include_router(chat_router) return app diff --git a/tests/conftest.py b/tests/conftest.py index e6eb3562f..15da0d1ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,8 @@ +import json from pathlib import Path +from typing import AsyncGenerator +from httpx import AsyncClient + import pytest from fastapi.testclient import TestClient @@ -21,6 +25,15 @@ def get_text(): """ +@pytest.fixture() +async def async_client() -> AsyncGenerator: + from langflow.main import create_app + + app = create_app() + async with AsyncClient(app=app, base_url="http://testserver") as client: + yield client + + # Create client fixture for FastAPI @pytest.fixture(scope="module") def client(): @@ -30,3 +43,37 @@ def client(): with TestClient(app) as client: yield client + + +def get_graph(_type="basic"): + """Get a graph from a json file""" + from langflow.graph.graph import Graph + + if _type == "basic": + path = pytest.BASIC_EXAMPLE_PATH + elif _type == "complex": + path = pytest.COMPLEX_EXAMPLE_PATH + elif _type == "openapi": + path = pytest.OPENAPI_EXAMPLE_PATH + + with open(path, "r") as f: + flow_graph = json.load(f) + data_graph = flow_graph["data"] + nodes = data_graph["nodes"] + edges = data_graph["edges"] + return Graph(nodes, edges) + + +@pytest.fixture +def basic_graph(): + return get_graph() + + +@pytest.fixture +def complex_graph(): + return get_graph("complex") + + +@pytest.fixture +def openapi_graph(): + return get_graph("openapi") diff --git a/tests/test_graph.py b/tests/test_graph.py index c007bdd78..76451fe6a 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -15,7 +15,7 @@ from langflow.graph.nodes import ( ToolNode, WrapperNode, ) -from langflow.interface.run import get_result_and_thought_using_graph +from langflow.interface.run import get_result_and_steps from langflow.utils.payload import build_json, get_root_node # Test cases for the graph module @@ -24,38 +24,6 @@ from langflow.utils.payload import build_json, get_root_node # BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH -def get_graph(_type="basic"): - """Get a graph from a json file""" - if _type == "basic": - path = pytest.BASIC_EXAMPLE_PATH - elif _type == "complex": - path = pytest.COMPLEX_EXAMPLE_PATH - elif _type == "openapi": - path = pytest.OPENAPI_EXAMPLE_PATH - - with open(path, "r") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - edges = data_graph["edges"] - return Graph(nodes, edges) - - -@pytest.fixture -def basic_graph(): - return get_graph() - - -@pytest.fixture -def complex_graph(): - return get_graph("complex") - - -@pytest.fixture -def openapi_graph(): - return get_graph("openapi") - - def get_node_by_type(graph, node_type: Type[Node]) -> Union[Node, None]: """Get a node by type""" return next((node for node in graph.nodes if isinstance(node, node_type)), None) @@ -441,7 +409,7 @@ def test_get_result_and_thought(basic_graph): # now build again and check if FakeListLLM was used # Get the result and thought - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) # The result should be a str assert isinstance(result, str) # The thought should be a Thought diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 000000000..9ce20bc45 --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,30 @@ +import json + + +def test_websocket_connection(client): + with client.websocket_connect("/ws") as websocket: + assert websocket.client == client + assert websocket.url.path == "/ws" + + +def test_chat_history(client): + chat_history = ["Test message 1", "Test message 2"] + + with client.websocket_connect("/ws") as websocket: + received_history = websocket.receive_text() + received_history = json.loads(received_history) + + assert received_history == chat_history + + +def test_send_message(client, basic_graph): + with client.websocket_connect("/ws") as websocket: + # Send the JSON payload through the WebSocket connection + websocket.send_text(basic_graph) + + # Receive and parse the response from the server + response = websocket.receive_text() + response = json.loads(response) + + # Test that the response is as expected + assert response == "Your response message here"