tests: add tests for streaming calls on api/v1/run (#7385)
* test: add concurrent streaming request tests for chat input type Implemented a new test for concurrent streaming requests to the run endpoint with chat input type. Added a helper coroutine to validate the streaming response, ensuring proper event handling and result verification. This enhances the test coverage for the streaming functionality. * refactor: replace session_getter with session_scope in API key CRUD operations Updated the API key CRUD operations to utilize session_scope instead of session_getter for better session management. This change enhances the clarity and robustness of the database interactions. * test: enhance assertions and error handling in streaming tests Refactored assertions in the streaming tests to provide clearer error messages and improve robustness. Added error handling for JSON parsing in the stream response and ensured that all expected fields are validated with informative messages. Updated the test for concurrent streaming requests to use the correct project ID and modified input values for better clarity. * test: refactor get_starter_project fixture for improved session management and data handling Updated the `get_starter_project` fixture to use `session_scope` for better session management. Enhanced the flow data processing by replacing the OpenAI API key and ensuring the `load_from_db` flag is set to false, improving robustness and clarity in test setup. * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
9660d50867
commit
7a87880e49
3 changed files with 127 additions and 8 deletions
|
|
@ -10,8 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||
|
||||
from langflow.services.database.models import User
|
||||
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
|
@ -68,7 +67,7 @@ async def check_key(session: AsyncSession, api_key: str) -> User | None:
|
|||
|
||||
async def update_total_uses(api_key_id: UUID):
|
||||
"""Update the total uses and last used at."""
|
||||
async with session_getter(get_db_service()) as session:
|
||||
async with session_scope() as session:
|
||||
new_api_key = await session.get(ApiKey, api_key_id)
|
||||
if new_api_key is None:
|
||||
msg = "API Key not found"
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from langflow.services.database.models.transactions.model import TransactionTabl
|
|||
from langflow.services.database.models.user.model import User, UserCreate, UserRead
|
||||
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from langflow.services.deps import get_db_service, session_scope
|
||||
from loguru import logger
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
|
@ -655,13 +655,13 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test):
|
|||
|
||||
|
||||
@pytest.fixture(name="starter_project")
|
||||
async def get_starter_project(active_user):
|
||||
async def get_starter_project(client, active_user): # noqa: ARG001
|
||||
# once the client is created, we can get the starter project
|
||||
async with session_getter(get_db_service()) as session:
|
||||
async with session_scope() as session:
|
||||
stmt = (
|
||||
select(Flow)
|
||||
.where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME))
|
||||
.where(Flow.name == "Basic Prompting (Hello, World)")
|
||||
.where(Flow.name == "Basic Prompting")
|
||||
)
|
||||
flow = (await session.exec(stmt)).first()
|
||||
if not flow:
|
||||
|
|
@ -669,7 +669,17 @@ async def get_starter_project(active_user):
|
|||
raise ValueError(msg)
|
||||
|
||||
# ensure openai api key is set
|
||||
get_openai_api_key()
|
||||
openai_api_key = get_openai_api_key()
|
||||
data_as_json = json.dumps(flow.data)
|
||||
data_as_json = data_as_json.replace("OPENAI_API_KEY", openai_api_key)
|
||||
# also replace `"load_from_db": true` with `"load_from_db": false`
|
||||
if '"load_from_db": true' in data_as_json:
|
||||
data_as_json = data_as_json.replace('"load_from_db": true', '"load_from_db": false')
|
||||
if '"load_from_db": true' in data_as_json:
|
||||
msg = "load_from_db should be false"
|
||||
raise ValueError(msg)
|
||||
flow.data = json.loads(data_as_json)
|
||||
|
||||
new_flow_create = FlowCreate(
|
||||
name=flow.name,
|
||||
description=flow.description,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import json
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
|
@ -526,3 +527,112 @@ async def test_starter_projects(client, created_api_key):
|
|||
headers = {"x-api-key": created_api_key.api_key}
|
||||
response = await client.get("api/v1/starter-projects/", headers=headers)
|
||||
assert response.status_code == status.HTTP_200_OK, response.text
|
||||
|
||||
|
||||
async def _run_single_stream_test(client: AsyncClient, flow_id: str, headers: dict, payload: dict):
|
||||
"""Helper coroutine to run and validate a single streaming request."""
|
||||
received_events = [] # Track all event types in sequence
|
||||
got_end_event = False
|
||||
final_result = None
|
||||
|
||||
async with client.stream("POST", f"/api/v1/run/{flow_id}?stream=true", headers=headers, json=payload) as response:
|
||||
assert response.status_code == status.HTTP_200_OK, (
|
||||
f"Request failed with status {response.status_code}: {response.text}"
|
||||
)
|
||||
assert response.headers["content-type"].startswith("text/event-stream"), (
|
||||
f"Expected event stream content type, got: {response.headers['content-type']}"
|
||||
)
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line or line.strip() == "":
|
||||
continue
|
||||
|
||||
try:
|
||||
event_data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
pytest.fail(f"Failed to parse JSON from stream line: {line}")
|
||||
|
||||
assert "event" in event_data, f"Event type missing in response line: {line}"
|
||||
event_type = event_data["event"]
|
||||
received_events.append(event_type)
|
||||
|
||||
if event_type == "add_message":
|
||||
message_data = event_data["data"]
|
||||
assert "sender_name" in message_data, f"Missing 'sender_name' in add_message event: {message_data}"
|
||||
assert "sender" in message_data, f"Missing 'sender' in add_message event: {message_data}"
|
||||
assert "session_id" in message_data, f"Missing 'session_id' in add_message event: {message_data}"
|
||||
assert "text" in message_data, f"Missing 'text' in add_message event: {message_data}"
|
||||
|
||||
elif event_type == "token":
|
||||
token_data = event_data["data"]
|
||||
assert "chunk" in token_data, f"Missing 'chunk' in token event: {token_data}"
|
||||
|
||||
elif event_type == "end":
|
||||
got_end_event = True
|
||||
final_result = event_data["data"].get("result")
|
||||
assert final_result is not None, "End event should contain result data but was None"
|
||||
break # Exit loop after end event
|
||||
|
||||
elif event_type == "error":
|
||||
pytest.fail(f"Received error event in stream: {event_data['data']}")
|
||||
|
||||
# Assert we got the end event
|
||||
assert got_end_event, f"Stream did not receive an end event. Received events: {received_events}"
|
||||
|
||||
# Verify event sequence
|
||||
assert "end" in received_events, f"End event missing from event sequence. Received: {received_events}"
|
||||
assert received_events[-1] == "end", f"Last event should be 'end', but was '{received_events[-1]}'"
|
||||
|
||||
# Verify we got at least one message or token event before end
|
||||
assert len(received_events) > 2, f"Should receive multiple events before the end event. Got: {received_events}"
|
||||
assert any(event == "add_message" for event in received_events), (
|
||||
f"Should receive at least one add_message event. Received events: {received_events}"
|
||||
)
|
||||
assert any(event == "token" for event in received_events), (
|
||||
f"Should receive at least one token event. Received events: {received_events}"
|
||||
)
|
||||
|
||||
# Verify the final result structure in the end event
|
||||
assert final_result is not None, "Final result should not be None"
|
||||
assert "outputs" in final_result, f"Missing 'outputs' in final result: {final_result}"
|
||||
assert "session_id" in final_result, f"Missing 'session_id' in final result: {final_result}"
|
||||
outputs = final_result["outputs"]
|
||||
assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}: {outputs}"
|
||||
outputs_dict = outputs[0]
|
||||
|
||||
# Verify the debug outputs in final result
|
||||
assert "inputs" in outputs_dict, f"Missing 'inputs' in outputs_dict: {outputs_dict}"
|
||||
assert "outputs" in outputs_dict, f"Missing 'outputs' in outputs_dict: {outputs_dict}"
|
||||
assert outputs_dict["inputs"] == {"input_value": payload["input_value"]}, (
|
||||
f"Input value mismatch. Expected: {{'input_value': {payload['input_value']}}}, Got: {outputs_dict['inputs']}"
|
||||
)
|
||||
assert isinstance(outputs_dict.get("outputs"), list), (
|
||||
f"Expected outputs to be a list, got: {type(outputs_dict.get('outputs'))}"
|
||||
)
|
||||
|
||||
chat_input_outputs = [output for output in outputs_dict.get("outputs") if "ChatInput" in output.get("component_id")]
|
||||
assert len(chat_input_outputs) == 1, (
|
||||
f"Expected 1 ChatInput output, got {len(chat_input_outputs)}: {chat_input_outputs}"
|
||||
)
|
||||
assert all(
|
||||
output.get("results").get("message").get("text") == payload["input_value"] for output in chat_input_outputs
|
||||
), f"Message text mismatch. Expected: {payload['input_value']}, Got: {chat_input_outputs}"
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
@pytest.mark.benchmark
|
||||
async def test_concurrent_stream_run_with_input_type_chat(client: AsyncClient, starter_project, created_api_key):
|
||||
"""Test concurrent streaming requests to the run endpoint with chat input type."""
|
||||
headers = {"x-api-key": created_api_key.api_key, "Accept": "text/event-stream", "Content-Type": "application/json"}
|
||||
flow_id = starter_project["id"]
|
||||
payload = {
|
||||
"input_type": "chat",
|
||||
"output_type": "debug",
|
||||
"input_value": "How are you?",
|
||||
}
|
||||
num_concurrent_requests = 5 # Number of concurrent requests to test
|
||||
|
||||
tasks = [_run_single_stream_test(client, flow_id, headers, payload) for _ in range(num_concurrent_requests)]
|
||||
|
||||
# Run all streaming tests concurrently
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue