Add tests to run endpoint
This commit is contained in:
parent
496c2aae3e
commit
80aec70ac4
7 changed files with 175 additions and 109 deletions
|
|
@ -17,6 +17,7 @@ from langflow.api.v1.schemas import (
|
|||
UpdateCustomComponentRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
|
|
@ -53,7 +54,7 @@ def get_all(
|
|||
async def run_flow_with_caching(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
inputs: Optional[List[InputValueRequest]] = [],
|
||||
inputs: Optional[List[InputValueRequest]] = [InputValueRequest(components=[], input_value="")],
|
||||
outputs: Optional[List[str]] = [],
|
||||
tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821
|
||||
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
|
|
@ -102,23 +103,13 @@ async def run_flow_with_caching(
|
|||
if outputs is None:
|
||||
outputs = []
|
||||
|
||||
task_result: List[RunOutputs] = []
|
||||
artifacts = {}
|
||||
if session_id:
|
||||
session_data = await session_service.load_session(session_id, flow_id=flow_id)
|
||||
graph, artifacts = session_data if session_data else (None, None)
|
||||
task_result: List[RunOutputs] = []
|
||||
if not graph:
|
||||
raise ValueError("Graph not found in the session")
|
||||
task_result, session_id = await run_graph(
|
||||
graph=graph,
|
||||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if graph is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
else:
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
|
||||
|
|
@ -130,28 +121,38 @@ async def run_flow_with_caching(
|
|||
raise ValueError(f"Flow {flow_id} has no data")
|
||||
graph_data = flow.data
|
||||
graph_data = process_tweaks(graph_data, tweaks or {})
|
||||
task_result, session_id = await run_graph(
|
||||
graph=graph_data,
|
||||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
artifacts={},
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
)
|
||||
graph = Graph.from_payload(graph_data, flow_id=flow_id)
|
||||
task_result, session_id = await run_graph(
|
||||
graph=graph,
|
||||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return RunResponse(outputs=task_result, session_id=session_id)
|
||||
except sa.exc.StatementError as exc:
|
||||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
logger.error(f"Flow ID {flow_id} is not a valid UUID")
|
||||
# This means the Flow ID is not a valid UUID which means it can't find the flow
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
if f"Flow {flow_id} not found" in str(exc):
|
||||
logger.error(f"Flow {flow_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
elif f"Session {session_id} not found" in str(exc):
|
||||
logger.error(f"Session {session_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
else:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
|
|||
|
|
@ -4,14 +4,7 @@ from pathlib import Path
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_serializer
|
||||
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.schema import dotdict
|
||||
|
|
@ -61,18 +54,19 @@ class RunResponse(BaseModel):
|
|||
outputs: Optional[List[RunOutputs]] = []
|
||||
session_id: Optional[str] = None
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
@model_serializer(mode="plain")
|
||||
def serialize(self):
|
||||
# Serialize all the outputs if they are base models
|
||||
serialized = {"session_id": self.session_id, "outputs": []}
|
||||
if self.outputs:
|
||||
serialized_outputs = []
|
||||
for output in self.outputs:
|
||||
if isinstance(output, BaseModel):
|
||||
if isinstance(output, BaseModel) and not isinstance(output, RunOutputs):
|
||||
serialized_outputs.append(output.model_dump(exclude_none=True))
|
||||
else:
|
||||
serialized_outputs.append(output)
|
||||
self.outputs = serialized_outputs
|
||||
return handler(self)
|
||||
serialized["outputs"] = serialized_outputs
|
||||
return serialized
|
||||
|
||||
|
||||
class PreloadResponse(BaseModel):
|
||||
|
|
@ -266,8 +260,8 @@ class InputValueRequest(BaseModel):
|
|||
input_value: Optional[str] = None
|
||||
|
||||
# add an example
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"examples": [
|
||||
{
|
||||
"components": ["components_id", "Component Name"],
|
||||
|
|
@ -276,8 +270,9 @@ class InputValueRequest(BaseModel):
|
|||
{"components": ["Component Name"], "input_value": "input_value"},
|
||||
{"input_value": "input_value"},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
|
||||
class Tweaks(RootModel):
|
||||
|
|
|
|||
|
|
@ -11,14 +11,7 @@ from langflow.graph.graph.state_manager import GraphStateManager
|
|||
from langflow.graph.graph.utils import process_flow
|
||||
from langflow.graph.schema import INPUT_FIELD_NAME, InterfaceComponentTypes, RunOutputs
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import (
|
||||
ChatVertex,
|
||||
FileToolVertex,
|
||||
LLMVertex,
|
||||
RoutingVertex,
|
||||
StateVertex,
|
||||
ToolkitVertex,
|
||||
)
|
||||
from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, RoutingVertex, StateVertex, ToolkitVertex
|
||||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.schema import Record
|
||||
|
||||
|
|
@ -222,6 +215,13 @@ class Graph:
|
|||
Returns:
|
||||
List[Optional["ResultData"]]: The outputs of the graph.
|
||||
"""
|
||||
if input_components and not isinstance(input_components, list):
|
||||
raise ValueError(f"Invalid components value: {input_components}. Expected list")
|
||||
elif input_components is None:
|
||||
input_components = []
|
||||
|
||||
if not isinstance(inputs.get(INPUT_FIELD_NAME, ""), str):
|
||||
raise ValueError(f"Invalid input value: {inputs.get(INPUT_FIELD_NAME)}. Expected string")
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if input_components and (vertex_id not in input_components or vertex.display_name not in input_components):
|
||||
|
|
@ -250,7 +250,7 @@ class Graph:
|
|||
|
||||
if not vertex.result and not stream and hasattr(vertex, "consume_async_generator"):
|
||||
await vertex.consume_async_generator()
|
||||
if not outputs or (vertex.display_name in outputs or vertex.id in outputs):
|
||||
if (not outputs and vertex.is_output) or (vertex.display_name in outputs or vertex.id in outputs):
|
||||
vertex_outputs.append(vertex.result)
|
||||
|
||||
return vertex_outputs
|
||||
|
|
@ -283,14 +283,9 @@ class Graph:
|
|||
vertex_outputs = []
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
for run_inputs, components in zip(inputs, inputs_components or []):
|
||||
if components and not isinstance(components, list):
|
||||
raise ValueError(f"Invalid components value: {components}. Expected list")
|
||||
elif components is None:
|
||||
components = []
|
||||
|
||||
if not isinstance(run_inputs.get(INPUT_FIELD_NAME, ""), str):
|
||||
raise ValueError(f"Invalid input value: {run_inputs.get(INPUT_FIELD_NAME)}. Expected string")
|
||||
elif not inputs:
|
||||
inputs = [{}]
|
||||
for run_inputs, components in zip(inputs, inputs_components):
|
||||
run_outputs = await self._run(
|
||||
inputs=run_inputs,
|
||||
input_components=components,
|
||||
|
|
|
|||
|
|
@ -208,11 +208,7 @@ async def run_graph(
|
|||
) -> tuple[List[RunOutputs], str]:
|
||||
"""Run the graph and generate the result"""
|
||||
inputs = inputs or []
|
||||
if isinstance(graph, dict):
|
||||
graph_data = graph
|
||||
graph = Graph.from_payload(graph, flow_id=flow_id)
|
||||
else:
|
||||
graph_data = graph._graph_data
|
||||
graph_data = graph._graph_data
|
||||
if session_id is None and session_service is not None:
|
||||
session_id_str = session_service.generate_key(session_id=flow_id, data_graph=graph_data)
|
||||
elif session_id is not None:
|
||||
|
|
@ -236,7 +232,7 @@ async def run_graph(
|
|||
session_id=session_id_str or "",
|
||||
)
|
||||
if session_id_str and session_service:
|
||||
session_service.update_session(session_id_str, (graph, artifacts))
|
||||
await session_service.update_session(session_id_str, (graph, artifacts))
|
||||
return run_outputs, session_id_str
|
||||
|
||||
|
||||
|
|
@ -262,6 +258,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
|
|||
return
|
||||
|
||||
for tweak_name, tweak_value in node_tweaks.items():
|
||||
if tweak_name not in template_data:
|
||||
logger.warning(f"Node {node.get('id')} does not have a tweak named {tweak_name}")
|
||||
continue
|
||||
if tweak_name and tweak_value and tweak_name in template_data:
|
||||
key = tweak_name if tweak_name == "file_path" else "value"
|
||||
template_data[tweak_name][key] = tweak_value
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Coroutine, Optional
|
||||
|
||||
from langflow.interface.run import build_sorted_vertices
|
||||
from langflow.services.base import Service
|
||||
|
|
@ -26,7 +26,7 @@ class SessionService(Service):
|
|||
# If not cached, build the graph and cache it
|
||||
graph, artifacts = await build_sorted_vertices(data_graph, flow_id)
|
||||
|
||||
self.cache_service.set(key, (graph, artifacts))
|
||||
await self.cache_service.set(key, (graph, artifacts))
|
||||
|
||||
return graph, artifacts
|
||||
|
||||
|
|
@ -41,8 +41,14 @@ class SessionService(Service):
|
|||
session_id = session_id_generator()
|
||||
return self.build_key(session_id, data_graph=data_graph)
|
||||
|
||||
def update_session(self, session_id, value):
|
||||
self.cache_service.set(session_id, value)
|
||||
async def update_session(self, session_id, value):
|
||||
result = self.cache_service.set(session_id, value)
|
||||
# if it is a coroutine, await it
|
||||
if isinstance(result, Coroutine):
|
||||
await result
|
||||
|
||||
def clear_session(self, session_id):
|
||||
self.cache_service.delete(session_id)
|
||||
async def clear_session(self, session_id):
|
||||
result = self.cache_service.delete(session_id)
|
||||
# if it is a coroutine, await it
|
||||
if isinstance(result, Coroutine):
|
||||
await result
|
||||
|
|
|
|||
|
|
@ -10,10 +10,6 @@ import orjson
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
|
|
@ -22,6 +18,9 @@ from langflow.services.database.models.flow.model import Flow, FlowCreate
|
|||
from langflow.services.database.models.user.model import User, UserCreate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
|
@ -263,7 +262,7 @@ def active_user(client):
|
|||
is_superuser=False,
|
||||
)
|
||||
# check if user exists
|
||||
if active_user := session.query(User).filter(User.username == user.username).first():
|
||||
if active_user := session.exec(select(User).where(User.username == user.username)).first():
|
||||
return active_user
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
|
@ -368,7 +367,7 @@ def created_api_key(active_user):
|
|||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
|
||||
if existing_api_key := session.exec(select(ApiKey).where(ApiKey.api_key == api_key.api_key)).first():
|
||||
return existing_api_key
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from langflow.interface.custom.directory_reader.directory_reader import DirectoryReader
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
|
||||
|
||||
|
||||
|
|
@ -112,25 +110,6 @@ PROMPT_REQUEST = {
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def created_api_key(active_user):
|
||||
hashed = get_password_hash("random_key")
|
||||
api_key = ApiKey(
|
||||
name="test_api_key",
|
||||
user_id=active_user.id,
|
||||
api_key="random_key",
|
||||
hashed_api_key=hashed,
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
|
||||
return existing_api_key
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
session.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
|
||||
# def test_process_flow_invalid_api_key(client, flow, monkeypatch):
|
||||
# # Mock de process_graph_cached
|
||||
# from langflow.api.v1 import endpoints
|
||||
|
|
@ -452,18 +431,24 @@ def test_successful_run(client, starter_project, created_api_key):
|
|||
assert response.status_code == status.HTTP_200_OK, response.text
|
||||
# Add more assertions here to validate the response content
|
||||
json_response = response.json()
|
||||
assert "session_id" in json_response
|
||||
assert "outputs" in json_response
|
||||
outer_outputs = json_response["outputs"]
|
||||
assert len(outer_outputs) == 1
|
||||
outputs = outer_outputs[0]
|
||||
assert len(outputs) == 2
|
||||
keys = ["results", "artifacts", "messages"]
|
||||
for output in outputs:
|
||||
assert all(key in output for key in keys)
|
||||
output = outputs[0]
|
||||
result = output["results"]["result"]
|
||||
assert result == "Write a press release \n\n- Cars\n- Bottle\n\n\nAnswer:\n\n"
|
||||
assert "session_id" in json_response
|
||||
outputs_dict = outer_outputs[0]
|
||||
assert len(outputs_dict) == 2
|
||||
assert "inputs" in outputs_dict
|
||||
assert "outputs" in outputs_dict
|
||||
assert outputs_dict.get("inputs") == {"input_value": ""}
|
||||
assert isinstance(outputs_dict.get("outputs"), list)
|
||||
assert len(outputs_dict.get("outputs")) == 2
|
||||
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
|
||||
assert all([id in ids for id in ["TextOutput-fTp5e", "ChatOutput-AVN8s"]])
|
||||
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
|
||||
assert all([name in display_names for name in ["Prompt Output", "Chat Output"]])
|
||||
inner_results = [output.get("results").get("result") for output in outputs_dict.get("outputs")]
|
||||
expected_results = ["Write a press release \n\n- Cars\n- Bottle\n\n\nAnswer:\n\n", ""]
|
||||
assert all([result in inner_results for result in expected_results])
|
||||
|
||||
|
||||
def test_run_with_inputs_and_outputs(client, starter_project, created_api_key):
|
||||
|
|
@ -484,3 +469,89 @@ def test_invalid_flow_id(client, created_api_key):
|
|||
response = client.post(f"/api/v1/run/{flow_id}", headers=headers)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
# Check if the error detail is as expected
|
||||
|
||||
|
||||
def test_run_flow_with_caching_success(client: TestClient, starter_project, created_api_key):
|
||||
flow_id = starter_project["id"]
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
payload = {
|
||||
"inputs": [
|
||||
{"components": ["component1"], "input_value": "value1"},
|
||||
{"components": ["component3"], "input_value": "value2"},
|
||||
],
|
||||
"outputs": ["Component Name", "component_id"],
|
||||
"tweaks": {
|
||||
"parameter_name": "value",
|
||||
"Component Name": {"parameter_name": "value"},
|
||||
"component_id": {"parameter_name": "value"},
|
||||
},
|
||||
"stream": False,
|
||||
}
|
||||
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "outputs" in data
|
||||
assert "session_id" in data
|
||||
|
||||
|
||||
def test_run_flow_with_caching_invalid_flow_id(client: TestClient, created_api_key):
|
||||
invalid_flow_id = uuid4()
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
payload = {"inputs": [], "outputs": [], "tweaks": {}, "stream": False}
|
||||
response = client.post(f"/api/v1/run/{invalid_flow_id}", json=payload, headers=headers)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
data = response.json()
|
||||
assert "detail" in data
|
||||
assert f"Flow {invalid_flow_id} not found" in data["detail"]
|
||||
|
||||
|
||||
def test_run_flow_with_caching_invalid_input_format(client: TestClient, starter_project, created_api_key):
|
||||
flow_id = starter_project["id"]
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
payload = {"inputs": [{"invalid_key": "value"}], "outputs": [], "tweaks": {}, "stream": False}
|
||||
# This should raise an http 422 error not validation error
|
||||
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
def test_run_flow_with_session_id(client, starter_project, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
flow_id = starter_project["id"]
|
||||
payload = {
|
||||
"inputs": [{"components": ["component1"], "input_value": "value1"}],
|
||||
"outputs": ["Component Name", "component_id"],
|
||||
"session_id": "test-session-id",
|
||||
}
|
||||
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "outputs" in data
|
||||
assert "session_id" in data
|
||||
assert data["session_id"] == "test-session-id"
|
||||
|
||||
|
||||
def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
flow_id = starter_project["id"]
|
||||
payload = {
|
||||
"inputs": [{"components": ["component1"], "input_value": "value1"}],
|
||||
"outputs": ["Component Name", "component_id"],
|
||||
"session_id": "invalid-session-id",
|
||||
}
|
||||
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
data = response.json()
|
||||
assert "detail" in data
|
||||
assert f"Session {payload['session_id']} not found" in data["detail"]
|
||||
|
||||
|
||||
def test_run_flow_with_invalid_tweaks(client, starter_project, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
flow_id = starter_project["id"]
|
||||
payload = {
|
||||
"inputs": [{"components": ["component1"], "input_value": "value1"}],
|
||||
"outputs": ["Component Name", "component_id"],
|
||||
"tweaks": {"invalid_tweak": "value"},
|
||||
}
|
||||
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue