feat: add dynamic outputs to Component (#4308)
* Refine condition to check for undefined value in use-handle-new-value hook * update mutateTemplate to use APIClassType * Refactor Component class to update inputs and validate outputs This commit refactors the Component class in the `component.py` file. It introduces two new methods: `update_inputs` and `run_and_validate_update_outputs`. The `update_inputs` method allows for updating the build configuration with new field values, while the `run_and_validate_update_outputs` method updates the frontend node and validates the outputs. Additionally, the `_validate_frontend_node` method is added to check if all outputs are valid. The `_set_output_types` method is modified to accept a list of outputs and set their return types. Overall, these changes improve the functionality and maintainability of the Component class. * Add dynamic output validation in API endpoint for component updates * Update build_config to store field_value under "value" key * Refactor: Convert dict values to list in _set_output_types call * Add type check for `cc_instance` before calling `run_and_validate_update_outputs` * Add DynamicOutputComponent with configurable outputs based on input * Add test for updating component outputs with dynamic code input * Refactor: Make get_dynamic_output_component_code asynchronous for improved performance --------- Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
parent
dd4a9f908c
commit
42cc1dacd2
9 changed files with 145 additions and 38 deletions
|
|
@ -49,12 +49,7 @@ from langflow.services.database.models.flow.utils import (
|
|||
get_all_webhook_components_in_flow,
|
||||
)
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import (
|
||||
get_session_service,
|
||||
get_settings_service,
|
||||
get_task_service,
|
||||
get_telemetry_service,
|
||||
)
|
||||
from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service
|
||||
from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
||||
from langflow.services.telemetry.schema import RunPayload
|
||||
from langflow.utils.version import get_version_info
|
||||
|
|
@ -637,10 +632,16 @@ async def custom_component_update(
|
|||
field_value=code_request.field_value,
|
||||
field_name=code_request.field,
|
||||
)
|
||||
component_node["template"] = updated_build_config
|
||||
if isinstance(cc_instance, Component):
|
||||
cc_instance.run_and_validate_update_outputs(
|
||||
frontend_node=component_node,
|
||||
field_name=code_request.field,
|
||||
field_value=code_request.field_value,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
component_node["template"] = updated_build_config
|
||||
return component_node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, get_type_hints
|
|||
|
||||
import nanoid
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
|
||||
from langflow.custom.tree_visitor import RequiredInputsVisitor
|
||||
|
|
@ -35,6 +35,7 @@ if TYPE_CHECKING:
|
|||
from langflow.graph.edge.schema import EdgeData
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.schema import dotdict
|
||||
from langflow.schema.log import LoggableType
|
||||
|
||||
|
||||
|
|
@ -104,7 +105,7 @@ class Component(CustomComponent):
|
|||
if self.outputs is not None:
|
||||
self.map_outputs(self.outputs)
|
||||
# Set output types
|
||||
self._set_output_types()
|
||||
self._set_output_types(list(self._outputs_map.values()))
|
||||
self.set_class_code()
|
||||
self._set_output_required_inputs()
|
||||
|
||||
|
|
@ -310,14 +311,57 @@ class Component(CustomComponent):
|
|||
self._validate_inputs(params)
|
||||
self._validate_outputs()
|
||||
|
||||
def _set_output_types(self) -> None:
|
||||
for output in self._outputs_map.values():
|
||||
if output.method is None:
|
||||
msg = f"Output {output.name} does not have a method"
|
||||
raise ValueError(msg)
|
||||
return_types = self._get_method_return_type(output.method)
|
||||
output.add_types(return_types)
|
||||
output.set_selected()
|
||||
def update_inputs(
|
||||
self,
|
||||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
return self.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
|
||||
frontend_node = self.update_outputs(frontend_node, field_name, field_value)
|
||||
return self._validate_frontend_node(frontend_node)
|
||||
|
||||
def _validate_frontend_node(self, frontend_node: dict):
|
||||
# Check if all outputs are either Output or a valid Output model
|
||||
for index, output in enumerate(frontend_node["outputs"]):
|
||||
if isinstance(output, dict):
|
||||
try:
|
||||
_output = Output(**output)
|
||||
self._set_output_return_type(_output)
|
||||
_output_dict = _output.model_dump()
|
||||
except ValidationError as e:
|
||||
msg = f"Invalid output: {e}"
|
||||
raise ValueError(msg) from e
|
||||
elif isinstance(output, Output):
|
||||
# we need to serialize it
|
||||
self._set_output_return_type(output)
|
||||
_output_dict = output.model_dump()
|
||||
else:
|
||||
msg = f"Invalid output type: {type(output)}"
|
||||
raise TypeError(msg)
|
||||
frontend_node["outputs"][index] = _output_dict
|
||||
return frontend_node
|
||||
|
||||
def update_outputs(self, frontend_node: dict, field_name: str, field_value: Any) -> dict: # noqa: ARG002
|
||||
"""Default implementation for updating outputs based on field changes.
|
||||
|
||||
Subclasses can override this to modify outputs based on field_name and field_value.
|
||||
"""
|
||||
return frontend_node
|
||||
|
||||
def _set_output_types(self, outputs: list[Output]) -> None:
|
||||
for output in outputs:
|
||||
self._set_output_return_type(output)
|
||||
|
||||
def _set_output_return_type(self, output: Output) -> None:
|
||||
if output.method is None:
|
||||
msg = f"Output {output.name} does not have a method"
|
||||
raise ValueError(msg)
|
||||
return_types = self._get_method_return_type(output.method)
|
||||
output.add_types(return_types)
|
||||
output.set_selected()
|
||||
|
||||
def _set_output_required_inputs(self) -> None:
|
||||
for output in self.outputs:
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ class CustomComponent(BaseComponent):
|
|||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
build_config[field_name] = field_value
|
||||
build_config[field_name]["value"] = field_value
|
||||
return build_config
|
||||
|
||||
@property
|
||||
|
|
|
|||
41
src/backend/tests/data/dynamic_output_component.py
Normal file
41
src/backend/tests/data/dynamic_output_component.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# from langflow.field_typing import Data
|
||||
from typing import Any
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.io import BoolInput, MessageTextInput, Output
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class DynamicOutputComponent(Component):
|
||||
display_name = "Dynamic Output Component"
|
||||
description = "Use as a template to create your own component."
|
||||
documentation: str = "http://docs.langflow.org/components/custom"
|
||||
icon = "custom_components"
|
||||
name = "DynamicOutputComponent"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(name="input_value", display_name="Input Value", value="Hello, World!"),
|
||||
BoolInput(name="show_output", display_name="Show Output", value=True, real_time_refresh=True),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Output", name="output", method="build_output"),
|
||||
]
|
||||
|
||||
def update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
|
||||
if field_name == "show_output":
|
||||
if field_value:
|
||||
frontend_node["outputs"].append(
|
||||
Output(display_name="Tool Output", name="tool_output", method="build_output")
|
||||
)
|
||||
else:
|
||||
# remove the output
|
||||
frontend_node["outputs"] = [
|
||||
output for output in frontend_node["outputs"] if output["name"] != "tool_output"
|
||||
]
|
||||
return frontend_node
|
||||
|
||||
def build_output(self) -> Data:
|
||||
data = Data(value=self.input_value)
|
||||
self.status = data
|
||||
return data
|
||||
|
|
@ -1,5 +1,16 @@
|
|||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
from langflow.api.v1.schemas import UpdateCustomComponentRequest
|
||||
|
||||
|
||||
async def get_dynamic_output_component_code():
|
||||
return await asyncio.to_thread(
|
||||
Path("src/backend/tests/data/dynamic_output_component.py").read_text, encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
async def test_get_version(client: AsyncClient):
|
||||
|
|
@ -23,3 +34,21 @@ async def test_get_config(client: AsyncClient):
|
|||
assert "auto_saving" in result, "The dictionary must contain a key called 'auto_saving'"
|
||||
assert "health_check_max_retries" in result, "The dictionary must contain a 'health_check_max_retries' key"
|
||||
assert "max_file_size_upload" in result, "The dictionary must contain a key called 'max_file_size_upload'"
|
||||
|
||||
|
||||
async def test_update_component_outputs(client: AsyncClient, logged_in_headers: dict):
|
||||
code = await get_dynamic_output_component_code()
|
||||
frontend_node: dict[str, Any] = {"outputs": []}
|
||||
request = UpdateCustomComponentRequest(
|
||||
code=code,
|
||||
frontend_node=frontend_node,
|
||||
field="show_output",
|
||||
field_value=True,
|
||||
template={},
|
||||
)
|
||||
response = await client.post("api/v1/custom_component/update", json=request.model_dump(), headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
output_names = [output["name"] for output in result["outputs"]]
|
||||
assert "tool_output" in output_names
|
||||
|
|
|
|||
|
|
@ -3,11 +3,7 @@ import {
|
|||
SAVE_DEBOUNCE_TIME,
|
||||
TITLE_ERROR_UPDATING_COMPONENT,
|
||||
} from "@/constants/constants";
|
||||
import {
|
||||
APIClassType,
|
||||
APITemplateType,
|
||||
ResponseErrorDetailAPI,
|
||||
} from "@/types/api";
|
||||
import { APIClassType, ResponseErrorDetailAPI } from "@/types/api";
|
||||
import { UseMutationResult } from "@tanstack/react-query";
|
||||
import { cloneDeep, debounce } from "lodash";
|
||||
|
||||
|
|
@ -17,7 +13,7 @@ export const mutateTemplate = debounce(
|
|||
node: APIClassType,
|
||||
setNodeClass,
|
||||
postTemplateValue: UseMutationResult<
|
||||
APITemplateType | undefined,
|
||||
APIClassType | undefined,
|
||||
ResponseErrorDetailAPI,
|
||||
any
|
||||
>,
|
||||
|
|
@ -29,7 +25,8 @@ export const mutateTemplate = debounce(
|
|||
value: newValue,
|
||||
});
|
||||
if (newTemplate) {
|
||||
newNode.template = newTemplate;
|
||||
newNode.template = newTemplate.template;
|
||||
newNode.outputs = newTemplate.outputs;
|
||||
}
|
||||
setNodeClass(newNode);
|
||||
} catch (e) {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,4 @@
|
|||
import {
|
||||
APIClassType,
|
||||
APITemplateType,
|
||||
ResponseErrorDetailAPI,
|
||||
} from "@/types/api";
|
||||
import { APIClassType, ResponseErrorDetailAPI } from "@/types/api";
|
||||
import { UseMutationResult } from "@tanstack/react-query";
|
||||
import { useEffect } from "react";
|
||||
import useAlertStore from "../../stores/alertStore";
|
||||
|
|
@ -13,7 +9,7 @@ const useFetchDataOnMount = (
|
|||
setNodeClass: (node: APIClassType) => void,
|
||||
name: string,
|
||||
postTemplateValue: UseMutationResult<
|
||||
APITemplateType | undefined,
|
||||
APIClassType | undefined,
|
||||
ResponseErrorDetailAPI,
|
||||
any
|
||||
>,
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ const useHandleOnNewValue = ({
|
|||
});
|
||||
};
|
||||
|
||||
if (shouldUpdate && changes.value) {
|
||||
if (shouldUpdate && changes.value !== undefined) {
|
||||
mutateTemplate(
|
||||
changes.value,
|
||||
newNode,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import {
|
||||
APIClassType,
|
||||
APITemplateType,
|
||||
ResponseErrorDetailAPI,
|
||||
useMutationFunctionType,
|
||||
} from "@/types/api";
|
||||
|
|
@ -22,14 +21,14 @@ interface IPostTemplateValueParams {
|
|||
export const usePostTemplateValue: useMutationFunctionType<
|
||||
IPostTemplateValueParams,
|
||||
IPostTemplateValue,
|
||||
APITemplateType | undefined,
|
||||
APIClassType,
|
||||
ResponseErrorDetailAPI
|
||||
> = ({ parameterId, nodeId, node }, options?) => {
|
||||
const { mutate } = UseRequestProcessor();
|
||||
|
||||
const postTemplateValueFn = async (
|
||||
payload: IPostTemplateValue,
|
||||
): Promise<APITemplateType | undefined> => {
|
||||
): Promise<APIClassType | undefined> => {
|
||||
const template = node.template;
|
||||
|
||||
if (!template) return;
|
||||
|
|
@ -44,11 +43,11 @@ export const usePostTemplateValue: useMutationFunctionType<
|
|||
},
|
||||
);
|
||||
|
||||
return response.data.template;
|
||||
return response.data;
|
||||
};
|
||||
|
||||
const mutation: UseMutationResult<
|
||||
APITemplateType | undefined,
|
||||
APIClassType,
|
||||
ResponseErrorDetailAPI,
|
||||
IPostTemplateValue
|
||||
> = mutate(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue