feat: add dynamic outputs to Component (#4308)

* Refine condition to check for undefined value in use-handle-new-value hook

* update mutateTemplate to use APIClassType

* Refactor Component class to update inputs and validate outputs

This commit refactors the Component class in the `component.py` file. It introduces two new methods: `update_inputs` and `run_and_validate_update_outputs`. The `update_inputs` method allows for updating the build configuration with new field values, while the `run_and_validate_update_outputs` method updates the frontend node and validates the outputs. Additionally, the `_validate_frontend_node` method is added to check if all outputs are valid. The `_set_output_types` method is modified to accept a list of outputs and set their return types. Overall, these changes improve the functionality and maintainability of the Component class.

* Add dynamic output validation in API endpoint for component updates

* Update build_config to store field_value under "value" key

* Refactor: Convert dict values to list in _set_output_types call

* Add type check for `cc_instance` before calling `run_and_validate_update_outputs`

* Add DynamicOutputComponent with configurable outputs based on input

* Add test for updating component outputs with dynamic code input

* Refactor: Make get_dynamic_output_component_code asynchronous for improved performance

---------

Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-11-04 09:31:56 -03:00 committed by GitHub
commit 42cc1dacd2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 145 additions and 38 deletions

View file

@ -49,12 +49,7 @@ from langflow.services.database.models.flow.utils import (
get_all_webhook_components_in_flow,
)
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import (
get_session_service,
get_settings_service,
get_task_service,
get_telemetry_service,
)
from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service
from langflow.services.settings.feature_flags import FEATURE_FLAGS
from langflow.services.telemetry.schema import RunPayload
from langflow.utils.version import get_version_info
@ -637,10 +632,16 @@ async def custom_component_update(
field_value=code_request.field_value,
field_name=code_request.field,
)
component_node["template"] = updated_build_config
if isinstance(cc_instance, Component):
cc_instance.run_and_validate_update_outputs(
frontend_node=component_node,
field_name=code_request.field,
field_value=code_request.field_value,
)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
component_node["template"] = updated_build_config
return component_node

View file

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, get_type_hints
import nanoid
import yaml
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
from langflow.custom.tree_visitor import RequiredInputsVisitor
@ -35,6 +35,7 @@ if TYPE_CHECKING:
from langflow.graph.edge.schema import EdgeData
from langflow.graph.vertex.base import Vertex
from langflow.inputs.inputs import InputTypes
from langflow.schema import dotdict
from langflow.schema.log import LoggableType
@ -104,7 +105,7 @@ class Component(CustomComponent):
if self.outputs is not None:
self.map_outputs(self.outputs)
# Set output types
self._set_output_types()
self._set_output_types(list(self._outputs_map.values()))
self.set_class_code()
self._set_output_required_inputs()
@ -310,14 +311,57 @@ class Component(CustomComponent):
self._validate_inputs(params)
self._validate_outputs()
def _set_output_types(self) -> None:
for output in self._outputs_map.values():
if output.method is None:
msg = f"Output {output.name} does not have a method"
raise ValueError(msg)
return_types = self._get_method_return_type(output.method)
output.add_types(return_types)
output.set_selected()
def update_inputs(
self,
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
):
return self.update_build_config(build_config, field_value, field_name)
def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
frontend_node = self.update_outputs(frontend_node, field_name, field_value)
return self._validate_frontend_node(frontend_node)
def _validate_frontend_node(self, frontend_node: dict):
# Check if all outputs are either Output or a valid Output model
for index, output in enumerate(frontend_node["outputs"]):
if isinstance(output, dict):
try:
_output = Output(**output)
self._set_output_return_type(_output)
_output_dict = _output.model_dump()
except ValidationError as e:
msg = f"Invalid output: {e}"
raise ValueError(msg) from e
elif isinstance(output, Output):
# we need to serialize it
self._set_output_return_type(output)
_output_dict = output.model_dump()
else:
msg = f"Invalid output type: {type(output)}"
raise TypeError(msg)
frontend_node["outputs"][index] = _output_dict
return frontend_node
def update_outputs(self, frontend_node: dict, field_name: str, field_value: Any) -> dict: # noqa: ARG002
"""Default implementation for updating outputs based on field changes.
Subclasses can override this to modify outputs based on field_name and field_value.
"""
return frontend_node
def _set_output_types(self, outputs: list[Output]) -> None:
for output in outputs:
self._set_output_return_type(output)
def _set_output_return_type(self, output: Output) -> None:
if output.method is None:
msg = f"Output {output.name} does not have a method"
raise ValueError(msg)
return_types = self._get_method_return_type(output.method)
output.add_types(return_types)
output.set_selected()
def _set_output_required_inputs(self) -> None:
for output in self.outputs:

View file

@ -228,7 +228,7 @@ class CustomComponent(BaseComponent):
field_value: Any,
field_name: str | None = None,
):
build_config[field_name] = field_value
build_config[field_name]["value"] = field_value
return build_config
@property

View file

@ -0,0 +1,41 @@
# from langflow.field_typing import Data
from typing import Any
from langflow.custom import Component
from langflow.io import BoolInput, MessageTextInput, Output
from langflow.schema import Data
class DynamicOutputComponent(Component):
display_name = "Dynamic Output Component"
description = "Use as a template to create your own component."
documentation: str = "http://docs.langflow.org/components/custom"
icon = "custom_components"
name = "DynamicOutputComponent"
inputs = [
MessageTextInput(name="input_value", display_name="Input Value", value="Hello, World!"),
BoolInput(name="show_output", display_name="Show Output", value=True, real_time_refresh=True),
]
outputs = [
Output(display_name="Output", name="output", method="build_output"),
]
def update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
if field_name == "show_output":
if field_value:
frontend_node["outputs"].append(
Output(display_name="Tool Output", name="tool_output", method="build_output")
)
else:
# remove the output
frontend_node["outputs"] = [
output for output in frontend_node["outputs"] if output["name"] != "tool_output"
]
return frontend_node
def build_output(self) -> Data:
data = Data(value=self.input_value)
self.status = data
return data

View file

@ -1,5 +1,16 @@
import asyncio
from pathlib import Path
from typing import Any
from fastapi import status
from httpx import AsyncClient
from langflow.api.v1.schemas import UpdateCustomComponentRequest
async def get_dynamic_output_component_code():
return await asyncio.to_thread(
Path("src/backend/tests/data/dynamic_output_component.py").read_text, encoding="utf-8"
)
async def test_get_version(client: AsyncClient):
@ -23,3 +34,21 @@ async def test_get_config(client: AsyncClient):
assert "auto_saving" in result, "The dictionary must contain a key called 'auto_saving'"
assert "health_check_max_retries" in result, "The dictionary must contain a 'health_check_max_retries' key"
assert "max_file_size_upload" in result, "The dictionary must contain a key called 'max_file_size_upload'"
async def test_update_component_outputs(client: AsyncClient, logged_in_headers: dict):
code = await get_dynamic_output_component_code()
frontend_node: dict[str, Any] = {"outputs": []}
request = UpdateCustomComponentRequest(
code=code,
frontend_node=frontend_node,
field="show_output",
field_value=True,
template={},
)
response = await client.post("api/v1/custom_component/update", json=request.model_dump(), headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_200_OK
output_names = [output["name"] for output in result["outputs"]]
assert "tool_output" in output_names

View file

@ -3,11 +3,7 @@ import {
SAVE_DEBOUNCE_TIME,
TITLE_ERROR_UPDATING_COMPONENT,
} from "@/constants/constants";
import {
APIClassType,
APITemplateType,
ResponseErrorDetailAPI,
} from "@/types/api";
import { APIClassType, ResponseErrorDetailAPI } from "@/types/api";
import { UseMutationResult } from "@tanstack/react-query";
import { cloneDeep, debounce } from "lodash";
@ -17,7 +13,7 @@ export const mutateTemplate = debounce(
node: APIClassType,
setNodeClass,
postTemplateValue: UseMutationResult<
APITemplateType | undefined,
APIClassType | undefined,
ResponseErrorDetailAPI,
any
>,
@ -29,7 +25,8 @@ export const mutateTemplate = debounce(
value: newValue,
});
if (newTemplate) {
newNode.template = newTemplate;
newNode.template = newTemplate.template;
newNode.outputs = newTemplate.outputs;
}
setNodeClass(newNode);
} catch (e) {

View file

@ -1,8 +1,4 @@
import {
APIClassType,
APITemplateType,
ResponseErrorDetailAPI,
} from "@/types/api";
import { APIClassType, ResponseErrorDetailAPI } from "@/types/api";
import { UseMutationResult } from "@tanstack/react-query";
import { useEffect } from "react";
import useAlertStore from "../../stores/alertStore";
@ -13,7 +9,7 @@ const useFetchDataOnMount = (
setNodeClass: (node: APIClassType) => void,
name: string,
postTemplateValue: UseMutationResult<
APITemplateType | undefined,
APIClassType | undefined,
ResponseErrorDetailAPI,
any
>,

View file

@ -80,7 +80,7 @@ const useHandleOnNewValue = ({
});
};
if (shouldUpdate && changes.value) {
if (shouldUpdate && changes.value !== undefined) {
mutateTemplate(
changes.value,
newNode,

View file

@ -1,6 +1,5 @@
import {
APIClassType,
APITemplateType,
ResponseErrorDetailAPI,
useMutationFunctionType,
} from "@/types/api";
@ -22,14 +21,14 @@ interface IPostTemplateValueParams {
export const usePostTemplateValue: useMutationFunctionType<
IPostTemplateValueParams,
IPostTemplateValue,
APITemplateType | undefined,
APIClassType,
ResponseErrorDetailAPI
> = ({ parameterId, nodeId, node }, options?) => {
const { mutate } = UseRequestProcessor();
const postTemplateValueFn = async (
payload: IPostTemplateValue,
): Promise<APITemplateType | undefined> => {
): Promise<APIClassType | undefined> => {
const template = node.template;
if (!template) return;
@ -44,11 +43,11 @@ export const usePostTemplateValue: useMutationFunctionType<
},
);
return response.data.template;
return response.data;
};
const mutation: UseMutationResult<
APITemplateType | undefined,
APIClassType,
ResponseErrorDetailAPI,
IPostTemplateValue
> = mutate(