Refactor process function in endpoints.py

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-21 14:13:33 -03:00
commit 686e15da1a

View file

@ -92,12 +92,7 @@ async def process(
)
# Get the flow that matches the flow_id and belongs to the user
flow = (
session.query(Flow)
.filter(Flow.id == flow_id)
.filter(Flow.user_id == api_key_user.id)
.first()
)
flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
@ -111,9 +106,7 @@ async def process(
logger.error(f"Error processing tweaks: {exc}")
if sync:
task_id, result = await task_service.launch_and_await_task(
process_graph_cached_task
if task_service.use_celery
else process_graph_cached,
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
@ -133,13 +126,9 @@ async def process(
)
if session_id is None:
# Generate a session ID
session_id = get_session_service().generate_key(
session_id=session_id, data_graph=graph_data
)
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
task_id, task = await task_service.launch_task(
process_graph_cached_task
if task_service.use_celery
else process_graph_cached,
process_graph_cached_task if task_service.use_celery else process_graph_cached,
graph_data,
inputs,
clear_cache,
@ -162,18 +151,12 @@ async def process(
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
except Exception as e:
# Log stack trace
logger.exception(e)
@ -237,3 +220,20 @@ async def custom_component(
extractor.is_check_valid()
return build_langchain_template_custom_component(extractor, user_id=user.id)
@router.post("/custom_component/update", status_code=HTTPStatus.OK)
async def custom_component_update(
raw_code: CustomComponentCode,
field: str,
user: User = Depends(get_current_active_user),
):
from langflow.interface.types import (
build_langchain_template_custom_component,
)
extractor = CustomComponent(code=raw_code.code)
extractor.is_check_valid()
component_node = build_langchain_template_custom_component(extractor, user_id=user.id)
# Update the field