diff --git a/src/backend/base/langflow/services/database/models/api_key/crud.py b/src/backend/base/langflow/services/database/models/api_key/crud.py index dfc8fd6e7..22a52a561 100644 --- a/src/backend/base/langflow/services/database/models/api_key/crud.py +++ b/src/backend/base/langflow/services/database/models/api_key/crud.py @@ -10,8 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.database.models import User from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead -from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service +from langflow.services.deps import session_scope if TYPE_CHECKING: from sqlmodel.sql.expression import SelectOfScalar @@ -68,7 +67,7 @@ async def check_key(session: AsyncSession, api_key: str) -> User | None: async def update_total_uses(api_key_id: UUID): """Update the total uses and last used at.""" - async with session_getter(get_db_service()) as session: + async with session_scope() as session: new_api_key = await session.get(ApiKey, api_key_id) if new_api_key is None: msg = "API Key not found" diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index b7d115c08..ff3a8263e 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -29,7 +29,7 @@ from langflow.services.database.models.transactions.model import TransactionTabl from langflow.services.database.models.user.model import User, UserCreate, UserRead from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service +from langflow.services.deps import get_db_service, session_scope from loguru import logger from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import selectinload @@ -655,13 +655,13 @@ async def get_simple_api_test(client, logged_in_headers, json_simple_api_test): @pytest.fixture(name="starter_project") -async def get_starter_project(active_user): +async def get_starter_project(client, active_user): # noqa: ARG001 # once the client is created, we can get the starter project - async with session_getter(get_db_service()) as session: + async with session_scope() as session: stmt = ( select(Flow) .where(Flow.folder.has(Folder.name == STARTER_FOLDER_NAME)) - .where(Flow.name == "Basic Prompting (Hello, World)") + .where(Flow.name == "Basic Prompting") ) flow = (await session.exec(stmt)).first() if not flow: @@ -669,7 +669,17 @@ async def get_starter_project(active_user): raise ValueError(msg) # ensure openai api key is set - get_openai_api_key() + openai_api_key = get_openai_api_key() + data_as_json = json.dumps(flow.data) + data_as_json = data_as_json.replace("OPENAI_API_KEY", openai_api_key) + # also replace `"load_from_db": true` with `"load_from_db": false` + if '"load_from_db": true' in data_as_json: + data_as_json = data_as_json.replace('"load_from_db": true', '"load_from_db": false') + if '"load_from_db": true' in data_as_json: + msg = "load_from_db should be false" + raise ValueError(msg) + flow.data = json.loads(data_as_json) + new_flow_create = FlowCreate( name=flow.name, description=flow.description, diff --git a/src/backend/tests/unit/test_endpoints.py b/src/backend/tests/unit/test_endpoints.py index 81a4b13c8..f5b646e68 100644 --- a/src/backend/tests/unit/test_endpoints.py +++ b/src/backend/tests/unit/test_endpoints.py @@ -1,4 +1,5 @@ import asyncio +import json from uuid import UUID, uuid4 import pytest @@ -526,3 +527,112 @@ async def test_starter_projects(client, created_api_key): headers = {"x-api-key": created_api_key.api_key} response = await client.get("api/v1/starter-projects/", headers=headers) assert response.status_code == status.HTTP_200_OK, response.text + + +async def _run_single_stream_test(client: AsyncClient, flow_id: str, headers: dict, payload: dict): + """Helper coroutine to run and validate a single streaming request.""" + received_events = [] # Track all event types in sequence + got_end_event = False + final_result = None + + async with client.stream("POST", f"/api/v1/run/{flow_id}?stream=true", headers=headers, json=payload) as response: + assert response.status_code == status.HTTP_200_OK, ( + f"Request failed with status {response.status_code}: {response.text}" + ) + assert response.headers["content-type"].startswith("text/event-stream"), ( + f"Expected event stream content type, got: {response.headers['content-type']}" + ) + + async for line in response.aiter_lines(): + if not line or line.strip() == "": + continue + + try: + event_data = json.loads(line) + except json.JSONDecodeError: + pytest.fail(f"Failed to parse JSON from stream line: {line}") + + assert "event" in event_data, f"Event type missing in response line: {line}" + event_type = event_data["event"] + received_events.append(event_type) + + if event_type == "add_message": + message_data = event_data["data"] + assert "sender_name" in message_data, f"Missing 'sender_name' in add_message event: {message_data}" + assert "sender" in message_data, f"Missing 'sender' in add_message event: {message_data}" + assert "session_id" in message_data, f"Missing 'session_id' in add_message event: {message_data}" + assert "text" in message_data, f"Missing 'text' in add_message event: {message_data}" + + elif event_type == "token": + token_data = event_data["data"] + assert "chunk" in token_data, f"Missing 'chunk' in token event: {token_data}" + + elif event_type == "end": + got_end_event = True + final_result = event_data["data"].get("result") + assert final_result is not None, "End event should contain result data but was None" + break # Exit loop after end event + + elif event_type == "error": + pytest.fail(f"Received error event in stream: {event_data['data']}") + + # Assert we got the end event + assert got_end_event, f"Stream did not receive an end event. Received events: {received_events}" + + # Verify event sequence + assert "end" in received_events, f"End event missing from event sequence. Received: {received_events}" + assert received_events[-1] == "end", f"Last event should be 'end', but was '{received_events[-1]}'" + + # Verify we got at least one message or token event before end + assert len(received_events) > 2, f"Should receive multiple events before the end event. Got: {received_events}" + assert any(event == "add_message" for event in received_events), ( + f"Should receive at least one add_message event. Received events: {received_events}" + ) + assert any(event == "token" for event in received_events), ( + f"Should receive at least one token event. Received events: {received_events}" + ) + + # Verify the final result structure in the end event + assert final_result is not None, "Final result should not be None" + assert "outputs" in final_result, f"Missing 'outputs' in final result: {final_result}" + assert "session_id" in final_result, f"Missing 'session_id' in final result: {final_result}" + outputs = final_result["outputs"] + assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}: {outputs}" + outputs_dict = outputs[0] + + # Verify the debug outputs in final result + assert "inputs" in outputs_dict, f"Missing 'inputs' in outputs_dict: {outputs_dict}" + assert "outputs" in outputs_dict, f"Missing 'outputs' in outputs_dict: {outputs_dict}" + assert outputs_dict["inputs"] == {"input_value": payload["input_value"]}, ( + f"Input value mismatch. Expected: {{'input_value': {payload['input_value']}}}, Got: {outputs_dict['inputs']}" + ) + assert isinstance(outputs_dict.get("outputs"), list), ( + f"Expected outputs to be a list, got: {type(outputs_dict.get('outputs'))}" + ) + + chat_input_outputs = [output for output in outputs_dict.get("outputs") if "ChatInput" in output.get("component_id")] + assert len(chat_input_outputs) == 1, ( + f"Expected 1 ChatInput output, got {len(chat_input_outputs)}: {chat_input_outputs}" + ) + assert all( + output.get("results").get("message").get("text") == payload["input_value"] for output in chat_input_outputs + ), f"Message text mismatch. Expected: {payload['input_value']}, Got: {chat_input_outputs}" + + +@pytest.mark.api_key_required +@pytest.mark.benchmark +async def test_concurrent_stream_run_with_input_type_chat(client: AsyncClient, starter_project, created_api_key): + """Test concurrent streaming requests to the run endpoint with chat input type.""" + headers = {"x-api-key": created_api_key.api_key, "Accept": "text/event-stream", "Content-Type": "application/json"} + flow_id = starter_project["id"] + payload = { + "input_type": "chat", + "output_type": "debug", + "input_value": "How are you?", + } + num_concurrent_requests = 5 # Number of concurrent requests to test + + tasks = [_run_single_stream_test(client, flow_id, headers, payload) for _ in range(num_concurrent_requests)] + + # Run all streaming tests concurrently + await asyncio.gather(*tasks)