test(websocket.py): add tests for websocket connection, chat history and sending message
This commit is contained in:
parent
e4d0a39b0b
commit
18b3583850
9 changed files with 226 additions and 37 deletions
13
src/backend/langflow/api/chat.py
Normal file
13
src/backend/langflow/api/chat.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from fastapi import APIRouter, WebSocket
|
||||
from uuid import uuid4
|
||||
|
||||
from langflow.api.chat_manager import ChatManager
|
||||
|
||||
router = APIRouter()
|
||||
chat_manager = ChatManager()
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
client_id = str(uuid4())
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
99
src/backend/langflow/api/chat_manager.py
Normal file
99
src/backend/langflow/api/chat_manager.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from typing import Dict, List
|
||||
from collections import defaultdict
|
||||
from fastapi import WebSocket
|
||||
import json
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse
|
||||
|
||||
from langflow.interface.run import (
|
||||
get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
def __init__(self):
|
||||
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
|
||||
|
||||
def add_message(self, client_id: str, message: ChatMessage):
|
||||
self.history[client_id].append(message)
|
||||
|
||||
def get_history(self, client_id: str) -> List[ChatMessage]:
|
||||
return self.history[client_id]
|
||||
|
||||
|
||||
class ChatManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.chat_history = ChatHistory()
|
||||
|
||||
async def connect(self, client_id: str, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections[client_id] = websocket
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
del self.active_connections[client_id]
|
||||
|
||||
async def send_message(self, client_id: str, message: str):
|
||||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_text(message)
|
||||
|
||||
async def send_json(self, client_id: str, message: Dict):
|
||||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_json(message)
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
# Process the graph data and chat message
|
||||
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(sender="user", message=chat_message)
|
||||
graph_data = payload
|
||||
start_resp = ChatResponse(
|
||||
sender="bot", message="", type="start", intermediate_steps=""
|
||||
)
|
||||
await self.send_json(client_id, start_resp.dict())
|
||||
|
||||
is_first_message = len(graph_data.get("chatHistory", [])) == 0
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = get_result_and_steps(
|
||||
langchain_object, chat_message.message
|
||||
)
|
||||
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
# Save the message to chat history
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
# Send a response back to the frontend, if needed
|
||||
response = ChatResponse(
|
||||
sender="bot",
|
||||
message=result or "",
|
||||
intermediate_steps=intermediate_steps or "",
|
||||
type="end",
|
||||
)
|
||||
await self.send_json(client_id, response.dict())
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
try:
|
||||
chat_history = self.chat_history.get_history(client_id)
|
||||
await websocket.send_text(json.dumps(chat_history))
|
||||
|
||||
while True:
|
||||
json_payload = await websocket.receive_text()
|
||||
payload = json.loads(json_payload)
|
||||
await self.process_message(client_id, payload)
|
||||
except Exception as e:
|
||||
# Handle any exceptions that might occur
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
self.disconnect(client_id)
|
||||
29
src/backend/langflow/api/schemas.py
Normal file
29
src/backend/langflow/api/schemas.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Any
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Chat message schema."""
|
||||
|
||||
sender: str
|
||||
message: str
|
||||
|
||||
@validator("sender")
|
||||
def sender_must_be_bot_or_you(cls, v):
|
||||
if v not in ["bot", "you"]:
|
||||
raise ValueError("sender must be bot or you")
|
||||
return v
|
||||
|
||||
|
||||
class ChatResponse(ChatMessage):
|
||||
"""Chat response schema."""
|
||||
|
||||
intermediate_steps: str
|
||||
type: str
|
||||
data: Any = None
|
||||
|
||||
@validator("type")
|
||||
def validate_message_type(cls, v):
|
||||
if v not in ["start", "stream", "end", "error", "info"]:
|
||||
raise ValueError("type must be start, stream, end, error or info")
|
||||
return v
|
||||
1
src/backend/langflow/cache/base.py
vendored
1
src/backend/langflow/cache/base.py
vendored
|
|
@ -1,3 +1,4 @@
|
|||
import base64
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
|
||||
# Save langchain_object to cache
|
||||
|
|
@ -117,7 +117,7 @@ def process_graph_cached(data_graph: Dict[str, Any]):
|
|||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
|
@ -183,7 +183,7 @@ def fix_memory_inputs(langchain_object):
|
|||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
def get_result_and_thought_using_graph(langchain_object, message: str):
|
||||
def get_result_and_steps(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
|
||||
from langflow.api.endpoints import router as endpoints_router
|
||||
from langflow.api.validate import router as validate_router
|
||||
from langflow.api.chat import router as chat_router
|
||||
|
||||
|
||||
def create_app():
|
||||
|
|
@ -23,6 +24,7 @@ def create_app():
|
|||
|
||||
app.include_router(endpoints_router)
|
||||
app.include_router(validate_router)
|
||||
app.include_router(chat_router)
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -21,6 +25,15 @@ def get_text():
|
|||
"""
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def async_client() -> AsyncGenerator:
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
async with AsyncClient(app=app, base_url="http://testserver") as client:
|
||||
yield client
|
||||
|
||||
|
||||
# Create client fixture for FastAPI
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
|
|
@ -30,3 +43,37 @@ def client():
|
|||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
from langflow.graph.graph import Graph
|
||||
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_graph():
|
||||
return get_graph("openapi")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from langflow.graph.nodes import (
|
|||
ToolNode,
|
||||
WrapperNode,
|
||||
)
|
||||
from langflow.interface.run import get_result_and_thought_using_graph
|
||||
from langflow.interface.run import get_result_and_steps
|
||||
from langflow.utils.payload import build_json, get_root_node
|
||||
|
||||
# Test cases for the graph module
|
||||
|
|
@ -24,38 +24,6 @@ from langflow.utils.payload import build_json, get_root_node
|
|||
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_graph():
|
||||
return get_graph("openapi")
|
||||
|
||||
|
||||
def get_node_by_type(graph, node_type: Type[Node]) -> Union[Node, None]:
|
||||
"""Get a node by type"""
|
||||
return next((node for node in graph.nodes if isinstance(node, node_type)), None)
|
||||
|
|
@ -441,7 +409,7 @@ def test_get_result_and_thought(basic_graph):
|
|||
# now build again and check if FakeListLLM was used
|
||||
|
||||
# Get the result and thought
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
# The result should be a str
|
||||
assert isinstance(result, str)
|
||||
# The thought should be a Thought
|
||||
|
|
|
|||
30
tests/test_websocket.py
Normal file
30
tests/test_websocket.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import json
|
||||
|
||||
|
||||
def test_websocket_connection(client):
|
||||
with client.websocket_connect("/ws") as websocket:
|
||||
assert websocket.client == client
|
||||
assert websocket.url.path == "/ws"
|
||||
|
||||
|
||||
def test_chat_history(client):
|
||||
chat_history = ["Test message 1", "Test message 2"]
|
||||
|
||||
with client.websocket_connect("/ws") as websocket:
|
||||
received_history = websocket.receive_text()
|
||||
received_history = json.loads(received_history)
|
||||
|
||||
assert received_history == chat_history
|
||||
|
||||
|
||||
def test_send_message(client, basic_graph):
|
||||
with client.websocket_connect("/ws") as websocket:
|
||||
# Send the JSON payload through the WebSocket connection
|
||||
websocket.send_text(basic_graph)
|
||||
|
||||
# Receive and parse the response from the server
|
||||
response = websocket.receive_text()
|
||||
response = json.loads(response)
|
||||
|
||||
# Test that the response is as expected
|
||||
assert response == "Your response message here"
|
||||
Loading…
Add table
Add a link
Reference in a new issue