Refactor API endpoints and schemas

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-11 16:51:08 -03:00
commit bcab53818e
3 changed files with 28 additions and 85 deletions

View file

@ -52,9 +52,7 @@ def get_all(
raise HTTPException(status_code=500, detail=str(exc)) from exc
@router.post(
"/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True
)
@router.post("/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
@ -105,9 +103,7 @@ async def run_flow_with_caching(
"""
try:
if inputs is not None:
input_values: list[dict[str, Union[str, list[str]]]] = [
_input.model_dump() for _input in inputs
]
input_values: list[dict[str, Union[str, list[str]]]] = [_input.model_dump() for _input in inputs]
else:
input_values = [{}]
@ -115,9 +111,7 @@ async def run_flow_with_caching(
outputs = []
if session_id:
session_data = await session_service.load_session(
session_id, flow_id=flow_id
)
session_data = await session_service.load_session(session_id, flow_id=flow_id)
graph, artifacts = session_data if session_data else (None, None)
task_result: Any = None
if not graph:
@ -136,11 +130,7 @@ async def run_flow_with_caching(
else:
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
@ -164,18 +154,12 @@ async def run_flow_with_caching(
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
@router.post(
@ -204,8 +188,7 @@ async def process(
"""
# Raise a depreciation warning
logger.warning(
"The /process endpoint is deprecated and will be removed in a future version. "
"Please use /run instead."
"The /process endpoint is deprecated and will be removed in a future version. " "Please use /run instead."
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -277,16 +260,12 @@ async def custom_component(
built_frontend_node, _ = build_custom_component_template(component, user_id=user.id)
built_frontend_node = update_frontend_node_with_template_values(
built_frontend_node, raw_code.frontend_node
)
built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node)
return built_frontend_node
@router.post("/custom_component/reload", status_code=HTTPStatus.OK)
async def reload_custom_component(
path: str, user: User = Depends(get_current_active_user)
):
async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)):
from langflow.interface.custom.utils import build_custom_component_template
try:

View file

@ -167,9 +167,7 @@ class StreamData(BaseModel):
data: dict
def __str__(self) -> str:
return (
f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
)
return f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
class CustomComponentRequest(BaseModel):

View file

@ -34,18 +34,14 @@ class UpdateBuildConfigError(Exception):
pass
def add_output_types(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add output types to the frontend node"""
for return_type in return_types:
if return_type is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type. Please check your code and try again."
),
"error": ("Invalid return type. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
)
@ -77,18 +73,14 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: List
frontend_node.field_order = field_order
def add_base_classes(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add base classes to the frontend node"""
for return_type_instance in return_types:
if return_type_instance is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type. Please check your code and try again."
),
"error": ("Invalid return type. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
)
@ -123,9 +115,7 @@ def get_field_properties(extra_field):
# a required field is a field that does not contain
# optional in field_type
# and a field that does not have a default value
field_required = "optional" not in field_type.lower() and isinstance(
field_value, MissingDefault
)
field_required = "optional" not in field_type.lower() and isinstance(field_value, MissingDefault)
field_value = field_value if not isinstance(field_value, MissingDefault) else None
if not field_required:
@ -174,14 +164,10 @@ def add_new_custom_field(
# If options is a list, then it's a dropdown
# If options is None, then it's a list of strings
is_list = isinstance(field_config.get("options"), list)
field_config["is_list"] = (
is_list or field_config.get("list", False) or field_contains_list
)
field_config["is_list"] = is_list or field_config.get("list", False) or field_contains_list
if "name" in field_config:
warnings.warn(
"The 'name' key in field_config is used to build the object and can't be changed."
)
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
required = field_config.pop("required", field_required)
placeholder = field_config.pop("placeholder", "")
@ -220,9 +206,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
]:
continue
field_name, field_type, field_value, field_required = get_field_properties(
extra_field
)
field_name, field_type, field_value, field_required = get_field_properties(extra_field)
config = _field_config.pop(field_name, {})
frontend_node = add_new_custom_field(
frontend_node,
@ -232,17 +216,13 @@ def add_extra_fields(frontend_node, field_config, function_args):
field_required,
config,
)
if "kwargs" in function_args_names and not all(
key in function_args_names for key in field_config.keys()
):
if "kwargs" in function_args_names and not all(key in function_args_names for key in field_config.keys()):
for field_name, field_config in _field_config.copy().items():
if "name" not in field_config or field_name == "code":
continue
config = _field_config.get(field_name, {})
config = config.model_dump() if isinstance(config, BaseModel) else config
field_name, field_type, field_value, field_required = get_field_properties(
extra_field=config
)
field_name, field_type, field_value, field_required = get_field_properties(extra_field=config)
frontend_node = add_new_custom_field(
frontend_node,
field_name,
@ -278,9 +258,7 @@ def run_build_config(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -360,16 +338,10 @@ def build_custom_component_template(
add_extra_fields(frontend_node, field_config, entrypoint_args)
frontend_node = add_code_field(
frontend_node, custom_component.code, field_config.get("code", {})
)
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
add_base_classes(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_output_types(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
reorder_fields(frontend_node, custom_instance._get_field_order())
@ -380,9 +352,7 @@ def build_custom_component_template(
raise HTTPException(
status_code=400,
detail={
"error": (
f"Something went wrong while building the custom component. Hints: {str(exc)}"
),
"error": (f"Something went wrong while building the custom component. Hints: {str(exc)}"),
"traceback": traceback.format_exc(),
},
) from exc
@ -418,9 +388,7 @@ def build_custom_components(components_paths: List[str]):
custom_component_dict = build_custom_component_list_from_path(path_str)
if custom_component_dict:
category = next(iter(custom_component_dict))
logger.info(
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
)
logger.info(f"Loading {len(custom_component_dict[category])} component(s) from category {category}")
custom_components_from_file = merge_nested_dicts_with_renaming(
custom_components_from_file, custom_component_dict
)
@ -455,9 +423,7 @@ def update_field_dict(
build_config = dd_build_config
except Exception as exc:
logger.error(f"Error while running update_build_config: {str(exc)}")
raise UpdateBuildConfigError(
f"Error while running update_build_config: {str(exc)}"
) from exc
raise UpdateBuildConfigError(f"Error while running update_build_config: {str(exc)}") from exc
# Let's check if "range_spec" is a RangeSpec object
if "rangeSpec" in field_dict and isinstance(field_dict["rangeSpec"], RangeSpec):