feat: run flow component with tool mode option to run a flow as a tool (#5518)
* Update calculator.py * Update json_cleaner.py * [autofix.ci] apply automated fixes * updated init * updated format format * update component sample component update * Update flow_orchestrator.py * test files tests * update in flow Orchestrator * Update flow_orchestrator.py solves issues with agents and session ids * Update flow_orchestrator.py * update to FlowOrchetstor update to FlowOrchetstor * draft Commit * updated to run_flow * updates * refactor code * [autofix.ci] apply automated fixes * default_name of tool to be the name of the flow * add tool_mode default activation * [autofix.ci] apply automated fixes * updates to build schema * [autofix.ci] apply automated fixes * update to schema and run flow * cleanup * Update run_flow.py * [autofix.ci] apply automated fixes * fix(run_flow.py): update condition to check if field_order exists in field_template for improved input validation * Update run_flow.py * [autofix.ci] apply automated fixes * debug updates * updates in flow tool outputs and tool mode * code cleanup * Update component_tool.py * Update run_flow.py * updated description added beta tag as per Simons Suggestion * [autofix.ci] apply automated fixes * Update component_tool.py * Update schema.py * [autofix.ci] apply automated fixes * feat: Add ToolModeMixin to CodeInput class * Update component.py * updated tests * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
f4a7d9d797
commit
040a84ed52
20 changed files with 437 additions and 114 deletions
|
|
@ -748,7 +748,7 @@ async def custom_component_update(
|
|||
)
|
||||
component_node["template"] = updated_build_config
|
||||
if isinstance(cc_instance, Component):
|
||||
cc_instance.run_and_validate_update_outputs(
|
||||
await cc_instance.run_and_validate_update_outputs(
|
||||
frontend_node=component_node,
|
||||
field_name=code_request.field,
|
||||
field_value=code_request.field_value,
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ class LCToolsAgentComponent(LCAgentComponent):
|
|||
tools_names = ", ".join([tool.name for tool in self.tools])
|
||||
return tools_names
|
||||
|
||||
def to_toolkit(self) -> list[Tool]:
|
||||
async def to_toolkit(self) -> list[Tool]:
|
||||
component_toolkit = _get_component_toolkit()
|
||||
tools_names = self._build_tools_names()
|
||||
agent_description = self.get_tool_description()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
|
||||
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
|
||||
from langflow.io.schema import create_input_schema
|
||||
from langflow.io.schema import create_input_schema, create_input_schema_from_dict
|
||||
from langflow.schema.data import Data
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
|
@ -25,6 +25,7 @@ if TYPE_CHECKING:
|
|||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.io import Output
|
||||
from langflow.schema.content_block import ContentBlock
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
TOOL_TYPES_SET = {"Tool", "BaseTool", "StructuredTool"}
|
||||
|
|
@ -169,7 +170,11 @@ class ComponentToolkit:
|
|||
self.metadata = metadata
|
||||
|
||||
def get_tools(
|
||||
self, tool_name: str | None = None, tool_description: str | None = None, callbacks: Callbacks | None = None
|
||||
self,
|
||||
tool_name: str | None = None,
|
||||
tool_description: str | None = None,
|
||||
callbacks: Callbacks | None = None,
|
||||
flow_mode_inputs: list[dotdict] | None = None,
|
||||
) -> list[BaseTool]:
|
||||
tools = []
|
||||
for output in self.component.outputs:
|
||||
|
|
@ -183,7 +188,12 @@ class ComponentToolkit:
|
|||
output_method: Callable = getattr(self.component, output.method)
|
||||
args_schema = None
|
||||
tool_mode_inputs = [_input for _input in self.component.inputs if getattr(_input, "tool_mode", False)]
|
||||
if output.required_inputs:
|
||||
if flow_mode_inputs:
|
||||
args_schema = create_input_schema_from_dict(
|
||||
inputs=flow_mode_inputs,
|
||||
param_key="flow_tweak_data",
|
||||
)
|
||||
elif output.required_inputs:
|
||||
inputs = [
|
||||
self.component._inputs[input_name]
|
||||
for input_name in output.required_inputs
|
||||
|
|
@ -239,8 +249,16 @@ class ComponentToolkit:
|
|||
)
|
||||
if len(tools) == 1 and (tool_name or tool_description):
|
||||
tool = tools[0]
|
||||
tool.name = tool_name or tool.name
|
||||
tool.name = _format_tool_name(str(tool_name)) or tool.name
|
||||
tool.description = tool_description or tool.description
|
||||
tool.tags = [tool.name]
|
||||
elif flow_mode_inputs and (tool_name or tool_description):
|
||||
for tool in tools:
|
||||
tool.name = _format_tool_name(str(tool_name) + "_" + str(tool.name)) or tool.name
|
||||
tool.description = (
|
||||
str(tool_description) + " Output details: " + str(tool.description)
|
||||
) or tool.description
|
||||
tool.tags = [tool.name]
|
||||
elif tool_name or tool_description:
|
||||
msg = (
|
||||
"When passing a tool name or description, there must be only one tool, "
|
||||
|
|
|
|||
216
src/backend/base/langflow/base/tools/run_flow.py
Normal file
216
src/backend/base/langflow/base/tools/run_flow.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
from abc import abstractmethod
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.base.tools.constants import TOOLS_METADATA_INPUT_NAME
|
||||
from langflow.custom import Component
|
||||
from langflow.custom.custom_component.component import _get_component_toolkit
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.helpers.flow import get_flow_inputs
|
||||
from langflow.inputs.inputs import (
|
||||
DropdownInput,
|
||||
InputTypes,
|
||||
MessageInput,
|
||||
)
|
||||
from langflow.schema import Data, dotdict
|
||||
from langflow.schema.dataframe import DataFrame
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template import Output
|
||||
|
||||
|
||||
class RunFlowBaseComponent(Component):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_tool_output = True
|
||||
|
||||
_base_inputs: list[InputTypes] = [
|
||||
DropdownInput(
|
||||
name="flow_name_selected",
|
||||
display_name="Flow Name",
|
||||
info="The name of the flow to run.",
|
||||
options=[],
|
||||
real_time_refresh=True,
|
||||
refresh_button=True,
|
||||
value=None,
|
||||
),
|
||||
MessageInput(
|
||||
name="session_id",
|
||||
display_name="Session ID",
|
||||
info="The session ID to run the flow in.",
|
||||
value="",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
_base_outputs: list[Output] = [
|
||||
Output(name="flow_outputs_data", display_name="Flow Data Output", method="data_output", hidden=True),
|
||||
Output(
|
||||
name="flow_outputs_dataframe", display_name="Flow Dataframe Output", method="dataframe_output", hidden=True
|
||||
),
|
||||
Output(name="flow_outputs_message", display_name="Flow Message Output", method="message_output"),
|
||||
]
|
||||
default_keys = ["code", "_type", "flow_name_selected", "session_id"]
|
||||
FLOW_INPUTS: list[dotdict] = []
|
||||
flow_tweak_data: dict = {}
|
||||
|
||||
@abstractmethod
|
||||
async def run_flow_with_tweaks(self) -> list[Data]:
|
||||
"""Run the flow with tweaks."""
|
||||
|
||||
async def data_output(self) -> Data:
|
||||
"""Return the data output."""
|
||||
run_outputs = await self.run_flow_with_tweaks()
|
||||
first_output = run_outputs[0]
|
||||
|
||||
if isinstance(first_output, Data):
|
||||
return first_output
|
||||
|
||||
message_data = first_output.outputs[0].results["message"].data
|
||||
return Data(data=message_data)
|
||||
|
||||
async def dataframe_output(self) -> DataFrame:
|
||||
"""Return the dataframe output."""
|
||||
run_outputs = await self.run_flow_with_tweaks()
|
||||
first_output = run_outputs[0]
|
||||
|
||||
if isinstance(first_output, DataFrame):
|
||||
return first_output
|
||||
|
||||
message_data = first_output.outputs[0].results["message"].data
|
||||
return DataFrame(data=message_data if isinstance(message_data, list) else [message_data])
|
||||
|
||||
async def message_output(self) -> Message:
|
||||
"""Return the message output."""
|
||||
run_outputs = await self.run_flow_with_tweaks()
|
||||
message_result = run_outputs[0].outputs[0].results["message"]
|
||||
|
||||
if isinstance(message_result, Message):
|
||||
return message_result
|
||||
|
||||
if isinstance(message_result, str):
|
||||
return Message(content=message_result)
|
||||
|
||||
return Message(content=message_result.data["text"])
|
||||
|
||||
async def get_flow_names(self) -> list[str]:
|
||||
# TODO: get flfow ID with flow name
|
||||
flow_data = await self.alist_flows()
|
||||
return [flow_data.data["name"] for flow_data in flow_data]
|
||||
|
||||
async def get_flow(self, flow_name_selected: str) -> Data | None:
|
||||
# get flow from flow id
|
||||
flow_datas = await self.alist_flows()
|
||||
for flow_data in flow_datas:
|
||||
if flow_data.data["name"] == flow_name_selected:
|
||||
return flow_data
|
||||
return None
|
||||
|
||||
async def get_graph(self, flow_name_selected: str | None = None) -> Graph:
|
||||
if flow_name_selected:
|
||||
flow_data = await self.get_flow(flow_name_selected)
|
||||
if flow_data:
|
||||
return Graph.from_payload(flow_data.data["data"])
|
||||
msg = "Flow not found"
|
||||
raise ValueError(msg)
|
||||
# Ensure a Graph is always returned or an exception is raised
|
||||
msg = "No valid flow JSON or flow name selected."
|
||||
raise ValueError(msg)
|
||||
|
||||
def get_new_fields_from_graph(self, graph: Graph) -> list[dotdict]:
|
||||
inputs = get_flow_inputs(graph)
|
||||
return self.get_new_fields(inputs)
|
||||
|
||||
def update_build_config_from_graph(self, build_config: dotdict, graph: Graph):
|
||||
try:
|
||||
# Get all inputs from the graph
|
||||
new_fields = self.get_new_fields_from_graph(graph)
|
||||
old_fields = self.get_old_fields(build_config, new_fields)
|
||||
self.delete_fields(build_config, old_fields)
|
||||
build_config = self.add_new_fields(build_config, new_fields)
|
||||
|
||||
except Exception as e:
|
||||
msg = "Error updating build config from graph"
|
||||
logger.exception(msg)
|
||||
raise RuntimeError(msg) from e
|
||||
|
||||
def get_new_fields(self, inputs_vertex: list[Vertex]) -> list[dotdict]:
|
||||
new_fields: list[dotdict] = []
|
||||
|
||||
for vertex in inputs_vertex:
|
||||
field_template = vertex.data.get("node", {}).get("template", {})
|
||||
field_order = vertex.data.get("node", {}).get("field_order", [])
|
||||
if field_order and field_template:
|
||||
new_vertex_inputs = [
|
||||
dotdict(
|
||||
{
|
||||
**field_template[input_name],
|
||||
"display_name": vertex.display_name + " - " + field_template[input_name]["display_name"],
|
||||
"name": f"{vertex.id}~{input_name}",
|
||||
"tool_mode": not (field_template[input_name].get("advanced", False)),
|
||||
}
|
||||
)
|
||||
for input_name in field_order
|
||||
if input_name in field_template
|
||||
]
|
||||
new_fields += new_vertex_inputs
|
||||
return new_fields
|
||||
|
||||
def add_new_fields(self, build_config: dotdict, new_fields: list[dotdict]) -> dotdict:
|
||||
"""Add new fields to the build_config."""
|
||||
for field in new_fields:
|
||||
build_config[field["name"]] = field
|
||||
return build_config
|
||||
|
||||
def delete_fields(self, build_config: dotdict, fields: dict | list[str]) -> None:
|
||||
"""Delete specified fields from build_config."""
|
||||
if isinstance(fields, dict):
|
||||
fields = list(fields.keys())
|
||||
for field in fields:
|
||||
build_config.pop(field, None)
|
||||
|
||||
def get_old_fields(self, build_config: dotdict, new_fields: list[dotdict]) -> list[str]:
|
||||
"""Get fields that are in build_config but not in new_fields."""
|
||||
return [
|
||||
field
|
||||
for field in build_config
|
||||
if field not in [new_field["name"] for new_field in new_fields] + self.default_keys
|
||||
]
|
||||
|
||||
async def get_required_data(self, flow_name_selected):
|
||||
self.flow_data = await self.alist_flows()
|
||||
for flow_data in self.flow_data:
|
||||
if flow_data.data["name"] == flow_name_selected:
|
||||
graph = Graph.from_payload(flow_data.data["data"])
|
||||
new_fields = self.get_new_fields_from_graph(graph)
|
||||
new_fields = self.update_input_types(new_fields)
|
||||
|
||||
return flow_data.data["description"], [field for field in new_fields if field.get("tool_mode") is True]
|
||||
return None
|
||||
|
||||
def update_input_types(self, fields: list[dotdict]) -> list[dotdict]:
|
||||
for field in fields:
|
||||
if isinstance(field, dict):
|
||||
if field.get("input_types") is None:
|
||||
field["input_types"] = []
|
||||
elif hasattr(field, "input_types") and field.input_types is None:
|
||||
field.input_types = []
|
||||
return fields
|
||||
|
||||
async def to_toolkit(self) -> list[Tool]:
|
||||
component_toolkit = _get_component_toolkit()
|
||||
flow_description, tool_mode_inputs = await self.get_required_data(self.flow_name_selected)
|
||||
# # convert list of dicts to list of dotdicts
|
||||
tool_mode_inputs = [dotdict(field) for field in tool_mode_inputs]
|
||||
tools = component_toolkit(component=self).get_tools(
|
||||
tool_name=f"{self.flow_name_selected}_tool",
|
||||
tool_description=(
|
||||
f"Tool designed to execute the flow '{self.flow_name_selected}'. Flow details: {flow_description}."
|
||||
),
|
||||
callbacks=self.get_langchain_callbacks(),
|
||||
flow_mode_inputs=tool_mode_inputs,
|
||||
)
|
||||
if hasattr(self, TOOLS_METADATA_INPUT_NAME):
|
||||
tools = component_toolkit(component=self, metadata=self.tools_metadata).update_tools_metadata(tools=tools)
|
||||
# self.status = tools
|
||||
return tools
|
||||
|
|
@ -86,7 +86,8 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
if not isinstance(self.tools, list): # type: ignore[has-type]
|
||||
self.tools = []
|
||||
# Convert CurrentDateComponent to a StructuredTool
|
||||
current_date_tool = CurrentDateComponent().to_toolkit()[0]
|
||||
current_date_tool = (await CurrentDateComponent().to_toolkit()).pop(0)
|
||||
# current_date_tool = CurrentDateComponent().to_toolkit()[0]
|
||||
if isinstance(current_date_tool, StructuredTool):
|
||||
self.tools.append(current_date_tool)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,74 +1,69 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import override
|
||||
from loguru import logger
|
||||
|
||||
from langflow.base.flow_processing.utils import build_data_from_run_outputs
|
||||
from langflow.custom import Component
|
||||
from langflow.io import DropdownInput, MessageTextInput, NestedDictInput, Output
|
||||
from langflow.schema import Data, dotdict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.base.tools.run_flow import RunFlowBaseComponent
|
||||
from langflow.helpers.flow import run_flow
|
||||
from langflow.schema import dotdict
|
||||
|
||||
|
||||
class RunFlowComponent(Component):
|
||||
class RunFlowComponent(RunFlowBaseComponent):
|
||||
display_name = "Run Flow"
|
||||
description = "A component to run a flow."
|
||||
description = "Creates a tool component from a Flow that takes all its inputs and runs it."
|
||||
beta = True
|
||||
name = "RunFlow"
|
||||
legacy: bool = True
|
||||
icon = "workflow"
|
||||
icon = "Workflow"
|
||||
|
||||
async def get_flow_names(self) -> list[str]:
|
||||
flow_data = await self.alist_flows()
|
||||
return [flow_data.data["name"] for flow_data in flow_data]
|
||||
inputs = RunFlowBaseComponent._base_inputs
|
||||
outputs = RunFlowBaseComponent._base_outputs
|
||||
|
||||
@override
|
||||
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = await self.get_flow_names()
|
||||
|
||||
if field_name == "flow_name_selected":
|
||||
build_config["flow_name_selected"]["options"] = await self.get_flow_names()
|
||||
missing_keys = [key for key in self.default_keys if key not in build_config]
|
||||
if missing_keys:
|
||||
msg = f"Missing required keys in build_config: {missing_keys}"
|
||||
raise ValueError(msg)
|
||||
if field_value is not None:
|
||||
try:
|
||||
graph = await self.get_graph(field_value)
|
||||
build_config = self.update_build_config_from_graph(build_config, graph)
|
||||
except Exception as e:
|
||||
msg = f"Error building graph for flow {field_value}"
|
||||
logger.exception(msg)
|
||||
raise RuntimeError(msg) from e
|
||||
return build_config
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="input_value",
|
||||
display_name="Input Value",
|
||||
info="The input value to be processed by the flow.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="flow_name",
|
||||
display_name="Flow Name",
|
||||
info="The name of the flow to run.",
|
||||
options=[],
|
||||
refresh_button=True,
|
||||
),
|
||||
NestedDictInput(
|
||||
name="tweaks",
|
||||
display_name="Tweaks",
|
||||
info="Tweaks to apply to the flow.",
|
||||
),
|
||||
]
|
||||
async def run_flow_with_tweaks(self):
|
||||
tweaks: dict = {}
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Run Outputs", name="run_outputs", method="generate_results"),
|
||||
]
|
||||
flow_name_selected = self._attributes.get("flow_name_selected")
|
||||
parsed_flow_tweak_data = self._attributes.get("flow_tweak_data", {})
|
||||
if not isinstance(parsed_flow_tweak_data, dict):
|
||||
parsed_flow_tweak_data = parsed_flow_tweak_data.dict()
|
||||
|
||||
async def generate_results(self) -> list[Data]:
|
||||
if "flow_name" not in self._attributes or not self._attributes["flow_name"]:
|
||||
msg = "Flow name is required"
|
||||
raise ValueError(msg)
|
||||
flow_name = self._attributes["flow_name"]
|
||||
|
||||
results: list[RunOutputs | None] = await self.run_flow(
|
||||
inputs={"input_value": self.input_value}, flow_name=flow_name, tweaks=self.tweaks
|
||||
)
|
||||
if isinstance(results, list):
|
||||
data = []
|
||||
for result in results:
|
||||
if result:
|
||||
data.extend(build_data_from_run_outputs(result))
|
||||
if parsed_flow_tweak_data != {}:
|
||||
for field in parsed_flow_tweak_data:
|
||||
if "~" in field:
|
||||
[node, name] = field.split("~")
|
||||
if node not in tweaks:
|
||||
tweaks[node] = {}
|
||||
tweaks[node][name] = parsed_flow_tweak_data[field]
|
||||
else:
|
||||
data = build_data_from_run_outputs()(results)
|
||||
|
||||
self.status = data
|
||||
return data
|
||||
for field in self._attributes:
|
||||
if field not in self.default_keys and "~" in field:
|
||||
[node, name] = field.split("~")
|
||||
if node not in tweaks:
|
||||
tweaks[node] = {}
|
||||
tweaks[node][name] = self._attributes[field]
|
||||
# import pdb; pdb.set_trace()
|
||||
return await run_flow(
|
||||
inputs=None,
|
||||
output_type="all",
|
||||
flow_id=None,
|
||||
flow_name=flow_name_selected,
|
||||
tweaks=tweaks,
|
||||
user_id=str(self.user_id),
|
||||
# run_id=self.graph.run_id,
|
||||
session_id=self.graph.session_id or self.session_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -406,14 +406,15 @@ class Component(CustomComponent):
|
|||
self._validate_inputs(params)
|
||||
self._validate_outputs()
|
||||
|
||||
def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
|
||||
async def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
|
||||
frontend_node = self.update_outputs(frontend_node, field_name, field_value)
|
||||
if field_name == "tool_mode" or frontend_node.get("tool_mode"):
|
||||
is_tool_mode = field_value or frontend_node.get("tool_mode")
|
||||
frontend_node["outputs"] = [self._build_tool_output()] if is_tool_mode else frontend_node["outputs"]
|
||||
if is_tool_mode:
|
||||
frontend_node.setdefault("template", {})
|
||||
frontend_node["template"][TOOLS_METADATA_INPUT_NAME] = self._build_tools_metadata_input().to_dict()
|
||||
tools_metadata_input = await self._build_tools_metadata_input()
|
||||
frontend_node["template"][TOOLS_METADATA_INPUT_NAME] = tools_metadata_input.to_dict()
|
||||
elif "template" in frontend_node:
|
||||
frontend_node["template"].pop(TOOLS_METADATA_INPUT_NAME, None)
|
||||
self.tools_metadata = frontend_node.get("template", {}).get(TOOLS_METADATA_INPUT_NAME, {}).get("value")
|
||||
|
|
@ -994,7 +995,7 @@ class Component(CustomComponent):
|
|||
def _get_fallback_input(self, **kwargs):
|
||||
return Input(**kwargs)
|
||||
|
||||
def to_toolkit(self) -> list[Tool]:
|
||||
async def to_toolkit(self) -> list[Tool]:
|
||||
component_toolkit = _get_component_toolkit()
|
||||
tools = component_toolkit(component=self).get_tools(callbacks=self.get_langchain_callbacks())
|
||||
if hasattr(self, TOOLS_METADATA_INPUT_NAME):
|
||||
|
|
@ -1212,8 +1213,8 @@ class Component(CustomComponent):
|
|||
def _build_tool_output(self) -> Output:
|
||||
return Output(name=TOOL_OUTPUT_NAME, display_name=TOOL_OUTPUT_DISPLAY_NAME, method="to_toolkit", types=["Tool"])
|
||||
|
||||
def _build_tools_metadata_input(self):
|
||||
tools = self.to_toolkit()
|
||||
async def _build_tools_metadata_input(self):
|
||||
tools = await self.to_toolkit()
|
||||
tool_data = (
|
||||
self.tools_metadata
|
||||
if hasattr(self, TOOLS_METADATA_INPUT_NAME)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -86,7 +86,7 @@ class PromptInput(BaseInputMixin, ListableInputMixin, InputTraceMixin, ToolModeM
|
|||
field_type: SerializableFieldTypes = FieldTypes.PROMPT
|
||||
|
||||
|
||||
class CodeInput(BaseInputMixin, ListableInputMixin, InputTraceMixin):
|
||||
class CodeInput(BaseInputMixin, ListableInputMixin, InputTraceMixin, ToolModeMixin):
|
||||
field_type: SerializableFieldTypes = FieldTypes.CODE
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Literal
|
|||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from langflow.inputs.inputs import FieldTypes
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
_convert_field_type_to_type: dict[FieldTypes, type] = {
|
||||
FieldTypes.TEXT: str,
|
||||
|
|
@ -62,3 +63,51 @@ def create_input_schema(inputs: list["InputTypes"]) -> type[BaseModel]:
|
|||
model = create_model("InputSchema", **fields)
|
||||
model.model_rebuild()
|
||||
return model
|
||||
|
||||
|
||||
def create_input_schema_from_dict(inputs: list[dotdict], param_key: str | None = None) -> type[BaseModel]:
|
||||
if not isinstance(inputs, list):
|
||||
msg = "inputs must be a list of Inputs"
|
||||
raise TypeError(msg)
|
||||
fields = {}
|
||||
for input_model in inputs:
|
||||
# Create a Pydantic Field for each input field
|
||||
field_type = input_model.type
|
||||
if hasattr(input_model, "options") and isinstance(input_model.options, list) and input_model.options:
|
||||
literal_string = f"Literal{input_model.options}"
|
||||
# validate that the literal_string is a valid literal
|
||||
|
||||
field_type = eval(literal_string, {"Literal": Literal}) # noqa: S307
|
||||
if hasattr(input_model, "is_list") and input_model.is_list:
|
||||
field_type = list[field_type] # type: ignore[valid-type]
|
||||
if input_model.name:
|
||||
name = input_model.name.replace("_", " ").title()
|
||||
elif input_model.display_name:
|
||||
name = input_model.display_name
|
||||
else:
|
||||
msg = "Input name or display_name is required"
|
||||
raise ValueError(msg)
|
||||
field_dict = {
|
||||
"title": name,
|
||||
"description": input_model.info or "",
|
||||
}
|
||||
if input_model.required is False:
|
||||
field_dict["default"] = input_model.value # type: ignore[assignment]
|
||||
pydantic_field = Field(**field_dict)
|
||||
|
||||
fields[input_model.name] = (field_type, pydantic_field)
|
||||
|
||||
# Wrap fields in a dictionary with the key as param_key
|
||||
if param_key is not None:
|
||||
# Create an inner model with the fields
|
||||
inner_model = create_model("InnerModel", **fields)
|
||||
|
||||
# Ensure the model is wrapped correctly in a dictionary
|
||||
# model = create_model("InputSchema", **{param_key: (inner_model, Field(default=..., description=description))})
|
||||
model = create_model("InputSchema", **{param_key: (inner_model, ...)})
|
||||
else:
|
||||
# Create and return the InputSchema model
|
||||
model = create_model("InputSchema", **fields)
|
||||
|
||||
model.model_rebuild()
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, field_serializer
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
|
|
@ -5,6 +7,8 @@ from pydantic_core import PydanticSerializationError
|
|||
|
||||
from langflow.schema.log import LoggableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Log(BaseModel):
|
||||
name: str
|
||||
|
|
@ -13,21 +17,23 @@ class Log(BaseModel):
|
|||
|
||||
@field_serializer("message")
|
||||
def serialize_message(self, value):
|
||||
# We need to make sure everything inside the message has been serialized
|
||||
if isinstance(value, dict):
|
||||
return {key: self.serialize_message(value[key]) for key in value}
|
||||
if isinstance(value, list):
|
||||
return [self.serialize_message(item) for item in value]
|
||||
# To json is for LangChain Serializable objects
|
||||
if hasattr(value, "dict") and isinstance(value, V1BaseModel):
|
||||
# This is for Pydantic V1 models
|
||||
return value.dict()
|
||||
if hasattr(value, "to_json"):
|
||||
return value.to_json()
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(exclude_none=True)
|
||||
try:
|
||||
# We need to make sure everything inside the message has been serialized
|
||||
if isinstance(value, dict):
|
||||
return {key: self.serialize_message(value[key]) for key in value}
|
||||
if isinstance(value, list):
|
||||
return [self.serialize_message(item) for item in value]
|
||||
# To json is for LangChain Serializable objects
|
||||
if hasattr(value, "dict") and isinstance(value, V1BaseModel):
|
||||
# This is for Pydantic V1 models
|
||||
return value.dict()
|
||||
if hasattr(value, "to_json"):
|
||||
return value.to_json()
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(exclude_none=True)
|
||||
value = jsonable_encoder(value)
|
||||
except UnicodeDecodeError:
|
||||
return str(value) # Fallback to string representation
|
||||
except PydanticSerializationError:
|
||||
return str(value)
|
||||
return value
|
||||
|
|
|
|||
33
src/backend/tests/unit/base/tools/test_create_schema.py
Normal file
33
src/backend/tests/unit/base/tools/test_create_schema.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from langflow.io.schema import create_input_schema_from_dict
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
def test_create_schema():
|
||||
sample_input = [
|
||||
{
|
||||
"_input_type": "MultilineInput",
|
||||
"advanced": False,
|
||||
"display_name": "Chat Input - Text",
|
||||
"dynamic": False,
|
||||
"info": "Message to be passed as input.",
|
||||
"input_types": ["Message"],
|
||||
"list": False,
|
||||
"load_from_db": False,
|
||||
"multiline": True,
|
||||
"name": "ChatInput-xNZ0a|input_value",
|
||||
"placeholder": "",
|
||||
"required": False,
|
||||
"show": True,
|
||||
"title_case": False,
|
||||
"tool_mode": True,
|
||||
"trace_as_input": True,
|
||||
"trace_as_metadata": True,
|
||||
"type": "str",
|
||||
"value": "add 1+1",
|
||||
}
|
||||
]
|
||||
# convert to dotdict
|
||||
# change the key type
|
||||
sample_input = [dotdict(field) for field in sample_input]
|
||||
schema = create_input_schema_from_dict(sample_input)
|
||||
assert schema is not None
|
||||
|
|
@ -5,11 +5,11 @@ from langflow.components.agents.agent import AgentComponent
|
|||
from langflow.components.tools.calculator import CalculatorToolComponent
|
||||
|
||||
|
||||
def test_component_to_toolkit():
|
||||
async def test_component_to_toolkit():
|
||||
calculator_component = CalculatorToolComponent()
|
||||
agent_component = AgentComponent().set(tools=[calculator_component])
|
||||
|
||||
tools = agent_component.to_toolkit()
|
||||
tools = await agent_component.to_toolkit()
|
||||
assert len(tools) == 1
|
||||
tool = tools[0]
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from langflow.custom.custom_component.component import Component
|
|||
|
||||
|
||||
class TestComponentOutputs:
|
||||
def test_run_and_validate_update_outputs_tool_mode(self):
|
||||
async def test_run_and_validate_update_outputs_tool_mode(self):
|
||||
"""Test run_and_validate_update_outputs with tool_mode field."""
|
||||
|
||||
class TestComponent(Component):
|
||||
|
|
@ -33,7 +33,7 @@ class TestComponentOutputs:
|
|||
}
|
||||
|
||||
# Test enabling tool mode
|
||||
updated_node = component.run_and_validate_update_outputs(
|
||||
updated_node = await component.run_and_validate_update_outputs(
|
||||
frontend_node=frontend_node.copy(), # Use a copy to avoid modifying original
|
||||
field_name="tool_mode",
|
||||
field_value=True,
|
||||
|
|
@ -45,7 +45,7 @@ class TestComponentOutputs:
|
|||
assert updated_node["outputs"][0]["display_name"] == TOOL_OUTPUT_DISPLAY_NAME
|
||||
|
||||
# Test disabling tool mode - use the original frontend node
|
||||
updated_node = component.run_and_validate_update_outputs(
|
||||
updated_node = await component.run_and_validate_update_outputs(
|
||||
frontend_node={"outputs": original_outputs.copy()}, # Use original outputs
|
||||
field_name="tool_mode",
|
||||
field_value=False,
|
||||
|
|
@ -60,7 +60,7 @@ class TestComponentOutputs:
|
|||
assert "types" in updated_node["outputs"][0]
|
||||
assert "selected" in updated_node["outputs"][0]
|
||||
|
||||
def test_run_and_validate_update_outputs_invalid_output(self):
|
||||
async def test_run_and_validate_update_outputs_invalid_output(self):
|
||||
"""Test run_and_validate_update_outputs with invalid output structure."""
|
||||
|
||||
class TestComponent(Component):
|
||||
|
|
@ -74,11 +74,11 @@ class TestComponentOutputs:
|
|||
|
||||
# Test validation fails for invalid output
|
||||
with pytest.raises(ValueError, match="Invalid output: 1 validation error for Output"):
|
||||
component.run_and_validate_update_outputs(
|
||||
await component.run_and_validate_update_outputs(
|
||||
frontend_node=frontend_node, field_name="some_field", field_value="some_value"
|
||||
)
|
||||
|
||||
def test_run_and_validate_update_outputs_custom_update(self):
|
||||
async def test_run_and_validate_update_outputs_custom_update(self):
|
||||
"""Test run_and_validate_update_outputs with custom update logic."""
|
||||
|
||||
class CustomComponent(Component):
|
||||
|
|
@ -111,7 +111,7 @@ class TestComponentOutputs:
|
|||
frontend_node = {"outputs": []}
|
||||
|
||||
# Test custom update logic
|
||||
updated_node = component.run_and_validate_update_outputs(
|
||||
updated_node = await component.run_and_validate_update_outputs(
|
||||
frontend_node=frontend_node, field_name="custom_field", field_value="custom_value"
|
||||
)
|
||||
|
||||
|
|
@ -122,14 +122,14 @@ class TestComponentOutputs:
|
|||
assert "types" in updated_node["outputs"][0]
|
||||
assert "selected" in updated_node["outputs"][0]
|
||||
|
||||
def test_run_and_validate_update_outputs_with_existing_tool_output(self):
|
||||
async def test_run_and_validate_update_outputs_with_existing_tool_output(self):
|
||||
"""Test run_and_validate_update_outputs when tool output already exists."""
|
||||
|
||||
class TestComponent(Component):
|
||||
def build(self) -> None:
|
||||
pass
|
||||
|
||||
def to_toolkit(self) -> list:
|
||||
async def to_toolkit(self) -> list:
|
||||
"""Method that returns a list of tools."""
|
||||
return []
|
||||
|
||||
|
|
@ -154,7 +154,7 @@ class TestComponentOutputs:
|
|||
}
|
||||
|
||||
# Test enabling tool mode doesn't duplicate tool output
|
||||
updated_node = component.run_and_validate_update_outputs(
|
||||
updated_node = await component.run_and_validate_update_outputs(
|
||||
frontend_node=frontend_node, field_name="tool_mode", field_value=True
|
||||
)
|
||||
|
||||
|
|
@ -164,7 +164,7 @@ class TestComponentOutputs:
|
|||
assert "types" in updated_node["outputs"][0]
|
||||
assert "selected" in updated_node["outputs"][0]
|
||||
|
||||
def test_run_and_validate_update_outputs_with_multiple_outputs(self):
|
||||
async def test_run_and_validate_update_outputs_with_multiple_outputs(self):
|
||||
"""Test run_and_validate_update_outputs with multiple outputs."""
|
||||
|
||||
class TestComponent(Component):
|
||||
|
|
@ -203,7 +203,7 @@ class TestComponentOutputs:
|
|||
frontend_node = {"outputs": []}
|
||||
|
||||
# Test adding multiple outputs
|
||||
updated_node = component.run_and_validate_update_outputs(
|
||||
updated_node = await component.run_and_validate_update_outputs(
|
||||
frontend_node=frontend_node, field_name="add_output", field_value=True
|
||||
)
|
||||
|
||||
|
|
@ -217,7 +217,7 @@ class TestComponentOutputs:
|
|||
assert set(output["types"]) == {"Text"}
|
||||
assert output["selected"] == "Text"
|
||||
|
||||
def test_run_and_validate_update_outputs_output_validation(self):
|
||||
async def test_run_and_validate_update_outputs_output_validation(self):
|
||||
"""Test output validation in run_and_validate_update_outputs."""
|
||||
|
||||
class TestComponent(Component):
|
||||
|
|
@ -236,10 +236,14 @@ class TestComponentOutputs:
|
|||
}
|
||||
|
||||
with pytest.raises(AttributeError, match="nonexistent_method not found in TestComponent"):
|
||||
component.run_and_validate_update_outputs(frontend_node=invalid_node, field_name="test", field_value=True)
|
||||
await component.run_and_validate_update_outputs(
|
||||
frontend_node=invalid_node, field_name="test", field_value=True
|
||||
)
|
||||
|
||||
# Test missing method case
|
||||
invalid_node = {"outputs": [{"name": "test", "type": "str", "display_name": "Test"}]}
|
||||
|
||||
with pytest.raises(ValueError, match="Output test does not have a method"):
|
||||
component.run_and_validate_update_outputs(frontend_node=invalid_node, field_name="test", field_value=True)
|
||||
await component.run_and_validate_update_outputs(
|
||||
frontend_node=invalid_node, field_name="test", field_value=True
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue