Add new run_flow_with_caching endpoint
This commit is contained in:
parent
31fc5c4be0
commit
c26175aede
1 changed files with 124 additions and 15 deletions
|
|
@ -3,11 +3,15 @@ from typing import Annotated, Any, List, Optional, Union
|
|||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import update_frontend_node_with_template_values
|
||||
from langflow.api.v1.schemas import (
|
||||
CustomComponentCode,
|
||||
PreloadResponse,
|
||||
ProcessResponse,
|
||||
RunResponse,
|
||||
TaskResponse,
|
||||
TaskStatusResponse,
|
||||
UploadFileResponse,
|
||||
|
|
@ -15,15 +19,23 @@ from langflow.api.v1.schemas import (
|
|||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.interface.custom.utils import build_custom_component_template
|
||||
from langflow.processing.process import build_graph_and_generate_result, process_graph_cached, process_tweaks
|
||||
from langflow.processing.process import (
|
||||
build_graph_and_generate_result,
|
||||
process_graph_cached,
|
||||
process_tweaks,
|
||||
run_graph,
|
||||
)
|
||||
from langflow.services.auth.utils import api_key_security, get_current_active_user
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service
|
||||
from langflow.services.deps import (
|
||||
get_session,
|
||||
get_session_service,
|
||||
get_settings_service,
|
||||
get_task_service,
|
||||
)
|
||||
from langflow.services.session.service import SessionService
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
try:
|
||||
from langflow.worker import process_graph_cached_task
|
||||
|
|
@ -33,9 +45,10 @@ except ImportError:
|
|||
raise NotImplementedError("Celery is not installed")
|
||||
|
||||
|
||||
from langflow.services.task.service import TaskService
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.task.service import TaskService
|
||||
|
||||
# build router
|
||||
router = APIRouter(tags=["Base"])
|
||||
|
||||
|
|
@ -80,9 +93,15 @@ async def process_graph_data(
|
|||
)
|
||||
if session_id is None:
|
||||
# Generate a session ID
|
||||
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
|
||||
session_id = get_session_service().generate_key(
|
||||
session_id=session_id, data_graph=graph_data
|
||||
)
|
||||
task_id, task = await task_service.launch_task(
|
||||
process_graph_cached_task if task_service.use_celery else process_graph_cached,
|
||||
(
|
||||
process_graph_cached_task
|
||||
if task_service.use_celery
|
||||
else process_graph_cached
|
||||
),
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
|
|
@ -176,7 +195,11 @@ async def preload_flow(
|
|||
else:
|
||||
if session_id is None:
|
||||
session_id = flow_id
|
||||
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
|
||||
flow = session.exec(
|
||||
select(Flow)
|
||||
.where(Flow.id == flow_id)
|
||||
.where(Flow.user_id == api_key_user.id)
|
||||
).first()
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
|
|
@ -197,6 +220,76 @@ async def preload_flow(
|
|||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/run/{flow_id}", response_model=ProcessResponse)
|
||||
async def run_flow_with_caching(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
inputs: Optional[Union[List[dict], dict]] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
):
|
||||
try:
|
||||
if session_id:
|
||||
session_data = await session_service.load_session(session_id)
|
||||
graph, artifacts = session_data if session_data else (None, None)
|
||||
task_result: Any = None
|
||||
task_status = None
|
||||
if not graph:
|
||||
raise ValueError("Graph not found in the session")
|
||||
task_result = await run_graph(
|
||||
graph,
|
||||
session_id,
|
||||
inputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
else:
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
|
||||
flow = session.exec(
|
||||
select(Flow)
|
||||
.where(Flow.id == flow_id)
|
||||
.where(Flow.user_id == api_key_user.id)
|
||||
).first()
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
if flow.data is None:
|
||||
raise ValueError(f"Flow {flow_id} has no data")
|
||||
graph_data = flow.data
|
||||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
task_result = await run_graph(
|
||||
graph_data,
|
||||
inputs,
|
||||
tweaks,
|
||||
session_id,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
return RunResponse(
|
||||
outputs=task_result, session_id=session_id, status=task_status
|
||||
)
|
||||
except sa.exc.StatementError as exc:
|
||||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
# This means the Flow ID is not a valid UUID which means it can't find the flow
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
if f"Flow {flow_id} not found" in str(exc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/predict/{flow_id}",
|
||||
response_model=ProcessResponse,
|
||||
|
|
@ -269,7 +362,11 @@ async def process(
|
|||
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
|
||||
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
|
||||
flow = session.exec(
|
||||
select(Flow)
|
||||
.where(Flow.id == flow_id)
|
||||
.where(Flow.user_id == api_key_user.id)
|
||||
).first()
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
|
|
@ -289,12 +386,18 @@ async def process(
|
|||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
# This means the Flow ID is not a valid UUID which means it can't find the flow
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
if f"Flow {flow_id} not found" in str(exc):
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
|
||||
) from exc
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
@ -364,12 +467,16 @@ async def custom_component(
|
|||
|
||||
built_frontend_node = build_custom_component_template(component, user_id=user.id)
|
||||
|
||||
built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node)
|
||||
built_frontend_node = update_frontend_node_with_template_values(
|
||||
built_frontend_node, raw_code.frontend_node
|
||||
)
|
||||
return built_frontend_node
|
||||
|
||||
|
||||
@router.post("/custom_component/reload", status_code=HTTPStatus.OK)
|
||||
async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)):
|
||||
async def reload_custom_component(
|
||||
path: str, user: User = Depends(get_current_active_user)
|
||||
):
|
||||
from langflow.interface.custom.utils import build_custom_component_template
|
||||
|
||||
try:
|
||||
|
|
@ -391,6 +498,8 @@ async def custom_component_update(
|
|||
):
|
||||
component = CustomComponent(code=raw_code.code)
|
||||
|
||||
component_node = build_custom_component_template(component, user_id=user.id, update_field=raw_code.field)
|
||||
component_node = build_custom_component_template(
|
||||
component, user_id=user.id, update_field=raw_code.field
|
||||
)
|
||||
# Update the field
|
||||
return component_node
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue