Refactor run_flow_with_caching endpoint to include simplified and experimental versions

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-31 23:12:24 -03:00
commit f43c558f7a

View file

@ -1,5 +1,5 @@
from http import HTTPStatus
from typing import Annotated, List, Optional, Union
from typing import Annotated, List, Literal, Optional, Union
import sqlalchemy as sa
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
@ -51,7 +51,123 @@ def get_all(
@router.post("/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def run_flow_with_caching(
async def simplified_run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
input_value: Optional[str] = "",
input_type: Optional[Literal["chat", "text", "any"]] = "chat",
output_type: Optional[Literal["chat", "text", "any", "debug"]] = "chat",
tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
api_key_user: User = Depends(api_key_security),
session_service: SessionService = Depends(get_session_service),
):
"""
Executes a specified flow by ID, offering options for input, output customization, and performance enhancements through caching.
Parameters:
- `session` (Session): Database session for executing queries.
- `flow_id` (str): Unique identifier of the flow to execute.
- `input_value` (Optional[str], default=""): Input value to pass to the flow. Defaults to an empty string.
- `input_type` (Optional[Literal["chat", "text", "any"]], default="chat"): Type of the input value.
- `output_type` (Optional[Literal["chat", "text", "any", "debug"]], default="chat"): Desired type of output. If "debug", all outputs are returned.
- `tweaks` (Optional[Tweaks], default=None): Optional parameter tweaks to customize flow execution.
- `stream` (bool, default=False): If true, outputs are streamed back as they are generated.
- `session_id` (Union[None, str], default=None): Session ID to reuse existing session data, enhancing efficiency.
- `api_key_user` (User): User object derived from the provided API key, ensuring secure access.
- `session_service` (SessionService): Service for session management, crucial for caching and session reuse.
Returns:
- `RunResponse`: Object containing the flow execution results and the session ID, allowing for result retrieval and session management.
Raises:
- HTTPException: 404 if the specified flow or session cannot be found; 500 for internal errors during execution.
Example:
```http
POST /run/{flow_id}
Content-Type: application/json
x-api-key: YOUR_API_KEY
{
"input_value": "Sample input",
"input_type": "text",
"output_type": "debug",
"tweaks": {"example_tweak": "value"},
"stream": true
}
```
This endpoint serves as a flexible and efficient way to execute flows with customizable inputs and outputs, leveraging caching for improved performance.
"""
try:
task_result: List[RunOutputs] = []
artifacts = {}
if session_id:
session_data = await session_service.load_session(session_id, flow_id=flow_id)
graph, artifacts = session_data if session_data else (None, None)
if graph is None:
raise ValueError(f"Session {session_id} not found")
else:
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
if flow.data is None:
raise ValueError(f"Flow {flow_id} has no data")
graph_data = flow.data
graph_data = process_tweaks(graph_data, tweaks or {})
graph = Graph.from_payload(graph_data, flow_id=flow_id)
inputs = [InputValueRequest(components=[], input_value=input_value, type=input_type)]
# outputs is a list of all components that should return output
# we need to get them by checking their type
# if the output type is debug, we return all outputs
# if the output type is any, we return all outputs that are either chat or text
# if the output type is chat or text, we return only the outputs that match the type
outputs = [
vertex
for vertex in graph.vertices
if output_type == "debug"
or (vertex.is_output and (output_type == "any" or output_type in vertex.id.lower()))
]
task_result, session_id = await run_graph(
graph=graph,
flow_id=flow_id,
session_id=session_id,
inputs=inputs,
outputs=outputs,
artifacts=artifacts,
session_service=session_service,
stream=stream,
)
return RunResponse(outputs=task_result, session_id=session_id)
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
logger.error(f"Flow ID {flow_id} is not a valid UUID")
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
logger.error(f"Flow {flow_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
elif f"Session {session_id} not found" in str(exc):
logger.error(f"Session {session_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
else:
logger.exception(exc)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
@router.post("/run/advanced/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def experimental_run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: Optional[List[InputValueRequest]] = [InputValueRequest(components=[], input_value="")],
@ -85,6 +201,7 @@ async def run_flow_with_caching(
### Example usage:
```json
POST /run/{flow_id}
x-api-key: YOUR_API_KEY
Payload:
{
"inputs": [