Add tests to run endpoint

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-22 12:38:37 -03:00
commit 80aec70ac4
7 changed files with 175 additions and 109 deletions

View file

@ -17,6 +17,7 @@ from langflow.api.v1.schemas import (
UpdateCustomComponentRequest,
UploadFileResponse,
)
from langflow.graph.graph.base import Graph
from langflow.graph.schema import RunOutputs
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.custom.directory_reader import DirectoryReader
@ -53,7 +54,7 @@ def get_all(
async def run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: Optional[List[InputValueRequest]] = [],
inputs: Optional[List[InputValueRequest]] = [InputValueRequest(components=[], input_value="")],
outputs: Optional[List[str]] = [],
tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
@ -102,23 +103,13 @@ async def run_flow_with_caching(
if outputs is None:
outputs = []
task_result: List[RunOutputs] = []
artifacts = {}
if session_id:
session_data = await session_service.load_session(session_id, flow_id=flow_id)
graph, artifacts = session_data if session_data else (None, None)
task_result: List[RunOutputs] = []
if not graph:
raise ValueError("Graph not found in the session")
task_result, session_id = await run_graph(
graph=graph,
flow_id=flow_id,
session_id=session_id,
inputs=inputs,
outputs=outputs,
artifacts=artifacts,
session_service=session_service,
stream=stream,
)
if graph is None:
raise ValueError(f"Session {session_id} not found")
else:
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
@ -130,28 +121,38 @@ async def run_flow_with_caching(
raise ValueError(f"Flow {flow_id} has no data")
graph_data = flow.data
graph_data = process_tweaks(graph_data, tweaks or {})
task_result, session_id = await run_graph(
graph=graph_data,
flow_id=flow_id,
session_id=session_id,
inputs=inputs,
outputs=outputs,
artifacts={},
session_service=session_service,
stream=stream,
)
graph = Graph.from_payload(graph_data, flow_id=flow_id)
task_result, session_id = await run_graph(
graph=graph,
flow_id=flow_id,
session_id=session_id,
inputs=inputs,
outputs=outputs,
artifacts=artifacts,
session_service=session_service,
stream=stream,
)
return RunResponse(outputs=task_result, session_id=session_id)
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
logger.error(f"Flow ID {flow_id} is not a valid UUID")
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
logger.error(f"Flow {flow_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
elif f"Session {session_id} not found" in str(exc):
logger.error(f"Session {session_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
else:
logger.exception(exc)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
@router.post(

View file

@ -4,14 +4,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from pydantic import (
BaseModel,
ConfigDict,
Field,
RootModel,
field_validator,
model_serializer,
)
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_serializer
from langflow.graph.schema import RunOutputs
from langflow.schema import dotdict
@ -61,18 +54,19 @@ class RunResponse(BaseModel):
outputs: Optional[List[RunOutputs]] = []
session_id: Optional[str] = None
@model_serializer(mode="wrap")
def serialize(self, handler):
@model_serializer(mode="plain")
def serialize(self):
# Serialize all the outputs if they are base models
serialized = {"session_id": self.session_id, "outputs": []}
if self.outputs:
serialized_outputs = []
for output in self.outputs:
if isinstance(output, BaseModel):
if isinstance(output, BaseModel) and not isinstance(output, RunOutputs):
serialized_outputs.append(output.model_dump(exclude_none=True))
else:
serialized_outputs.append(output)
self.outputs = serialized_outputs
return handler(self)
serialized["outputs"] = serialized_outputs
return serialized
class PreloadResponse(BaseModel):
@ -266,8 +260,8 @@ class InputValueRequest(BaseModel):
input_value: Optional[str] = None
# add an example
model_config = {
"json_schema_extra": {
model_config = ConfigDict(
json_schema_extra={
"examples": [
{
"components": ["components_id", "Component Name"],
@ -276,8 +270,9 @@ class InputValueRequest(BaseModel):
{"components": ["Component Name"], "input_value": "input_value"},
{"input_value": "input_value"},
]
}
}
},
extra="forbid",
)
class Tweaks(RootModel):

View file

@ -11,14 +11,7 @@ from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import process_flow
from langflow.graph.schema import INPUT_FIELD_NAME, InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import (
ChatVertex,
FileToolVertex,
LLMVertex,
RoutingVertex,
StateVertex,
ToolkitVertex,
)
from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, RoutingVertex, StateVertex, ToolkitVertex
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.schema import Record
@ -222,6 +215,13 @@ class Graph:
Returns:
List[Optional["ResultData"]]: The outputs of the graph.
"""
if input_components and not isinstance(input_components, list):
raise ValueError(f"Invalid components value: {input_components}. Expected list")
elif input_components is None:
input_components = []
if not isinstance(inputs.get(INPUT_FIELD_NAME, ""), str):
raise ValueError(f"Invalid input value: {inputs.get(INPUT_FIELD_NAME)}. Expected string")
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
if input_components and (vertex_id not in input_components or vertex.display_name not in input_components):
@ -250,7 +250,7 @@ class Graph:
if not vertex.result and not stream and hasattr(vertex, "consume_async_generator"):
await vertex.consume_async_generator()
if not outputs or (vertex.display_name in outputs or vertex.id in outputs):
if (not outputs and vertex.is_output) or (vertex.display_name in outputs or vertex.id in outputs):
vertex_outputs.append(vertex.result)
return vertex_outputs
@ -283,14 +283,9 @@ class Graph:
vertex_outputs = []
if not isinstance(inputs, list):
inputs = [inputs]
for run_inputs, components in zip(inputs, inputs_components or []):
if components and not isinstance(components, list):
raise ValueError(f"Invalid components value: {components}. Expected list")
elif components is None:
components = []
if not isinstance(run_inputs.get(INPUT_FIELD_NAME, ""), str):
raise ValueError(f"Invalid input value: {run_inputs.get(INPUT_FIELD_NAME)}. Expected string")
elif not inputs:
inputs = [{}]
for run_inputs, components in zip(inputs, inputs_components):
run_outputs = await self._run(
inputs=run_inputs,
input_components=components,

View file

@ -208,11 +208,7 @@ async def run_graph(
) -> tuple[List[RunOutputs], str]:
"""Run the graph and generate the result"""
inputs = inputs or []
if isinstance(graph, dict):
graph_data = graph
graph = Graph.from_payload(graph, flow_id=flow_id)
else:
graph_data = graph._graph_data
graph_data = graph._graph_data
if session_id is None and session_service is not None:
session_id_str = session_service.generate_key(session_id=flow_id, data_graph=graph_data)
elif session_id is not None:
@ -236,7 +232,7 @@ async def run_graph(
session_id=session_id_str or "",
)
if session_id_str and session_service:
session_service.update_session(session_id_str, (graph, artifacts))
await session_service.update_session(session_id_str, (graph, artifacts))
return run_outputs, session_id_str
@ -262,6 +258,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
return
for tweak_name, tweak_value in node_tweaks.items():
if tweak_name not in template_data:
logger.warning(f"Node {node.get('id')} does not have a tweak named {tweak_name}")
continue
if tweak_name and tweak_value and tweak_name in template_data:
key = tweak_name if tweak_name == "file_path" else "value"
template_data[tweak_name][key] = tweak_value

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Coroutine, Optional
from langflow.interface.run import build_sorted_vertices
from langflow.services.base import Service
@ -26,7 +26,7 @@ class SessionService(Service):
# If not cached, build the graph and cache it
graph, artifacts = await build_sorted_vertices(data_graph, flow_id)
self.cache_service.set(key, (graph, artifacts))
await self.cache_service.set(key, (graph, artifacts))
return graph, artifacts
@ -41,8 +41,14 @@ class SessionService(Service):
session_id = session_id_generator()
return self.build_key(session_id, data_graph=data_graph)
def update_session(self, session_id, value):
self.cache_service.set(session_id, value)
async def update_session(self, session_id, value):
result = self.cache_service.set(session_id, value)
# if it is a coroutine, await it
if isinstance(result, Coroutine):
await result
def clear_session(self, session_id):
self.cache_service.delete(session_id)
async def clear_session(self, session_id):
result = self.cache_service.delete(session_id)
# if it is a coroutine, await it
if isinstance(result, Coroutine):
await result

View file

@ -10,10 +10,6 @@ import orjson
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from langflow.graph.graph.base import Graph
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
from langflow.services.auth.utils import get_password_hash
@ -22,6 +18,9 @@ from langflow.services.database.models.flow.model import Flow, FlowCreate
from langflow.services.database.models.user.model import User, UserCreate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
@ -263,7 +262,7 @@ def active_user(client):
is_superuser=False,
)
# check if user exists
if active_user := session.query(User).filter(User.username == user.username).first():
if active_user := session.exec(select(User).where(User.username == user.username)).first():
return active_user
session.add(user)
session.commit()
@ -368,7 +367,7 @@ def created_api_key(active_user):
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
return existing_api_key
session.add(api_key)
session.commit()

View file

@ -1,14 +1,12 @@
import time
from uuid import uuid4
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from langflow.interface.custom.directory_reader.directory_reader import DirectoryReader
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from langflow.services.deps import get_settings_service
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
@ -112,25 +110,6 @@ PROMPT_REQUEST = {
}
@pytest.fixture
def created_api_key(active_user):
hashed = get_password_hash("random_key")
api_key = ApiKey(
name="test_api_key",
user_id=active_user.id,
api_key="random_key",
hashed_api_key=hashed,
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
return existing_api_key
session.add(api_key)
session.commit()
session.refresh(api_key)
return api_key
# def test_process_flow_invalid_api_key(client, flow, monkeypatch):
# # Mock de process_graph_cached
# from langflow.api.v1 import endpoints
@ -452,18 +431,24 @@ def test_successful_run(client, starter_project, created_api_key):
assert response.status_code == status.HTTP_200_OK, response.text
# Add more assertions here to validate the response content
json_response = response.json()
assert "session_id" in json_response
assert "outputs" in json_response
outer_outputs = json_response["outputs"]
assert len(outer_outputs) == 1
outputs = outer_outputs[0]
assert len(outputs) == 2
keys = ["results", "artifacts", "messages"]
for output in outputs:
assert all(key in output for key in keys)
output = outputs[0]
result = output["results"]["result"]
assert result == "Write a press release \n\n- Cars\n- Bottle\n\n\nAnswer:\n\n"
assert "session_id" in json_response
outputs_dict = outer_outputs[0]
assert len(outputs_dict) == 2
assert "inputs" in outputs_dict
assert "outputs" in outputs_dict
assert outputs_dict.get("inputs") == {"input_value": ""}
assert isinstance(outputs_dict.get("outputs"), list)
assert len(outputs_dict.get("outputs")) == 2
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
assert all([id in ids for id in ["TextOutput-fTp5e", "ChatOutput-AVN8s"]])
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
assert all([name in display_names for name in ["Prompt Output", "Chat Output"]])
inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")]
expected_results = ["Write a press release \n\n- Cars\n- Bottle\n\n\nAnswer:\n\n", ""]
assert all([result in inner_results for result in expected_results])
def test_run_with_inputs_and_outputs(client, starter_project, created_api_key):
@ -484,3 +469,89 @@ def test_invalid_flow_id(client, created_api_key):
response = client.post(f"/api/v1/run/{flow_id}", headers=headers)
assert response.status_code == status.HTTP_404_NOT_FOUND
# Check if the error detail is as expected
def test_run_flow_with_caching_success(client: TestClient, starter_project, created_api_key):
flow_id = starter_project["id"]
headers = {"x-api-key": created_api_key.api_key}
payload = {
"inputs": [
{"components": ["component1"], "input_value": "value1"},
{"components": ["component3"], "input_value": "value2"},
],
"outputs": ["Component Name", "component_id"],
"tweaks": {
"parameter_name": "value",
"Component Name": {"parameter_name": "value"},
"component_id": {"parameter_name": "value"},
},
"stream": False,
}
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "outputs" in data
assert "session_id" in data
def test_run_flow_with_caching_invalid_flow_id(client: TestClient, created_api_key):
invalid_flow_id = uuid4()
headers = {"x-api-key": created_api_key.api_key}
payload = {"inputs": [], "outputs": [], "tweaks": {}, "stream": False}
response = client.post(f"/api/v1/run/{invalid_flow_id}", json=payload, headers=headers)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert "detail" in data
assert f"Flow {invalid_flow_id} not found" in data["detail"]
def test_run_flow_with_caching_invalid_input_format(client: TestClient, starter_project, created_api_key):
flow_id = starter_project["id"]
headers = {"x-api-key": created_api_key.api_key}
payload = {"inputs": [{"invalid_key": "value"}], "outputs": [], "tweaks": {}, "stream": False}
# This should raise an http 422 error not validation error
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_run_flow_with_session_id(client, starter_project, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = starter_project["id"]
payload = {
"inputs": [{"components": ["component1"], "input_value": "value1"}],
"outputs": ["Component Name", "component_id"],
"session_id": "test-session-id",
}
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "outputs" in data
assert "session_id" in data
assert data["session_id"] == "test-session-id"
def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = starter_project["id"]
payload = {
"inputs": [{"components": ["component1"], "input_value": "value1"}],
"outputs": ["Component Name", "component_id"],
"session_id": "invalid-session-id",
}
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert "detail" in data
assert f"Session {payload['session_id']} not found" in data["detail"]
def test_run_flow_with_invalid_tweaks(client, starter_project, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = starter_project["id"]
payload = {
"inputs": [{"components": ["component1"], "input_value": "value1"}],
"outputs": ["Component Name", "component_id"],
"tweaks": {"invalid_tweak": "value"},
}
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
assert response.status_code == status.HTTP_200_OK