From 80aec70ac456a8000b4703838f99f537fa812137 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 22 Mar 2024 12:38:37 -0300 Subject: [PATCH] Add tests to run endpoint --- src/backend/langflow/api/v1/endpoints.py | 51 +++---- src/backend/langflow/api/v1/schemas.py | 29 ++-- src/backend/langflow/graph/graph/base.py | 29 ++-- src/backend/langflow/processing/process.py | 11 +- .../langflow/services/session/service.py | 18 ++- tests/conftest.py | 11 +- tests/test_endpoints.py | 135 +++++++++++++----- 7 files changed, 175 insertions(+), 109 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index cea43c496..6bae2fe79 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -17,6 +17,7 @@ from langflow.api.v1.schemas import ( UpdateCustomComponentRequest, UploadFileResponse, ) +from langflow.graph.graph.base import Graph from langflow.graph.schema import RunOutputs from langflow.interface.custom.custom_component import CustomComponent from langflow.interface.custom.directory_reader import DirectoryReader @@ -53,7 +54,7 @@ def get_all( async def run_flow_with_caching( session: Annotated[Session, Depends(get_session)], flow_id: str, - inputs: Optional[List[InputValueRequest]] = [], + inputs: Optional[List[InputValueRequest]] = [InputValueRequest(components=[], input_value="")], outputs: Optional[List[str]] = [], tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821 stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821 @@ -102,23 +103,13 @@ async def run_flow_with_caching( if outputs is None: outputs = [] + task_result: List[RunOutputs] = [] + artifacts = {} if session_id: session_data = await session_service.load_session(session_id, flow_id=flow_id) graph, artifacts = session_data if session_data else (None, None) - task_result: List[RunOutputs] = [] - if not graph: - raise ValueError("Graph not found in the session") - task_result, session_id = await run_graph( - graph=graph, - flow_id=flow_id, - session_id=session_id, - inputs=inputs, - outputs=outputs, - artifacts=artifacts, - session_service=session_service, - stream=stream, - ) - + if graph is None: + raise ValueError(f"Session {session_id} not found") else: # Get the flow that matches the flow_id and belongs to the user # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() @@ -130,28 +121,38 @@ async def run_flow_with_caching( raise ValueError(f"Flow {flow_id} has no data") graph_data = flow.data graph_data = process_tweaks(graph_data, tweaks or {}) - task_result, session_id = await run_graph( - graph=graph_data, - flow_id=flow_id, - session_id=session_id, - inputs=inputs, - outputs=outputs, - artifacts={}, - session_service=session_service, - stream=stream, - ) + graph = Graph.from_payload(graph_data, flow_id=flow_id) + task_result, session_id = await run_graph( + graph=graph, + flow_id=flow_id, + session_id=session_id, + inputs=inputs, + outputs=outputs, + artifacts=artifacts, + session_service=session_service, + stream=stream, + ) return RunResponse(outputs=task_result, session_id=session_id) except sa.exc.StatementError as exc: # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') if "badly formed hexadecimal UUID string" in str(exc): + logger.error(f"Flow ID {flow_id} is not a valid UUID") # This means the Flow ID is not a valid UUID which means it can't find the flow raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc except ValueError as exc: if f"Flow {flow_id} not found" in str(exc): + logger.error(f"Flow {flow_id} not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + elif f"Session {session_id} not found" in str(exc): + logger.error(f"Session {session_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc else: + logger.exception(exc) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc + except Exception as exc: + logger.exception(exc) + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc @router.post( diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index db522cc2a..e575745aa 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -4,14 +4,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import ( - BaseModel, - ConfigDict, - Field, - RootModel, - field_validator, - model_serializer, -) +from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_serializer from langflow.graph.schema import RunOutputs from langflow.schema import dotdict @@ -61,18 +54,19 @@ class RunResponse(BaseModel): outputs: Optional[List[RunOutputs]] = [] session_id: Optional[str] = None - @model_serializer(mode="wrap") - def serialize(self, handler): + @model_serializer(mode="plain") + def serialize(self): # Serialize all the outputs if they are base models + serialized = {"session_id": self.session_id, "outputs": []} if self.outputs: serialized_outputs = [] for output in self.outputs: - if isinstance(output, BaseModel): + if isinstance(output, BaseModel) and not isinstance(output, RunOutputs): serialized_outputs.append(output.model_dump(exclude_none=True)) else: serialized_outputs.append(output) - self.outputs = serialized_outputs - return handler(self) + serialized["outputs"] = serialized_outputs + return serialized class PreloadResponse(BaseModel): @@ -266,8 +260,8 @@ class InputValueRequest(BaseModel): input_value: Optional[str] = None # add an example - model_config = { - "json_schema_extra": { + model_config = ConfigDict( + json_schema_extra={ "examples": [ { "components": ["components_id", "Component Name"], @@ -276,8 +270,9 @@ class InputValueRequest(BaseModel): {"components": ["Component Name"], "input_value": "input_value"}, {"input_value": "input_value"}, ] - } - } + }, + extra="forbid", + ) class Tweaks(RootModel): diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 1d68c5e6b..f0cdff6ac 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -11,14 +11,7 @@ from langflow.graph.graph.state_manager import GraphStateManager from langflow.graph.graph.utils import process_flow from langflow.graph.schema import INPUT_FIELD_NAME, InterfaceComponentTypes, RunOutputs from langflow.graph.vertex.base import Vertex -from langflow.graph.vertex.types import ( - ChatVertex, - FileToolVertex, - LLMVertex, - RoutingVertex, - StateVertex, - ToolkitVertex, -) +from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, RoutingVertex, StateVertex, ToolkitVertex from langflow.interface.tools.constants import FILE_TOOLS from langflow.schema import Record @@ -222,6 +215,13 @@ class Graph: Returns: List[Optional["ResultData"]]: The outputs of the graph. """ + if input_components and not isinstance(input_components, list): + raise ValueError(f"Invalid components value: {input_components}. Expected list") + elif input_components is None: + input_components = [] + + if not isinstance(inputs.get(INPUT_FIELD_NAME, ""), str): + raise ValueError(f"Invalid input value: {inputs.get(INPUT_FIELD_NAME)}. Expected string") for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) if input_components and (vertex_id not in input_components or vertex.display_name not in input_components): @@ -250,7 +250,7 @@ class Graph: if not vertex.result and not stream and hasattr(vertex, "consume_async_generator"): await vertex.consume_async_generator() - if not outputs or (vertex.display_name in outputs or vertex.id in outputs): + if (not outputs and vertex.is_output) or (vertex.display_name in outputs or vertex.id in outputs): vertex_outputs.append(vertex.result) return vertex_outputs @@ -283,14 +283,9 @@ class Graph: vertex_outputs = [] if not isinstance(inputs, list): inputs = [inputs] - for run_inputs, components in zip(inputs, inputs_components or []): - if components and not isinstance(components, list): - raise ValueError(f"Invalid components value: {components}. Expected list") - elif components is None: - components = [] - - if not isinstance(run_inputs.get(INPUT_FIELD_NAME, ""), str): - raise ValueError(f"Invalid input value: {run_inputs.get(INPUT_FIELD_NAME)}. Expected string") + elif not inputs: + inputs = [{}] + for run_inputs, components in zip(inputs, inputs_components): run_outputs = await self._run( inputs=run_inputs, input_components=components, diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 6063afbbe..5d498bb0f 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -208,11 +208,7 @@ async def run_graph( ) -> tuple[List[RunOutputs], str]: """Run the graph and generate the result""" inputs = inputs or [] - if isinstance(graph, dict): - graph_data = graph - graph = Graph.from_payload(graph, flow_id=flow_id) - else: - graph_data = graph._graph_data + graph_data = graph._graph_data if session_id is None and session_service is not None: session_id_str = session_service.generate_key(session_id=flow_id, data_graph=graph_data) elif session_id is not None: @@ -236,7 +232,7 @@ async def run_graph( session_id=session_id_str or "", ) if session_id_str and session_service: - session_service.update_session(session_id_str, (graph, artifacts)) + await session_service.update_session(session_id_str, (graph, artifacts)) return run_outputs, session_id_str @@ -262,6 +258,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: return for tweak_name, tweak_value in node_tweaks.items(): + if tweak_name not in template_data: + logger.warning(f"Node {node.get('id')} does not have a tweak named {tweak_name}") + continue if tweak_name and tweak_value and tweak_name in template_data: key = tweak_name if tweak_name == "file_path" else "value" template_data[tweak_name][key] = tweak_value diff --git a/src/backend/langflow/services/session/service.py b/src/backend/langflow/services/session/service.py index 68fec0430..ae137f2af 100644 --- a/src/backend/langflow/services/session/service.py +++ b/src/backend/langflow/services/session/service.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Coroutine, Optional from langflow.interface.run import build_sorted_vertices from langflow.services.base import Service @@ -26,7 +26,7 @@ class SessionService(Service): # If not cached, build the graph and cache it graph, artifacts = await build_sorted_vertices(data_graph, flow_id) - self.cache_service.set(key, (graph, artifacts)) + await self.cache_service.set(key, (graph, artifacts)) return graph, artifacts @@ -41,8 +41,14 @@ class SessionService(Service): session_id = session_id_generator() return self.build_key(session_id, data_graph=data_graph) - def update_session(self, session_id, value): - self.cache_service.set(session_id, value) + async def update_session(self, session_id, value): + result = self.cache_service.set(session_id, value) + # if it is a coroutine, await it + if isinstance(result, Coroutine): + await result - def clear_session(self, session_id): - self.cache_service.delete(session_id) + async def clear_session(self, session_id): + result = self.cache_service.delete(session_id) + # if it is a coroutine, await it + if isinstance(result, Coroutine): + await result diff --git a/tests/conftest.py b/tests/conftest.py index 7708340c9..83c7d031d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,10 +10,6 @@ import orjson import pytest from fastapi.testclient import TestClient from httpx import AsyncClient -from sqlmodel import Session, SQLModel, create_engine, select -from sqlmodel.pool import StaticPool -from typer.testing import CliRunner - from langflow.graph.graph.base import Graph from langflow.initial_setup.setup import STARTER_FOLDER_NAME from langflow.services.auth.utils import get_password_hash @@ -22,6 +18,9 @@ from langflow.services.database.models.flow.model import Flow, FlowCreate from langflow.services.database.models.user.model import User, UserCreate from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service +from sqlmodel import Session, SQLModel, create_engine, select +from sqlmodel.pool import StaticPool +from typer.testing import CliRunner if TYPE_CHECKING: from langflow.services.database.service import DatabaseService @@ -263,7 +262,7 @@ def active_user(client): is_superuser=False, ) # check if user exists - if active_user := session.query(User).filter(User.username == user.username).first(): + if active_user := session.exec(select(User).where(User.username == user.username)).first(): return active_user session.add(user) session.commit() @@ -368,7 +367,7 @@ def created_api_key(active_user): ) db_manager = get_db_service() with session_getter(db_manager) as session: - if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first(): + if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first(): return existing_api_key session.add(api_key) session.commit() diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 4494934fd..7b1ccd2ef 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,14 +1,12 @@ import time +from uuid import uuid4 import pytest from fastapi import status from fastapi.testclient import TestClient from langflow.interface.custom.directory_reader.directory_reader import DirectoryReader -from langflow.services.auth.utils import get_password_hash -from langflow.services.database.models.api_key.model import ApiKey -from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service, get_settings_service +from langflow.services.deps import get_settings_service from langflow.template.frontend_node.chains import TimeTravelGuideChainNode @@ -112,25 +110,6 @@ PROMPT_REQUEST = { } -@pytest.fixture -def created_api_key(active_user): - hashed = get_password_hash("random_key") - api_key = ApiKey( - name="test_api_key", - user_id=active_user.id, - api_key="random_key", - hashed_api_key=hashed, - ) - db_manager = get_db_service() - with session_getter(db_manager) as session: - if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first(): - return existing_api_key - session.add(api_key) - session.commit() - session.refresh(api_key) - return api_key - - # def test_process_flow_invalid_api_key(client, flow, monkeypatch): # # Mock de process_graph_cached # from langflow.api.v1 import endpoints @@ -452,18 +431,24 @@ def test_successful_run(client, starter_project, created_api_key): assert response.status_code == status.HTTP_200_OK, response.text # Add more assertions here to validate the response content json_response = response.json() + assert "session_id" in json_response assert "outputs" in json_response outer_outputs = json_response["outputs"] assert len(outer_outputs) == 1 - outputs = outer_outputs[0] - assert len(outputs) == 2 - keys = ["results", "artifacts", "messages"] - for output in outputs: - assert all(key in output for key in keys) - output = outputs[0] - result = output["results"]["result"] - assert result == "Write a press release \n\n- Cars\n- Bottle\n\n\nAnswer:\n\n" - assert "session_id" in json_response + outputs_dict = outer_outputs[0] + assert len(outputs_dict) == 2 + assert "inputs" in outputs_dict + assert "outputs" in outputs_dict + assert outputs_dict.get("inputs") == {"input_value": ""} + assert isinstance(outputs_dict.get("outputs"), list) + assert len(outputs_dict.get("outputs")) == 2 + ids = [output.get("component_id") for output in outputs_dict.get("outputs")] + assert all([id in ids for id in ["TextOutput-fTp5e", "ChatOutput-AVN8s"]]) + display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")] + assert all([name in display_names for name in ["Prompt Output", "Chat Output"]]) + inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")] + expected_results = ["Write a press release \n\n- Cars\n- Bottle\n\n\nAnswer:\n\n", ""] + assert all([result in inner_results for result in expected_results]) def test_run_with_inputs_and_outputs(client, starter_project, created_api_key): @@ -484,3 +469,89 @@ def test_invalid_flow_id(client, created_api_key): response = client.post(f"/api/v1/run/{flow_id}", headers=headers) assert response.status_code == status.HTTP_404_NOT_FOUND # Check if the error detail is as expected + + +def test_run_flow_with_caching_success(client: TestClient, starter_project, created_api_key): + flow_id = starter_project["id"] + headers = {"x-api-key": created_api_key.api_key} + payload = { + "inputs": [ + {"components": ["component1"], "input_value": "value1"}, + {"components": ["component3"], "input_value": "value2"}, + ], + "outputs": ["Component Name", "component_id"], + "tweaks": { + "parameter_name": "value", + "Component Name": {"parameter_name": "value"}, + "component_id": {"parameter_name": "value"}, + }, + "stream": False, + } + response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "outputs" in data + assert "session_id" in data + + +def test_run_flow_with_caching_invalid_flow_id(client: TestClient, created_api_key): + invalid_flow_id = uuid4() + headers = {"x-api-key": created_api_key.api_key} + payload = {"inputs": [], "outputs": [], "tweaks": {}, "stream": False} + response = client.post(f"/api/v1/run/{invalid_flow_id}", json=payload, headers=headers) + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + assert f"Flow {invalid_flow_id} not found" in data["detail"] + + +def test_run_flow_with_caching_invalid_input_format(client: TestClient, starter_project, created_api_key): + flow_id = starter_project["id"] + headers = {"x-api-key": created_api_key.api_key} + payload = {"inputs": [{"invalid_key": "value"}], "outputs": [], "tweaks": {}, "stream": False} + # This should raise an http 422 error not validation error + response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +def test_run_flow_with_session_id(client, starter_project, created_api_key): + headers = {"x-api-key": created_api_key.api_key} + flow_id = starter_project["id"] + payload = { + "inputs": [{"components": ["component1"], "input_value": "value1"}], + "outputs": ["Component Name", "component_id"], + "session_id": "test-session-id", + } + response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "outputs" in data + assert "session_id" in data + assert data["session_id"] == "test-session-id" + + +def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key): + headers = {"x-api-key": created_api_key.api_key} + flow_id = starter_project["id"] + payload = { + "inputs": [{"components": ["component1"], "input_value": "value1"}], + "outputs": ["Component Name", "component_id"], + "session_id": "invalid-session-id", + } + response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + assert f"Session {payload['session_id']} not found" in data["detail"] + + +def test_run_flow_with_invalid_tweaks(client, starter_project, created_api_key): + headers = {"x-api-key": created_api_key.api_key} + flow_id = starter_project["id"] + payload = { + "inputs": [{"components": ["component1"], "input_value": "value1"}], + "outputs": ["Component Name", "component_id"], + "tweaks": {"invalid_tweak": "value"}, + } + response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) + assert response.status_code == status.HTTP_200_OK