From 42cc1dacd25c9bf90e2af0be7ee6ce81b4838642 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 4 Nov 2024 09:31:56 -0300 Subject: [PATCH] feat: add dynamic outputs to Component (#4308) * Refine condition to check for undefined value in use-handle-new-value hook * update mutateTemplate to use APIClassType * Refactor Component class to update inputs and validate outputs This commit refactors the Component class in the `component.py` file. It introduces two new methods: `update_inputs` and `run_and_validate_update_outputs`. The `update_inputs` method allows for updating the build configuration with new field values, while the `run_and_validate_update_outputs` method updates the frontend node and validates the outputs. Additionally, the `_validate_frontend_node` method is added to check if all outputs are valid. The `_set_output_types` method is modified to accept a list of outputs and set their return types. Overall, these changes improve the functionality and maintainability of the Component class. * Add dynamic output validation in API endpoint for component updates * Update build_config to store field_value under "value" key * Refactor: Convert dict values to list in _set_output_types call * Add type check for `cc_instance` before calling `run_and_validate_update_outputs` * Add DynamicOutputComponent with configurable outputs based on input * Add test for updating component outputs with dynamic code input * Refactor: Make get_dynamic_output_component_code asynchronous for improved performance --------- Co-authored-by: Edwin Jose --- src/backend/base/langflow/api/v1/endpoints.py | 17 ++--- .../custom/custom_component/component.py | 64 ++++++++++++++++--- .../custom_component/custom_component.py | 2 +- .../tests/data/dynamic_output_component.py | 41 ++++++++++++ .../tests/unit/api/v1/test_endpoints.py | 29 +++++++++ .../CustomNodes/helpers/mutate-template.ts | 11 ++-- .../hooks/use-fetch-data-on-mount.tsx | 8 +-- .../hooks/use-handle-new-value.tsx | 2 +- .../queries/nodes/use-post-template-value.ts | 9 ++- 9 files changed, 145 insertions(+), 38 deletions(-) create mode 100644 src/backend/tests/data/dynamic_output_component.py diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index fa93f799d..000618fb6 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -49,12 +49,7 @@ from langflow.services.database.models.flow.utils import ( get_all_webhook_components_in_flow, ) from langflow.services.database.models.user.model import User, UserRead -from langflow.services.deps import ( - get_session_service, - get_settings_service, - get_task_service, - get_telemetry_service, -) +from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service from langflow.services.settings.feature_flags import FEATURE_FLAGS from langflow.services.telemetry.schema import RunPayload from langflow.utils.version import get_version_info @@ -637,10 +632,16 @@ async def custom_component_update( field_value=code_request.field_value, field_name=code_request.field, ) + component_node["template"] = updated_build_config + if isinstance(cc_instance, Component): + cc_instance.run_and_validate_update_outputs( + frontend_node=component_node, + field_name=code_request.field, + field_value=code_request.field_value, + ) + except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc - - component_node["template"] = updated_build_config return component_node diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 2a4c73f80..22db8d35e 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, get_type_hints import nanoid import yaml -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from langflow.base.tools.constants import TOOL_OUTPUT_NAME from langflow.custom.tree_visitor import RequiredInputsVisitor @@ -35,6 +35,7 @@ if TYPE_CHECKING: from langflow.graph.edge.schema import EdgeData from langflow.graph.vertex.base import Vertex from langflow.inputs.inputs import InputTypes + from langflow.schema import dotdict from langflow.schema.log import LoggableType @@ -104,7 +105,7 @@ class Component(CustomComponent): if self.outputs is not None: self.map_outputs(self.outputs) # Set output types - self._set_output_types() + self._set_output_types(list(self._outputs_map.values())) self.set_class_code() self._set_output_required_inputs() @@ -310,14 +311,57 @@ class Component(CustomComponent): self._validate_inputs(params) self._validate_outputs() - def _set_output_types(self) -> None: - for output in self._outputs_map.values(): - if output.method is None: - msg = f"Output {output.name} does not have a method" - raise ValueError(msg) - return_types = self._get_method_return_type(output.method) - output.add_types(return_types) - output.set_selected() + def update_inputs( + self, + build_config: dotdict, + field_value: Any, + field_name: str | None = None, + ): + return self.update_build_config(build_config, field_value, field_name) + + def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any): + frontend_node = self.update_outputs(frontend_node, field_name, field_value) + return self._validate_frontend_node(frontend_node) + + def _validate_frontend_node(self, frontend_node: dict): + # Check if all outputs are either Output or a valid Output model + for index, output in enumerate(frontend_node["outputs"]): + if isinstance(output, dict): + try: + _output = Output(**output) + self._set_output_return_type(_output) + _output_dict = _output.model_dump() + except ValidationError as e: + msg = f"Invalid output: {e}" + raise ValueError(msg) from e + elif isinstance(output, Output): + # we need to serialize it + self._set_output_return_type(output) + _output_dict = output.model_dump() + else: + msg = f"Invalid output type: {type(output)}" + raise TypeError(msg) + frontend_node["outputs"][index] = _output_dict + return frontend_node + + def update_outputs(self, frontend_node: dict, field_name: str, field_value: Any) -> dict: # noqa: ARG002 + """Default implementation for updating outputs based on field changes. + + Subclasses can override this to modify outputs based on field_name and field_value. + """ + return frontend_node + + def _set_output_types(self, outputs: list[Output]) -> None: + for output in outputs: + self._set_output_return_type(output) + + def _set_output_return_type(self, output: Output) -> None: + if output.method is None: + msg = f"Output {output.name} does not have a method" + raise ValueError(msg) + return_types = self._get_method_return_type(output.method) + output.add_types(return_types) + output.set_selected() def _set_output_required_inputs(self) -> None: for output in self.outputs: diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 3f25f4dd7..7544e4243 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -228,7 +228,7 @@ class CustomComponent(BaseComponent): field_value: Any, field_name: str | None = None, ): - build_config[field_name] = field_value + build_config[field_name]["value"] = field_value return build_config @property diff --git a/src/backend/tests/data/dynamic_output_component.py b/src/backend/tests/data/dynamic_output_component.py new file mode 100644 index 000000000..b2a23aa37 --- /dev/null +++ b/src/backend/tests/data/dynamic_output_component.py @@ -0,0 +1,41 @@ +# from langflow.field_typing import Data +from typing import Any + +from langflow.custom import Component +from langflow.io import BoolInput, MessageTextInput, Output +from langflow.schema import Data + + +class DynamicOutputComponent(Component): + display_name = "Dynamic Output Component" + description = "Use as a template to create your own component." + documentation: str = "http://docs.langflow.org/components/custom" + icon = "custom_components" + name = "DynamicOutputComponent" + + inputs = [ + MessageTextInput(name="input_value", display_name="Input Value", value="Hello, World!"), + BoolInput(name="show_output", display_name="Show Output", value=True, real_time_refresh=True), + ] + + outputs = [ + Output(display_name="Output", name="output", method="build_output"), + ] + + def update_outputs(self, frontend_node: dict, field_name: str, field_value: Any): + if field_name == "show_output": + if field_value: + frontend_node["outputs"].append( + Output(display_name="Tool Output", name="tool_output", method="build_output") + ) + else: + # remove the output + frontend_node["outputs"] = [ + output for output in frontend_node["outputs"] if output["name"] != "tool_output" + ] + return frontend_node + + def build_output(self) -> Data: + data = Data(value=self.input_value) + self.status = data + return data diff --git a/src/backend/tests/unit/api/v1/test_endpoints.py b/src/backend/tests/unit/api/v1/test_endpoints.py index 908b4edaf..5003bc838 100644 --- a/src/backend/tests/unit/api/v1/test_endpoints.py +++ b/src/backend/tests/unit/api/v1/test_endpoints.py @@ -1,5 +1,16 @@ +import asyncio +from pathlib import Path +from typing import Any + from fastapi import status from httpx import AsyncClient +from langflow.api.v1.schemas import UpdateCustomComponentRequest + + +async def get_dynamic_output_component_code(): + return await asyncio.to_thread( + Path("src/backend/tests/data/dynamic_output_component.py").read_text, encoding="utf-8" + ) async def test_get_version(client: AsyncClient): @@ -23,3 +34,21 @@ async def test_get_config(client: AsyncClient): assert "auto_saving" in result, "The dictionary must contain a key called 'auto_saving'" assert "health_check_max_retries" in result, "The dictionary must contain a 'health_check_max_retries' key" assert "max_file_size_upload" in result, "The dictionary must contain a key called 'max_file_size_upload'" + + +async def test_update_component_outputs(client: AsyncClient, logged_in_headers: dict): + code = await get_dynamic_output_component_code() + frontend_node: dict[str, Any] = {"outputs": []} + request = UpdateCustomComponentRequest( + code=code, + frontend_node=frontend_node, + field="show_output", + field_value=True, + template={}, + ) + response = await client.post("api/v1/custom_component/update", json=request.model_dump(), headers=logged_in_headers) + result = response.json() + + assert response.status_code == status.HTTP_200_OK + output_names = [output["name"] for output in result["outputs"]] + assert "tool_output" in output_names diff --git a/src/frontend/src/CustomNodes/helpers/mutate-template.ts b/src/frontend/src/CustomNodes/helpers/mutate-template.ts index 407157e2d..713400753 100644 --- a/src/frontend/src/CustomNodes/helpers/mutate-template.ts +++ b/src/frontend/src/CustomNodes/helpers/mutate-template.ts @@ -3,11 +3,7 @@ import { SAVE_DEBOUNCE_TIME, TITLE_ERROR_UPDATING_COMPONENT, } from "@/constants/constants"; -import { - APIClassType, - APITemplateType, - ResponseErrorDetailAPI, -} from "@/types/api"; +import { APIClassType, ResponseErrorDetailAPI } from "@/types/api"; import { UseMutationResult } from "@tanstack/react-query"; import { cloneDeep, debounce } from "lodash"; @@ -17,7 +13,7 @@ export const mutateTemplate = debounce( node: APIClassType, setNodeClass, postTemplateValue: UseMutationResult< - APITemplateType | undefined, + APIClassType | undefined, ResponseErrorDetailAPI, any >, @@ -29,7 +25,8 @@ export const mutateTemplate = debounce( value: newValue, }); if (newTemplate) { - newNode.template = newTemplate; + newNode.template = newTemplate.template; + newNode.outputs = newTemplate.outputs; } setNodeClass(newNode); } catch (e) { diff --git a/src/frontend/src/CustomNodes/hooks/use-fetch-data-on-mount.tsx b/src/frontend/src/CustomNodes/hooks/use-fetch-data-on-mount.tsx index 6379202fd..471e0bc0d 100644 --- a/src/frontend/src/CustomNodes/hooks/use-fetch-data-on-mount.tsx +++ b/src/frontend/src/CustomNodes/hooks/use-fetch-data-on-mount.tsx @@ -1,8 +1,4 @@ -import { - APIClassType, - APITemplateType, - ResponseErrorDetailAPI, -} from "@/types/api"; +import { APIClassType, ResponseErrorDetailAPI } from "@/types/api"; import { UseMutationResult } from "@tanstack/react-query"; import { useEffect } from "react"; import useAlertStore from "../../stores/alertStore"; @@ -13,7 +9,7 @@ const useFetchDataOnMount = ( setNodeClass: (node: APIClassType) => void, name: string, postTemplateValue: UseMutationResult< - APITemplateType | undefined, + APIClassType | undefined, ResponseErrorDetailAPI, any >, diff --git a/src/frontend/src/CustomNodes/hooks/use-handle-new-value.tsx b/src/frontend/src/CustomNodes/hooks/use-handle-new-value.tsx index dc2e2c309..a8faffa49 100644 --- a/src/frontend/src/CustomNodes/hooks/use-handle-new-value.tsx +++ b/src/frontend/src/CustomNodes/hooks/use-handle-new-value.tsx @@ -80,7 +80,7 @@ const useHandleOnNewValue = ({ }); }; - if (shouldUpdate && changes.value) { + if (shouldUpdate && changes.value !== undefined) { mutateTemplate( changes.value, newNode, diff --git a/src/frontend/src/controllers/API/queries/nodes/use-post-template-value.ts b/src/frontend/src/controllers/API/queries/nodes/use-post-template-value.ts index 73d339bbb..40d751078 100644 --- a/src/frontend/src/controllers/API/queries/nodes/use-post-template-value.ts +++ b/src/frontend/src/controllers/API/queries/nodes/use-post-template-value.ts @@ -1,6 +1,5 @@ import { APIClassType, - APITemplateType, ResponseErrorDetailAPI, useMutationFunctionType, } from "@/types/api"; @@ -22,14 +21,14 @@ interface IPostTemplateValueParams { export const usePostTemplateValue: useMutationFunctionType< IPostTemplateValueParams, IPostTemplateValue, - APITemplateType | undefined, + APIClassType, ResponseErrorDetailAPI > = ({ parameterId, nodeId, node }, options?) => { const { mutate } = UseRequestProcessor(); const postTemplateValueFn = async ( payload: IPostTemplateValue, - ): Promise => { + ): Promise => { const template = node.template; if (!template) return; @@ -44,11 +43,11 @@ export const usePostTemplateValue: useMutationFunctionType< }, ); - return response.data.template; + return response.data; }; const mutation: UseMutationResult< - APITemplateType | undefined, + APIClassType, ResponseErrorDetailAPI, IPostTemplateValue > = mutate(