feat: run flow component with tool mode option to run a flow as a tool (#5518)

* Update calculator.py

* Update json_cleaner.py

* [autofix.ci] apply automated fixes

* updated init

* updated format

format

* update component

sample component update

* Update flow_orchestrator.py

* test files

tests

* update in flow Orchestrator

* Update flow_orchestrator.py

solves issues with agents and session ids

* Update flow_orchestrator.py

* update to FlowOrchetstor

update to FlowOrchetstor

* draft Commit

* updated to run_flow

* updates

* refactor code

* [autofix.ci] apply automated fixes

* default_name of tool to be the name of the flow

* add tool_mode default activation

* [autofix.ci] apply automated fixes

* updates to build schema

* [autofix.ci] apply automated fixes

* update to schema and run flow

* cleanup

* Update run_flow.py

* [autofix.ci] apply automated fixes

* fix(run_flow.py): update condition to check if field_order exists in field_template for improved input validation

* Update run_flow.py

* [autofix.ci] apply automated fixes

* debug updates

* updates in flow tool outputs and tool mode

* code cleanup

* Update component_tool.py

* Update run_flow.py

* updated description

added beta tag as per Simons Suggestion

* [autofix.ci] apply automated fixes

* Update component_tool.py

* Update schema.py

* [autofix.ci] apply automated fixes

* feat: Add ToolModeMixin to CodeInput class

* Update component.py

* updated tests

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Edwin Jose 2025-01-17 18:12:56 -05:00 committed by GitHub
commit 040a84ed52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 437 additions and 114 deletions

View file

@ -748,7 +748,7 @@ async def custom_component_update(
)
component_node["template"] = updated_build_config
if isinstance(cc_instance, Component):
cc_instance.run_and_validate_update_outputs(
await cc_instance.run_and_validate_update_outputs(
frontend_node=component_node,
field_name=code_request.field,
field_value=code_request.field_value,

View file

@ -233,7 +233,7 @@ class LCToolsAgentComponent(LCAgentComponent):
tools_names = ", ".join([tool.name for tool in self.tools])
return tools_names
def to_toolkit(self) -> list[Tool]:
async def to_toolkit(self) -> list[Tool]:
component_toolkit = _get_component_toolkit()
tools_names = self._build_tools_names()
agent_description = self.get_tool_description()

View file

@ -11,7 +11,7 @@ from loguru import logger
from pydantic import BaseModel
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
from langflow.io.schema import create_input_schema
from langflow.io.schema import create_input_schema, create_input_schema_from_dict
from langflow.schema.data import Data
from langflow.schema.message import Message
@ -25,6 +25,7 @@ if TYPE_CHECKING:
from langflow.inputs.inputs import InputTypes
from langflow.io import Output
from langflow.schema.content_block import ContentBlock
from langflow.schema.dotdict import dotdict
TOOL_TYPES_SET = {"Tool", "BaseTool", "StructuredTool"}
@ -169,7 +170,11 @@ class ComponentToolkit:
self.metadata = metadata
def get_tools(
self, tool_name: str | None = None, tool_description: str | None = None, callbacks: Callbacks | None = None
self,
tool_name: str | None = None,
tool_description: str | None = None,
callbacks: Callbacks | None = None,
flow_mode_inputs: list[dotdict] | None = None,
) -> list[BaseTool]:
tools = []
for output in self.component.outputs:
@ -183,7 +188,12 @@ class ComponentToolkit:
output_method: Callable = getattr(self.component, output.method)
args_schema = None
tool_mode_inputs = [_input for _input in self.component.inputs if getattr(_input, "tool_mode", False)]
if output.required_inputs:
if flow_mode_inputs:
args_schema = create_input_schema_from_dict(
inputs=flow_mode_inputs,
param_key="flow_tweak_data",
)
elif output.required_inputs:
inputs = [
self.component._inputs[input_name]
for input_name in output.required_inputs
@ -239,8 +249,16 @@ class ComponentToolkit:
)
if len(tools) == 1 and (tool_name or tool_description):
tool = tools[0]
tool.name = tool_name or tool.name
tool.name = _format_tool_name(str(tool_name)) or tool.name
tool.description = tool_description or tool.description
tool.tags = [tool.name]
elif flow_mode_inputs and (tool_name or tool_description):
for tool in tools:
tool.name = _format_tool_name(str(tool_name) + "_" + str(tool.name)) or tool.name
tool.description = (
str(tool_description) + " Output details: " + str(tool.description)
) or tool.description
tool.tags = [tool.name]
elif tool_name or tool_description:
msg = (
"When passing a tool name or description, there must be only one tool, "

View file

@ -0,0 +1,216 @@
from abc import abstractmethod
from loguru import logger
from langflow.base.tools.constants import TOOLS_METADATA_INPUT_NAME
from langflow.custom import Component
from langflow.custom.custom_component.component import _get_component_toolkit
from langflow.field_typing import Tool
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.helpers.flow import get_flow_inputs
from langflow.inputs.inputs import (
DropdownInput,
InputTypes,
MessageInput,
)
from langflow.schema import Data, dotdict
from langflow.schema.dataframe import DataFrame
from langflow.schema.message import Message
from langflow.template import Output
class RunFlowBaseComponent(Component):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_tool_output = True
_base_inputs: list[InputTypes] = [
DropdownInput(
name="flow_name_selected",
display_name="Flow Name",
info="The name of the flow to run.",
options=[],
real_time_refresh=True,
refresh_button=True,
value=None,
),
MessageInput(
name="session_id",
display_name="Session ID",
info="The session ID to run the flow in.",
value="",
advanced=True,
),
]
_base_outputs: list[Output] = [
Output(name="flow_outputs_data", display_name="Flow Data Output", method="data_output", hidden=True),
Output(
name="flow_outputs_dataframe", display_name="Flow Dataframe Output", method="dataframe_output", hidden=True
),
Output(name="flow_outputs_message", display_name="Flow Message Output", method="message_output"),
]
default_keys = ["code", "_type", "flow_name_selected", "session_id"]
FLOW_INPUTS: list[dotdict] = []
flow_tweak_data: dict = {}
@abstractmethod
async def run_flow_with_tweaks(self) -> list[Data]:
"""Run the flow with tweaks."""
async def data_output(self) -> Data:
"""Return the data output."""
run_outputs = await self.run_flow_with_tweaks()
first_output = run_outputs[0]
if isinstance(first_output, Data):
return first_output
message_data = first_output.outputs[0].results["message"].data
return Data(data=message_data)
async def dataframe_output(self) -> DataFrame:
"""Return the dataframe output."""
run_outputs = await self.run_flow_with_tweaks()
first_output = run_outputs[0]
if isinstance(first_output, DataFrame):
return first_output
message_data = first_output.outputs[0].results["message"].data
return DataFrame(data=message_data if isinstance(message_data, list) else [message_data])
async def message_output(self) -> Message:
"""Return the message output."""
run_outputs = await self.run_flow_with_tweaks()
message_result = run_outputs[0].outputs[0].results["message"]
if isinstance(message_result, Message):
return message_result
if isinstance(message_result, str):
return Message(content=message_result)
return Message(content=message_result.data["text"])
async def get_flow_names(self) -> list[str]:
# TODO: get flfow ID with flow name
flow_data = await self.alist_flows()
return [flow_data.data["name"] for flow_data in flow_data]
async def get_flow(self, flow_name_selected: str) -> Data | None:
# get flow from flow id
flow_datas = await self.alist_flows()
for flow_data in flow_datas:
if flow_data.data["name"] == flow_name_selected:
return flow_data
return None
async def get_graph(self, flow_name_selected: str | None = None) -> Graph:
if flow_name_selected:
flow_data = await self.get_flow(flow_name_selected)
if flow_data:
return Graph.from_payload(flow_data.data["data"])
msg = "Flow not found"
raise ValueError(msg)
# Ensure a Graph is always returned or an exception is raised
msg = "No valid flow JSON or flow name selected."
raise ValueError(msg)
def get_new_fields_from_graph(self, graph: Graph) -> list[dotdict]:
inputs = get_flow_inputs(graph)
return self.get_new_fields(inputs)
def update_build_config_from_graph(self, build_config: dotdict, graph: Graph):
try:
# Get all inputs from the graph
new_fields = self.get_new_fields_from_graph(graph)
old_fields = self.get_old_fields(build_config, new_fields)
self.delete_fields(build_config, old_fields)
build_config = self.add_new_fields(build_config, new_fields)
except Exception as e:
msg = "Error updating build config from graph"
logger.exception(msg)
raise RuntimeError(msg) from e
def get_new_fields(self, inputs_vertex: list[Vertex]) -> list[dotdict]:
new_fields: list[dotdict] = []
for vertex in inputs_vertex:
field_template = vertex.data.get("node", {}).get("template", {})
field_order = vertex.data.get("node", {}).get("field_order", [])
if field_order and field_template:
new_vertex_inputs = [
dotdict(
{
**field_template[input_name],
"display_name": vertex.display_name + " - " + field_template[input_name]["display_name"],
"name": f"{vertex.id}~{input_name}",
"tool_mode": not (field_template[input_name].get("advanced", False)),
}
)
for input_name in field_order
if input_name in field_template
]
new_fields += new_vertex_inputs
return new_fields
def add_new_fields(self, build_config: dotdict, new_fields: list[dotdict]) -> dotdict:
"""Add new fields to the build_config."""
for field in new_fields:
build_config[field["name"]] = field
return build_config
def delete_fields(self, build_config: dotdict, fields: dict | list[str]) -> None:
"""Delete specified fields from build_config."""
if isinstance(fields, dict):
fields = list(fields.keys())
for field in fields:
build_config.pop(field, None)
def get_old_fields(self, build_config: dotdict, new_fields: list[dotdict]) -> list[str]:
"""Get fields that are in build_config but not in new_fields."""
return [
field
for field in build_config
if field not in [new_field["name"] for new_field in new_fields] + self.default_keys
]
async def get_required_data(self, flow_name_selected):
self.flow_data = await self.alist_flows()
for flow_data in self.flow_data:
if flow_data.data["name"] == flow_name_selected:
graph = Graph.from_payload(flow_data.data["data"])
new_fields = self.get_new_fields_from_graph(graph)
new_fields = self.update_input_types(new_fields)
return flow_data.data["description"], [field for field in new_fields if field.get("tool_mode") is True]
return None
def update_input_types(self, fields: list[dotdict]) -> list[dotdict]:
for field in fields:
if isinstance(field, dict):
if field.get("input_types") is None:
field["input_types"] = []
elif hasattr(field, "input_types") and field.input_types is None:
field.input_types = []
return fields
async def to_toolkit(self) -> list[Tool]:
component_toolkit = _get_component_toolkit()
flow_description, tool_mode_inputs = await self.get_required_data(self.flow_name_selected)
# # convert list of dicts to list of dotdicts
tool_mode_inputs = [dotdict(field) for field in tool_mode_inputs]
tools = component_toolkit(component=self).get_tools(
tool_name=f"{self.flow_name_selected}_tool",
tool_description=(
f"Tool designed to execute the flow '{self.flow_name_selected}'. Flow details: {flow_description}."
),
callbacks=self.get_langchain_callbacks(),
flow_mode_inputs=tool_mode_inputs,
)
if hasattr(self, TOOLS_METADATA_INPUT_NAME):
tools = component_toolkit(component=self, metadata=self.tools_metadata).update_tools_metadata(tools=tools)
# self.status = tools
return tools

View file

@ -86,7 +86,8 @@ class AgentComponent(ToolCallingAgentComponent):
if not isinstance(self.tools, list): # type: ignore[has-type]
self.tools = []
# Convert CurrentDateComponent to a StructuredTool
current_date_tool = CurrentDateComponent().to_toolkit()[0]
current_date_tool = (await CurrentDateComponent().to_toolkit()).pop(0)
# current_date_tool = CurrentDateComponent().to_toolkit()[0]
if isinstance(current_date_tool, StructuredTool):
self.tools.append(current_date_tool)
else:

View file

@ -1,74 +1,69 @@
from typing import TYPE_CHECKING, Any
from typing import Any
from typing_extensions import override
from loguru import logger
from langflow.base.flow_processing.utils import build_data_from_run_outputs
from langflow.custom import Component
from langflow.io import DropdownInput, MessageTextInput, NestedDictInput, Output
from langflow.schema import Data, dotdict
if TYPE_CHECKING:
from langflow.graph.schema import RunOutputs
from langflow.base.tools.run_flow import RunFlowBaseComponent
from langflow.helpers.flow import run_flow
from langflow.schema import dotdict
class RunFlowComponent(Component):
class RunFlowComponent(RunFlowBaseComponent):
display_name = "Run Flow"
description = "A component to run a flow."
description = "Creates a tool component from a Flow that takes all its inputs and runs it."
beta = True
name = "RunFlow"
legacy: bool = True
icon = "workflow"
icon = "Workflow"
async def get_flow_names(self) -> list[str]:
flow_data = await self.alist_flows()
return [flow_data.data["name"] for flow_data in flow_data]
inputs = RunFlowBaseComponent._base_inputs
outputs = RunFlowBaseComponent._base_outputs
@override
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "flow_name":
build_config["flow_name"]["options"] = await self.get_flow_names()
if field_name == "flow_name_selected":
build_config["flow_name_selected"]["options"] = await self.get_flow_names()
missing_keys = [key for key in self.default_keys if key not in build_config]
if missing_keys:
msg = f"Missing required keys in build_config: {missing_keys}"
raise ValueError(msg)
if field_value is not None:
try:
graph = await self.get_graph(field_value)
build_config = self.update_build_config_from_graph(build_config, graph)
except Exception as e:
msg = f"Error building graph for flow {field_value}"
logger.exception(msg)
raise RuntimeError(msg) from e
return build_config
inputs = [
MessageTextInput(
name="input_value",
display_name="Input Value",
info="The input value to be processed by the flow.",
),
DropdownInput(
name="flow_name",
display_name="Flow Name",
info="The name of the flow to run.",
options=[],
refresh_button=True,
),
NestedDictInput(
name="tweaks",
display_name="Tweaks",
info="Tweaks to apply to the flow.",
),
]
async def run_flow_with_tweaks(self):
tweaks: dict = {}
outputs = [
Output(display_name="Run Outputs", name="run_outputs", method="generate_results"),
]
flow_name_selected = self._attributes.get("flow_name_selected")
parsed_flow_tweak_data = self._attributes.get("flow_tweak_data", {})
if not isinstance(parsed_flow_tweak_data, dict):
parsed_flow_tweak_data = parsed_flow_tweak_data.dict()
async def generate_results(self) -> list[Data]:
if "flow_name" not in self._attributes or not self._attributes["flow_name"]:
msg = "Flow name is required"
raise ValueError(msg)
flow_name = self._attributes["flow_name"]
results: list[RunOutputs | None] = await self.run_flow(
inputs={"input_value": self.input_value}, flow_name=flow_name, tweaks=self.tweaks
)
if isinstance(results, list):
data = []
for result in results:
if result:
data.extend(build_data_from_run_outputs(result))
if parsed_flow_tweak_data != {}:
for field in parsed_flow_tweak_data:
if "~" in field:
[node, name] = field.split("~")
if node not in tweaks:
tweaks[node] = {}
tweaks[node][name] = parsed_flow_tweak_data[field]
else:
data = build_data_from_run_outputs()(results)
self.status = data
return data
for field in self._attributes:
if field not in self.default_keys and "~" in field:
[node, name] = field.split("~")
if node not in tweaks:
tweaks[node] = {}
tweaks[node][name] = self._attributes[field]
# import pdb; pdb.set_trace()
return await run_flow(
inputs=None,
output_type="all",
flow_id=None,
flow_name=flow_name_selected,
tweaks=tweaks,
user_id=str(self.user_id),
# run_id=self.graph.run_id,
session_id=self.graph.session_id or self.session_id,
)

View file

@ -406,14 +406,15 @@ class Component(CustomComponent):
self._validate_inputs(params)
self._validate_outputs()
def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
async def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
frontend_node = self.update_outputs(frontend_node, field_name, field_value)
if field_name == "tool_mode" or frontend_node.get("tool_mode"):
is_tool_mode = field_value or frontend_node.get("tool_mode")
frontend_node["outputs"] = [self._build_tool_output()] if is_tool_mode else frontend_node["outputs"]
if is_tool_mode:
frontend_node.setdefault("template", {})
frontend_node["template"][TOOLS_METADATA_INPUT_NAME] = self._build_tools_metadata_input().to_dict()
tools_metadata_input = await self._build_tools_metadata_input()
frontend_node["template"][TOOLS_METADATA_INPUT_NAME] = tools_metadata_input.to_dict()
elif "template" in frontend_node:
frontend_node["template"].pop(TOOLS_METADATA_INPUT_NAME, None)
self.tools_metadata = frontend_node.get("template", {}).get(TOOLS_METADATA_INPUT_NAME, {}).get("value")
@ -994,7 +995,7 @@ class Component(CustomComponent):
def _get_fallback_input(self, **kwargs):
return Input(**kwargs)
def to_toolkit(self) -> list[Tool]:
async def to_toolkit(self) -> list[Tool]:
component_toolkit = _get_component_toolkit()
tools = component_toolkit(component=self).get_tools(callbacks=self.get_langchain_callbacks())
if hasattr(self, TOOLS_METADATA_INPUT_NAME):
@ -1212,8 +1213,8 @@ class Component(CustomComponent):
def _build_tool_output(self) -> Output:
return Output(name=TOOL_OUTPUT_NAME, display_name=TOOL_OUTPUT_DISPLAY_NAME, method="to_toolkit", types=["Tool"])
def _build_tools_metadata_input(self):
tools = self.to_toolkit()
async def _build_tools_metadata_input(self):
tools = await self.to_toolkit()
tool_data = (
self.tools_metadata
if hasattr(self, TOOLS_METADATA_INPUT_NAME)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -86,7 +86,7 @@ class PromptInput(BaseInputMixin, ListableInputMixin, InputTraceMixin, ToolModeM
field_type: SerializableFieldTypes = FieldTypes.PROMPT
class CodeInput(BaseInputMixin, ListableInputMixin, InputTraceMixin):
class CodeInput(BaseInputMixin, ListableInputMixin, InputTraceMixin, ToolModeMixin):
field_type: SerializableFieldTypes = FieldTypes.CODE

View file

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Literal
from pydantic import BaseModel, Field, create_model
from langflow.inputs.inputs import FieldTypes
from langflow.schema.dotdict import dotdict
_convert_field_type_to_type: dict[FieldTypes, type] = {
FieldTypes.TEXT: str,
@ -62,3 +63,51 @@ def create_input_schema(inputs: list["InputTypes"]) -> type[BaseModel]:
model = create_model("InputSchema", **fields)
model.model_rebuild()
return model
def create_input_schema_from_dict(inputs: list[dotdict], param_key: str | None = None) -> type[BaseModel]:
if not isinstance(inputs, list):
msg = "inputs must be a list of Inputs"
raise TypeError(msg)
fields = {}
for input_model in inputs:
# Create a Pydantic Field for each input field
field_type = input_model.type
if hasattr(input_model, "options") and isinstance(input_model.options, list) and input_model.options:
literal_string = f"Literal{input_model.options}"
# validate that the literal_string is a valid literal
field_type = eval(literal_string, {"Literal": Literal}) # noqa: S307
if hasattr(input_model, "is_list") and input_model.is_list:
field_type = list[field_type] # type: ignore[valid-type]
if input_model.name:
name = input_model.name.replace("_", " ").title()
elif input_model.display_name:
name = input_model.display_name
else:
msg = "Input name or display_name is required"
raise ValueError(msg)
field_dict = {
"title": name,
"description": input_model.info or "",
}
if input_model.required is False:
field_dict["default"] = input_model.value # type: ignore[assignment]
pydantic_field = Field(**field_dict)
fields[input_model.name] = (field_type, pydantic_field)
# Wrap fields in a dictionary with the key as param_key
if param_key is not None:
# Create an inner model with the fields
inner_model = create_model("InnerModel", **fields)
# Ensure the model is wrapped correctly in a dictionary
# model = create_model("InputSchema", **{param_key: (inner_model, Field(default=..., description=description))})
model = create_model("InputSchema", **{param_key: (inner_model, ...)})
else:
# Create and return the InputSchema model
model = create_model("InputSchema", **fields)
model.model_rebuild()
return model

View file

@ -1,3 +1,5 @@
import logging
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, field_serializer
from pydantic.v1 import BaseModel as V1BaseModel
@ -5,6 +7,8 @@ from pydantic_core import PydanticSerializationError
from langflow.schema.log import LoggableType
logger = logging.getLogger(__name__)
class Log(BaseModel):
name: str
@ -13,21 +17,23 @@ class Log(BaseModel):
@field_serializer("message")
def serialize_message(self, value):
# We need to make sure everything inside the message has been serialized
if isinstance(value, dict):
return {key: self.serialize_message(value[key]) for key in value}
if isinstance(value, list):
return [self.serialize_message(item) for item in value]
# To json is for LangChain Serializable objects
if hasattr(value, "dict") and isinstance(value, V1BaseModel):
# This is for Pydantic V1 models
return value.dict()
if hasattr(value, "to_json"):
return value.to_json()
if isinstance(value, BaseModel):
return value.model_dump(exclude_none=True)
try:
# We need to make sure everything inside the message has been serialized
if isinstance(value, dict):
return {key: self.serialize_message(value[key]) for key in value}
if isinstance(value, list):
return [self.serialize_message(item) for item in value]
# To json is for LangChain Serializable objects
if hasattr(value, "dict") and isinstance(value, V1BaseModel):
# This is for Pydantic V1 models
return value.dict()
if hasattr(value, "to_json"):
return value.to_json()
if isinstance(value, BaseModel):
return value.model_dump(exclude_none=True)
value = jsonable_encoder(value)
except UnicodeDecodeError:
return str(value) # Fallback to string representation
except PydanticSerializationError:
return str(value)
return value

View file

@ -0,0 +1,33 @@
from langflow.io.schema import create_input_schema_from_dict
from langflow.schema.dotdict import dotdict
def test_create_schema():
sample_input = [
{
"_input_type": "MultilineInput",
"advanced": False,
"display_name": "Chat Input - Text",
"dynamic": False,
"info": "Message to be passed as input.",
"input_types": ["Message"],
"list": False,
"load_from_db": False,
"multiline": True,
"name": "ChatInput-xNZ0a|input_value",
"placeholder": "",
"required": False,
"show": True,
"title_case": False,
"tool_mode": True,
"trace_as_input": True,
"trace_as_metadata": True,
"type": "str",
"value": "add 1+1",
}
]
# convert to dotdict
# change the key type
sample_input = [dotdict(field) for field in sample_input]
schema = create_input_schema_from_dict(sample_input)
assert schema is not None

View file

@ -5,11 +5,11 @@ from langflow.components.agents.agent import AgentComponent
from langflow.components.tools.calculator import CalculatorToolComponent
def test_component_to_toolkit():
async def test_component_to_toolkit():
calculator_component = CalculatorToolComponent()
agent_component = AgentComponent().set(tools=[calculator_component])
tools = agent_component.to_toolkit()
tools = await agent_component.to_toolkit()
assert len(tools) == 1
tool = tools[0]

View file

@ -4,7 +4,7 @@ from langflow.custom.custom_component.component import Component
class TestComponentOutputs:
def test_run_and_validate_update_outputs_tool_mode(self):
async def test_run_and_validate_update_outputs_tool_mode(self):
"""Test run_and_validate_update_outputs with tool_mode field."""
class TestComponent(Component):
@ -33,7 +33,7 @@ class TestComponentOutputs:
}
# Test enabling tool mode
updated_node = component.run_and_validate_update_outputs(
updated_node = await component.run_and_validate_update_outputs(
frontend_node=frontend_node.copy(), # Use a copy to avoid modifying original
field_name="tool_mode",
field_value=True,
@ -45,7 +45,7 @@ class TestComponentOutputs:
assert updated_node["outputs"][0]["display_name"] == TOOL_OUTPUT_DISPLAY_NAME
# Test disabling tool mode - use the original frontend node
updated_node = component.run_and_validate_update_outputs(
updated_node = await component.run_and_validate_update_outputs(
frontend_node={"outputs": original_outputs.copy()}, # Use original outputs
field_name="tool_mode",
field_value=False,
@ -60,7 +60,7 @@ class TestComponentOutputs:
assert "types" in updated_node["outputs"][0]
assert "selected" in updated_node["outputs"][0]
def test_run_and_validate_update_outputs_invalid_output(self):
async def test_run_and_validate_update_outputs_invalid_output(self):
"""Test run_and_validate_update_outputs with invalid output structure."""
class TestComponent(Component):
@ -74,11 +74,11 @@ class TestComponentOutputs:
# Test validation fails for invalid output
with pytest.raises(ValueError, match="Invalid output: 1 validation error for Output"):
component.run_and_validate_update_outputs(
await component.run_and_validate_update_outputs(
frontend_node=frontend_node, field_name="some_field", field_value="some_value"
)
def test_run_and_validate_update_outputs_custom_update(self):
async def test_run_and_validate_update_outputs_custom_update(self):
"""Test run_and_validate_update_outputs with custom update logic."""
class CustomComponent(Component):
@ -111,7 +111,7 @@ class TestComponentOutputs:
frontend_node = {"outputs": []}
# Test custom update logic
updated_node = component.run_and_validate_update_outputs(
updated_node = await component.run_and_validate_update_outputs(
frontend_node=frontend_node, field_name="custom_field", field_value="custom_value"
)
@ -122,14 +122,14 @@ class TestComponentOutputs:
assert "types" in updated_node["outputs"][0]
assert "selected" in updated_node["outputs"][0]
def test_run_and_validate_update_outputs_with_existing_tool_output(self):
async def test_run_and_validate_update_outputs_with_existing_tool_output(self):
"""Test run_and_validate_update_outputs when tool output already exists."""
class TestComponent(Component):
def build(self) -> None:
pass
def to_toolkit(self) -> list:
async def to_toolkit(self) -> list:
"""Method that returns a list of tools."""
return []
@ -154,7 +154,7 @@ class TestComponentOutputs:
}
# Test enabling tool mode doesn't duplicate tool output
updated_node = component.run_and_validate_update_outputs(
updated_node = await component.run_and_validate_update_outputs(
frontend_node=frontend_node, field_name="tool_mode", field_value=True
)
@ -164,7 +164,7 @@ class TestComponentOutputs:
assert "types" in updated_node["outputs"][0]
assert "selected" in updated_node["outputs"][0]
def test_run_and_validate_update_outputs_with_multiple_outputs(self):
async def test_run_and_validate_update_outputs_with_multiple_outputs(self):
"""Test run_and_validate_update_outputs with multiple outputs."""
class TestComponent(Component):
@ -203,7 +203,7 @@ class TestComponentOutputs:
frontend_node = {"outputs": []}
# Test adding multiple outputs
updated_node = component.run_and_validate_update_outputs(
updated_node = await component.run_and_validate_update_outputs(
frontend_node=frontend_node, field_name="add_output", field_value=True
)
@ -217,7 +217,7 @@ class TestComponentOutputs:
assert set(output["types"]) == {"Text"}
assert output["selected"] == "Text"
def test_run_and_validate_update_outputs_output_validation(self):
async def test_run_and_validate_update_outputs_output_validation(self):
"""Test output validation in run_and_validate_update_outputs."""
class TestComponent(Component):
@ -236,10 +236,14 @@ class TestComponentOutputs:
}
with pytest.raises(AttributeError, match="nonexistent_method not found in TestComponent"):
component.run_and_validate_update_outputs(frontend_node=invalid_node, field_name="test", field_value=True)
await component.run_and_validate_update_outputs(
frontend_node=invalid_node, field_name="test", field_value=True
)
# Test missing method case
invalid_node = {"outputs": [{"name": "test", "type": "str", "display_name": "Test"}]}
with pytest.raises(ValueError, match="Output test does not have a method"):
component.run_and_validate_update_outputs(frontend_node=invalid_node, field_name="test", field_value=True)
await component.run_and_validate_update_outputs(
frontend_node=invalid_node, field_name="test", field_value=True
)