Add new run_flow_with_caching endpoint

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 13:46:00 -03:00
commit c26175aede

View file

@ -3,11 +3,15 @@ from typing import Annotated, Any, List, Optional, Union
import sqlalchemy as sa
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
from loguru import logger
from sqlmodel import select
from langflow.api.utils import update_frontend_node_with_template_values
from langflow.api.v1.schemas import (
CustomComponentCode,
PreloadResponse,
ProcessResponse,
RunResponse,
TaskResponse,
TaskStatusResponse,
UploadFileResponse,
@ -15,15 +19,23 @@ from langflow.api.v1.schemas import (
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.interface.custom.utils import build_custom_component_template
from langflow.processing.process import build_graph_and_generate_result, process_graph_cached, process_tweaks
from langflow.processing.process import (
build_graph_and_generate_result,
process_graph_cached,
process_tweaks,
run_graph,
)
from langflow.services.auth.utils import api_key_security, get_current_active_user
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service
from langflow.services.deps import (
get_session,
get_session_service,
get_settings_service,
get_task_service,
)
from langflow.services.session.service import SessionService
from loguru import logger
from sqlmodel import select
try:
from langflow.worker import process_graph_cached_task
@ -33,9 +45,10 @@ except ImportError:
raise NotImplementedError("Celery is not installed")
from langflow.services.task.service import TaskService
from sqlmodel import Session
from langflow.services.task.service import TaskService
# build router
router = APIRouter(tags=["Base"])
@ -80,9 +93,15 @@ async def process_graph_data(
)
if session_id is None:
# Generate a session ID
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
session_id = get_session_service().generate_key(
session_id=session_id, data_graph=graph_data
)
task_id, task = await task_service.launch_task(
process_graph_cached_task if task_service.use_celery else process_graph_cached,
(
process_graph_cached_task
if task_service.use_celery
else process_graph_cached
),
graph_data,
inputs,
clear_cache,
@ -176,7 +195,11 @@ async def preload_flow(
else:
if session_id is None:
session_id = flow_id
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
@ -197,6 +220,76 @@ async def preload_flow(
raise HTTPException(status_code=500, detail=str(exc)) from exc
@router.post("/run/{flow_id}", response_model=ProcessResponse)
async def run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: Optional[Union[List[dict], dict]] = None,
tweaks: Optional[dict] = None,
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
api_key_user: User = Depends(api_key_security),
session_service: SessionService = Depends(get_session_service),
):
try:
if session_id:
session_data = await session_service.load_session(session_id)
graph, artifacts = session_data if session_data else (None, None)
task_result: Any = None
task_status = None
if not graph:
raise ValueError("Graph not found in the session")
task_result = await run_graph(
graph,
session_id,
inputs,
artifacts=artifacts,
session_service=session_service,
)
else:
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
if flow.data is None:
raise ValueError(f"Flow {flow_id} has no data")
graph_data = flow.data
graph_data = process_tweaks(graph_data, tweaks)
task_result = await run_graph(
graph_data,
inputs,
tweaks,
session_id,
session_service=session_service,
)
return RunResponse(
outputs=task_result, session_id=session_id, status=task_status
)
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
@router.post(
"/predict/{flow_id}",
response_model=ProcessResponse,
@ -269,7 +362,11 @@ async def process(
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
@ -289,12 +386,18 @@ async def process(
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
else:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
except Exception as e:
# Log stack trace
logger.exception(e)
@ -364,12 +467,16 @@ async def custom_component(
built_frontend_node = build_custom_component_template(component, user_id=user.id)
built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node)
built_frontend_node = update_frontend_node_with_template_values(
built_frontend_node, raw_code.frontend_node
)
return built_frontend_node
@router.post("/custom_component/reload", status_code=HTTPStatus.OK)
async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)):
async def reload_custom_component(
path: str, user: User = Depends(get_current_active_user)
):
from langflow.interface.custom.utils import build_custom_component_template
try:
@ -391,6 +498,8 @@ async def custom_component_update(
):
component = CustomComponent(code=raw_code.code)
component_node = build_custom_component_template(component, user_id=user.id, update_field=raw_code.field)
component_node = build_custom_component_template(
component, user_id=user.id, update_field=raw_code.field
)
# Update the field
return component_node