Add process_json endpoint for processing JSON data

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-18 10:52:33 -03:00
commit 34e116ddf2

View file

@ -3,24 +3,26 @@ from typing import Annotated, Optional, Union
import sqlalchemy as sa
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
from loguru import logger
from sqlmodel import select
from langflow.api.utils import update_frontend_node_with_template_values
from langflow.api.v1.schemas import (CustomComponentCode, ProcessResponse,
TaskResponse, TaskStatusResponse,
UploadFileResponse)
from langflow.api.v1.schemas import (
CustomComponentCode,
ProcessResponse,
TaskResponse,
TaskStatusResponse,
UploadFileResponse,
)
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.interface.custom.utils import (build_custom_component_template,
create_and_validate_component)
from langflow.interface.custom.utils import build_custom_component_template, create_and_validate_component
from langflow.processing.process import process_graph_cached, process_tweaks
from langflow.services.auth.utils import (api_key_security,
get_current_active_user)
from langflow.services.auth.utils import api_key_security, get_current_active_user
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.user.model import User
from langflow.services.deps import (get_session, get_session_service,
get_settings_service, get_task_service)
from loguru import logger
from sqlmodel import select
from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service
try:
from langflow.worker import process_graph_cached_task
@ -30,13 +32,81 @@ except ImportError:
raise NotImplementedError("Celery is not installed")
from langflow.services.task.service import TaskService
from sqlmodel import Session
from langflow.services.task.service import TaskService
# build router
router = APIRouter(tags=["Base"])
async def process_graph_data(
graph_data: dict,
inputs: Optional[dict] = None,
tweaks: Optional[dict] = None,
clear_cache: bool = False,
session_id: Optional[str] = None,
task_service: "TaskService" = Depends(get_task_service),
sync: bool = True,
):
task_result = None
task_status = None
if tweaks:
try:
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f"Error processing tweaks: {exc}")
if sync:
task_id, result = await task_service.launch_and_await_task(
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
if isinstance(result, dict) and "result" in result:
task_result = result["result"]
session_id = result["session_id"]
elif hasattr(result, "result") and hasattr(result, "session_id"):
task_result = result.result
session_id = result.session_id
else:
logger.warning(
"This is an experimental feature and may not work as expected."
"Please report any issues to our GitHub repository."
)
if session_id is None:
# Generate a session ID
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
task_id, task = await task_service.launch_task(
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
task_status = task.status
if task.status == "FAILURE":
logger.error(f"Task {task_id} failed: {task.traceback}")
task_result = str(task._exception)
else:
task_result = task.result
if task_id:
task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}")
else:
task_response = None
return ProcessResponse(
result=task_result,
status=task_status,
task=task_response,
session_id=session_id,
backend=task_service.backend_name,
)
@router.get("/all", dependencies=[Depends(get_current_active_user)])
def get_all(
settings_service=Depends(get_settings_service),
@ -50,7 +120,32 @@ def get_all(
raise HTTPException(status_code=500, detail=str(exc)) from exc
# For backwards compatibility we will keep the old endpoint
@router.post("/process/json", response_model=ProcessResponse)
async def process_json(
session: Annotated[Session, Depends(get_session)],
data: dict,
inputs: Optional[dict] = None,
tweaks: Optional[dict] = None,
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
task_service: "TaskService" = Depends(get_task_service),
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
):
try:
return await process_graph_data(
graph_data=data,
inputs=inputs,
tweaks=tweaks,
clear_cache=clear_cache,
session_id=session_id,
task_service=task_service,
sync=sync,
)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
@router.post(
"/predict/{flow_id}",
response_model=ProcessResponse,
@ -91,54 +186,14 @@ async def process(
if flow.data is None:
raise ValueError(f"Flow {flow_id} has no data")
graph_data = flow.data
task_result = None
if tweaks:
try:
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f"Error processing tweaks: {exc}")
if sync:
task_id, result = await task_service.launch_and_await_task(
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
if isinstance(result, dict) and "result" in result:
task_result = result["result"]
session_id = result["session_id"]
elif hasattr(result, "result") and hasattr(result, "session_id"):
task_result = result.result
session_id = result.session_id
else:
logger.warning(
"This is an experimental feature and may not work as expected."
"Please report any issues to our GitHub repository."
)
if session_id is None:
# Generate a session ID
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
task_id, task = await task_service.launch_task(
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
session_id,
)
task_result = task.status
if task_id:
task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}")
else:
task_response = None
return ProcessResponse(
result=task_result,
task=task_response,
return await process_graph_data(
graph_data=graph_data,
inputs=inputs,
tweaks=tweaks,
clear_cache=clear_cache,
session_id=session_id,
backend=task_service.backend_name,
task_service=task_service,
sync=sync,
)
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
@ -161,6 +216,8 @@ async def get_task_status(task_id: str):
task_service = get_task_service()
task = task_service.get_task(task_id)
result = None
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
if task.ready():
result = task.result
# If result isinstance of Exception, can we get the traceback?
@ -172,8 +229,6 @@ async def get_task_status(task_id: str):
elif hasattr(result, "result"):
result = result.result
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
if task.status == "FAILURE":
result = str(task.result)
logger.error(f"Task {task_id} failed: {task.traceback}")