Refactor run_flow_with_caching endpoint to include simplified and experimental versions
This commit is contained in:
parent
888a6904d5
commit
f43c558f7a
1 changed files with 119 additions and 2 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
|
||||
|
|
@ -51,7 +51,123 @@ def get_all(
|
|||
|
||||
|
||||
@router.post("/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
|
||||
async def run_flow_with_caching(
|
||||
async def simplified_run_flow_with_caching(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
input_value: Optional[str] = "",
|
||||
input_type: Optional[Literal["chat", "text", "any"]] = "chat",
|
||||
output_type: Optional[Literal["chat", "text", "any", "debug"]] = "chat",
|
||||
tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821
|
||||
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
):
|
||||
"""
|
||||
Executes a specified flow by ID, offering options for input, output customization, and performance enhancements through caching.
|
||||
|
||||
Parameters:
|
||||
- `session` (Session): Database session for executing queries.
|
||||
- `flow_id` (str): Unique identifier of the flow to execute.
|
||||
- `input_value` (Optional[str], default=""): Input value to pass to the flow. Defaults to an empty string.
|
||||
- `input_type` (Optional[Literal["chat", "text", "any"]], default="chat"): Type of the input value.
|
||||
- `output_type` (Optional[Literal["chat", "text", "any", "debug"]], default="chat"): Desired type of output. If "debug", all outputs are returned.
|
||||
- `tweaks` (Optional[Tweaks], default=None): Optional parameter tweaks to customize flow execution.
|
||||
- `stream` (bool, default=False): If true, outputs are streamed back as they are generated.
|
||||
- `session_id` (Union[None, str], default=None): Session ID to reuse existing session data, enhancing efficiency.
|
||||
- `api_key_user` (User): User object derived from the provided API key, ensuring secure access.
|
||||
- `session_service` (SessionService): Service for session management, crucial for caching and session reuse.
|
||||
|
||||
Returns:
|
||||
- `RunResponse`: Object containing the flow execution results and the session ID, allowing for result retrieval and session management.
|
||||
|
||||
Raises:
|
||||
- HTTPException: 404 if the specified flow or session cannot be found; 500 for internal errors during execution.
|
||||
|
||||
Example:
|
||||
```http
|
||||
POST /run/{flow_id}
|
||||
Content-Type: application/json
|
||||
x-api-key: YOUR_API_KEY
|
||||
|
||||
{
|
||||
"input_value": "Sample input",
|
||||
"input_type": "text",
|
||||
"output_type": "debug",
|
||||
"tweaks": {"example_tweak": "value"},
|
||||
"stream": true
|
||||
}
|
||||
```
|
||||
|
||||
This endpoint serves as a flexible and efficient way to execute flows with customizable inputs and outputs, leveraging caching for improved performance.
|
||||
"""
|
||||
try:
|
||||
task_result: List[RunOutputs] = []
|
||||
artifacts = {}
|
||||
if session_id:
|
||||
session_data = await session_service.load_session(session_id, flow_id=flow_id)
|
||||
graph, artifacts = session_data if session_data else (None, None)
|
||||
if graph is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
else:
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
|
||||
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
if flow.data is None:
|
||||
raise ValueError(f"Flow {flow_id} has no data")
|
||||
graph_data = flow.data
|
||||
graph_data = process_tweaks(graph_data, tweaks or {})
|
||||
graph = Graph.from_payload(graph_data, flow_id=flow_id)
|
||||
inputs = [InputValueRequest(components=[], input_value=input_value, type=input_type)]
|
||||
# outputs is a list of all components that should return output
|
||||
# we need to get them by checking their type
|
||||
# if the output type is debug, we return all outputs
|
||||
# if the output type is any, we return all outputs that are either chat or text
|
||||
# if the output type is chat or text, we return only the outputs that match the type
|
||||
outputs = [
|
||||
vertex
|
||||
for vertex in graph.vertices
|
||||
if output_type == "debug"
|
||||
or (vertex.is_output and (output_type == "any" or output_type in vertex.id.lower()))
|
||||
]
|
||||
task_result, session_id = await run_graph(
|
||||
graph=graph,
|
||||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return RunResponse(outputs=task_result, session_id=session_id)
|
||||
except sa.exc.StatementError as exc:
|
||||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
logger.error(f"Flow ID {flow_id} is not a valid UUID")
|
||||
# This means the Flow ID is not a valid UUID which means it can't find the flow
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
if f"Flow {flow_id} not found" in str(exc):
|
||||
logger.error(f"Flow {flow_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
elif f"Session {session_id} not found" in str(exc):
|
||||
logger.error(f"Session {session_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
else:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/run/advanced/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
|
||||
async def experimental_run_flow_with_caching(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
inputs: Optional[List[InputValueRequest]] = [InputValueRequest(components=[], input_value="")],
|
||||
|
|
@ -85,6 +201,7 @@ async def run_flow_with_caching(
|
|||
### Example usage:
|
||||
```json
|
||||
POST /run/{flow_id}
|
||||
x-api-key: YOUR_API_KEY
|
||||
Payload:
|
||||
{
|
||||
"inputs": [
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue