refactor: enhance tool creation logic and add FeatureFlags (#3662)
* Add `required_inputs` field to `Output` model in `base.py` * Refactor ComponentTool to ComponentToolkit and enhance tool creation logic - Replaced `ComponentTool` with `ComponentToolkit` to better encapsulate component-related tools. - Introduced `build_description` and `_build_output_function` helper functions for dynamic tool creation. - Updated tool initialization to handle multiple outputs and required inputs using `StructuredTool`. - Improved schema creation for tool arguments based on component inputs. * Refactor `to_tool` method to `to_toolkit` to use `ComponentToolkit` instead of `ComponentTool` * Refactor `ComponentTool` to `ComponentToolkit` in unit tests - Updated import statements to reflect the new `ComponentToolkit` class. - Modified test logic to use `ComponentToolkit` for retrieving tools. - Adjusted assertions to match the new structure and output format. - Ensured compatibility with `Message` schema for output validation. * Refactor `test_component_to_tool` to validate `ComponentToolkit` and tool properties * Refactor `build_description` to include input types in the output format * Add method to set required inputs for outputs based on method analysis - Introduced `_set_output_required_inputs` method to determine and set required inputs for each output by analyzing the method's source code. - Added necessary imports (`ast` and `dedent`) to support the new functionality. * Update test to assert full tool description in test_component_to_tool.py * Add unit tests for verifying required inputs of various components - Added tests to ensure that required inputs for outputs are present in the inputs of `ChatInput`, `ChatOutput`, `SequentialTaskComponent`, `ToolCallingAgentComponent`, and `OpenAIModelComponent`. - Included helper functions to check if required inputs are in inputs and to assert that all outputs have different required inputs. * Add RequiredInputsVisitor to identify required inputs in AST nodes - Introduced RequiredInputsVisitor class to traverse AST nodes and collect required inputs. - The visitor checks for 'self' attributes matching the provided inputs and adds them to the required_inputs set. * Refactor required inputs extraction using `RequiredInputsVisitor` * Add feature flags configuration for toolkit output in settings * Add toolkit output handling based on feature flag in custom component utils * Add method to append 'component_as_tool' output in custom component * Add unit test for toolkit output feature flag in custom component * Add utility functions for lazy loading and instantiating input types in langflow - Introduced `get_InputTypesMap` for lazy loading of `InputTypesMap`. - Added `instantiate_input` function to create instances of input types dynamically. - Included type checking and error handling for invalid input types. * Refactor input instantiation logic and update imports - Removed `instantiate_input` function from `inputs.py` and moved it to `utils.py`. - Updated imports in `base.py` to reflect the new location of `instantiate_input`. - Added missing import for `Callable` in `base.py`. * Refactor import statement to use `instantiate_input` from `langflow.inputs.utils` in test_inputs.py * Add TOOL_OUTPUT_NAME constant to tools module * Add type checking and TOOL_OUTPUT_NAME filter in ComponentToolkit - Introduced `TYPE_CHECKING` for type hints to avoid circular imports. - Added `TOOL_OUTPUT_NAME` constant to filter specific outputs in `ComponentToolkit`. - Updated type annotations to use forward references. * Refactor component toolkit import to avoid circular dependency and use constant for tool output name * Refactor `ComponentToolkit` class to remove inheritance from `BaseToolkit` and add an initializer for `component` * Add unit test for ComponentToolkit in test_component_to_tool - Added `test_component_to_tool_has_no_component_as_tool` to verify that `ComponentToolkit` correctly initializes with a `ChatInput` component and returns the expected tools. * Refactor toolkit output handling to `custom_component` module * fix: mypy errors union-attr and arg-type * Add 'OTHER' field type to schema in langflow/io/schema.py * Add tool name formatting to ComponentToolkit to ensure valid characters * Refactor toolkit output handling and add type hint for `to_toolkit` method * Add `is_interface_component` attribute to vertex types and update import order * Add tests for ToolCallingAgentComponent and ChatOutput with API key handling - Updated `test_component_tool` to reflect new description format. - Added `test_component_tool_with_api_key` to test `ToolCallingAgentComponent` with `ChatOutput` and OpenAI API key. - Enabled `add_toolkit_output` feature flag for testing. * Refactor `_find_matching_output_method` to accept `input_name` parameter for more precise input-output matching * Replace ValueError with warning in build_description function * use chat_output component directly in set * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 3/3) * Refactor: Reorder method calls in `__init__` for logical consistency Moved `set_class_code` method call to ensure output types and required inputs are set before class code initialization. * Update _format_tool_name to allow '.' in tool names * Refactor `_format_tool_name` to remove non-alphanumeric characters * Update test assertions for component tool name and output mapping * Handle case where 'required_inputs' is empty in 'component_tool.py' * Refactor import statements for better readability in `base.py` * [autofix.ci] apply automated fixes * Add noqa comment to suppress import warning and re-add Any import in base.py --------- Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
df792cb19d
commit
270f609fe7
29 changed files with 903 additions and 206 deletions
|
|
@ -1,42 +1,92 @@
|
|||
from typing import Any
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import BaseTool, ToolException
|
||||
import re
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
|
||||
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.io import Output
|
||||
|
||||
|
||||
class ComponentTool(BaseTool):
|
||||
name: str
|
||||
description: str
|
||||
component: "Component"
|
||||
|
||||
def __init__(self, component: "Component") -> None:
|
||||
"""Initialize the tool."""
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
name = component.name or component.__class__.__name__
|
||||
description = component.description or ""
|
||||
args_schema = create_input_schema(component.inputs)
|
||||
super().__init__(name=name, description=description, args_schema=args_schema, component=component)
|
||||
# self.component = component
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
schema = self.get_input_schema()
|
||||
return schema.schema()["properties"]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
results, _ = self.component(**kwargs)
|
||||
return results
|
||||
except Exception as e:
|
||||
msg = f"Error running {self.name}: {e}"
|
||||
raise ToolException(msg)
|
||||
def _get_input_type(input: InputTypes):
|
||||
if input.input_types:
|
||||
if len(input.input_types) == 1:
|
||||
return input.input_types[0]
|
||||
return " | ".join(input.input_types)
|
||||
return input.field_type
|
||||
|
||||
|
||||
ComponentTool.update_forward_refs()
|
||||
def build_description(component: Component, output: Output):
|
||||
if not output.required_inputs:
|
||||
warnings.warn(f"Output {output.name} does not have required inputs defined")
|
||||
|
||||
if output.required_inputs:
|
||||
args = ", ".join(
|
||||
sorted(
|
||||
[
|
||||
f"{input_name}: {_get_input_type(component._inputs[input_name])}"
|
||||
for input_name in output.required_inputs
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
args = ""
|
||||
return f"{output.method}({args}) - {component.description}"
|
||||
|
||||
|
||||
def _build_output_function(component: Component, output_method: Callable):
|
||||
def output_function(*args, **kwargs):
|
||||
component.set(*args, **kwargs)
|
||||
return output_method()
|
||||
|
||||
return output_function
|
||||
|
||||
|
||||
def _format_tool_name(name: str):
|
||||
# format to '^[a-zA-Z0-9_-]+$'."
|
||||
# to do that we must remove all non-alphanumeric characters
|
||||
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "-", name)
|
||||
|
||||
|
||||
class ComponentToolkit: # type: ignore
|
||||
def __init__(self, component: Component):
|
||||
self.component = component
|
||||
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
tools = []
|
||||
for output in self.component.outputs:
|
||||
if output.name == TOOL_OUTPUT_NAME:
|
||||
continue
|
||||
|
||||
if not output.method:
|
||||
msg = f"Output {output.name} does not have a method defined"
|
||||
raise ValueError(msg)
|
||||
|
||||
output_method: Callable = getattr(self.component, output.method)
|
||||
args_schema = None
|
||||
if output.required_inputs:
|
||||
inputs = [self.component._inputs[input_name] for input_name in output.required_inputs]
|
||||
args_schema = create_input_schema(inputs)
|
||||
else:
|
||||
args_schema = create_input_schema(self.component.inputs)
|
||||
name = f"{self.component.name}.{output.method}"
|
||||
formatted_name = _format_tool_name(name)
|
||||
tools.append(
|
||||
StructuredTool(
|
||||
name=formatted_name,
|
||||
description=build_description(self.component, output),
|
||||
func=_build_output_function(self.component, output_method),
|
||||
args_schema=args_schema,
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
|
|
|||
1
src/backend/base/langflow/base/tools/constants.py
Normal file
1
src/backend/base/langflow/base/tools/constants.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
TOOL_OUTPUT_NAME = "component_as_tool"
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, get_type_hints
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -10,13 +12,17 @@ import nanoid # type: ignore
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
|
||||
from langflow.custom.tree_visitor import RequiredInputsVisitor
|
||||
from langflow.events.event_manager import EventManager
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.helpers.custom import format_type
|
||||
from langflow.schema.artifact import get_artifact_type, post_process_raw
|
||||
from langflow.schema.data import Data
|
||||
from langflow.schema.log import LoggableType
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
||||
from langflow.services.tracing.schema import Log
|
||||
from langflow.template.field.base import UNDEFINED, Input, Output
|
||||
from langflow.template.frontend_node.custom_components import ComponentFrontendNode
|
||||
|
|
@ -30,6 +36,19 @@ if TYPE_CHECKING:
|
|||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
|
||||
|
||||
_ComponentToolkit = None
|
||||
|
||||
|
||||
def _get_component_toolkit():
|
||||
global _ComponentToolkit
|
||||
if _ComponentToolkit is None:
|
||||
from langflow.base.tools.component_tool import ComponentToolkit
|
||||
|
||||
_ComponentToolkit = ComponentToolkit
|
||||
return _ComponentToolkit
|
||||
|
||||
|
||||
BACKWARDS_COMPATIBLE_ATTRIBUTES = ["user_id", "vertex", "tracing_service"]
|
||||
CONFIG_ATTRIBUTES = ["_display_name", "_description", "_icon", "_name", "_metadata"]
|
||||
|
||||
|
|
@ -72,6 +91,8 @@ class Component(CustomComponent):
|
|||
self.__inputs = inputs
|
||||
self.__config = config
|
||||
self._reset_all_output_values()
|
||||
if FEATURE_FLAGS.add_toolkit_output and hasattr(self, "_append_tool_output"):
|
||||
self._append_tool_output()
|
||||
super().__init__(**config)
|
||||
if hasattr(self, "_trace_type"):
|
||||
self.trace_type = self._trace_type
|
||||
|
|
@ -84,6 +105,7 @@ class Component(CustomComponent):
|
|||
# Set output types
|
||||
self._set_output_types()
|
||||
self.set_class_code()
|
||||
self._set_output_required_inputs()
|
||||
|
||||
def set_event_manager(self, event_manager: EventManager | None = None):
|
||||
self._event_manager = event_manager
|
||||
|
|
@ -305,6 +327,24 @@ class Component(CustomComponent):
|
|||
output.add_types(return_types)
|
||||
output.set_selected()
|
||||
|
||||
def _set_output_required_inputs(self):
|
||||
for output in self.outputs:
|
||||
if not output.method:
|
||||
continue
|
||||
method = getattr(self, output.method, None)
|
||||
if not method or not callable(method):
|
||||
continue
|
||||
try:
|
||||
source_code = inspect.getsource(method)
|
||||
ast_tree = ast.parse(dedent(source_code))
|
||||
except Exception:
|
||||
source_code = self._code
|
||||
ast_tree = ast.parse(dedent(source_code))
|
||||
|
||||
visitor = RequiredInputsVisitor(self._inputs)
|
||||
visitor.visit(ast_tree)
|
||||
output.required_inputs = sorted(visitor.required_inputs)
|
||||
|
||||
def get_output_by_method(self, method: Callable):
|
||||
# method is a callable and output.method is a string
|
||||
# we need to find the output that has the same method
|
||||
|
|
@ -335,24 +375,56 @@ class Component(CustomComponent):
|
|||
text += f"{output.name}[{','.join(output.types)}]->{input_.name}[{','.join(input_.input_types or [])}]\n"
|
||||
return text
|
||||
|
||||
def _find_matching_output_method(self, value: Component):
|
||||
# get all outputs of the value component
|
||||
def _find_matching_output_method(self, input_name: str, value: Component):
|
||||
"""
|
||||
Find the output method from the given component (`value`) that matches the specified input (`input_name`)
|
||||
in the current component.
|
||||
|
||||
This method searches through all outputs of the provided component to find outputs whose types match
|
||||
the input types of the specified input in the current component. If exactly one matching output is found,
|
||||
it returns the corresponding method. If multiple matching outputs are found, it raises an error indicating
|
||||
ambiguity. If no matching outputs are found, it raises an error indicating that no suitable output was found.
|
||||
|
||||
Args:
|
||||
input_name (str): The name of the input in the current component to match.
|
||||
value (Component): The component whose outputs are to be considered.
|
||||
|
||||
Returns:
|
||||
Callable: The method corresponding to the matching output.
|
||||
|
||||
Raises:
|
||||
ValueError: If multiple matching outputs are found, if no matching outputs are found,
|
||||
or if the output method is invalid.
|
||||
"""
|
||||
# Retrieve all outputs from the given component
|
||||
outputs = value._outputs_map.values()
|
||||
# check if the any of the types in the output.types matches ONLY one input in the current component
|
||||
# Prepare to collect matching output-input pairs
|
||||
matching_pairs = []
|
||||
# Get the input object from the current component
|
||||
input_ = self._inputs[input_name]
|
||||
# Iterate over outputs to find matches based on types
|
||||
for output in outputs:
|
||||
for input_ in self.inputs:
|
||||
for output_type in output.types:
|
||||
if input_.input_types and output_type in input_.input_types:
|
||||
matching_pairs.append((output, input_))
|
||||
for output_type in output.types:
|
||||
# Check if the output type matches the input's accepted types
|
||||
if input_.input_types and output_type in input_.input_types:
|
||||
matching_pairs.append((output, input_))
|
||||
# If multiple matches are found, raise an error indicating ambiguity
|
||||
if len(matching_pairs) > 1:
|
||||
matching_pairs_str = self._build_error_string_from_matching_pairs(matching_pairs)
|
||||
msg = (
|
||||
f"There are multiple outputs from {value.__class__.__name__} "
|
||||
f"that can connect to inputs in {self.__class__.__name__}: {matching_pairs_str}"
|
||||
)
|
||||
# If no matches are found, raise an error indicating no suitable output
|
||||
if not matching_pairs:
|
||||
msg = (
|
||||
f"No matching output from {value.__class__.__name__} found for input '{input_name}' "
|
||||
f"in {self.__class__.__name__}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
# Get the matching output and input pair
|
||||
output, input_ = matching_pairs[0]
|
||||
# Ensure that the output method is a valid method name (string)
|
||||
if not isinstance(output.method, str):
|
||||
msg = f"Method {output.method} is not a valid output of {value.__class__.__name__}"
|
||||
raise ValueError(msg)
|
||||
|
|
@ -365,7 +437,7 @@ class Component(CustomComponent):
|
|||
# We need to find the Output that can connect to an input of the current component
|
||||
# if there's more than one output that matches, we need to raise an error
|
||||
# because we don't know which one to connect to
|
||||
value = self._find_matching_output_method(value)
|
||||
value = self._find_matching_output_method(key, value)
|
||||
if callable(value) and self._inherits_from_component(value):
|
||||
try:
|
||||
self._method_is_valid_output(value)
|
||||
|
|
@ -744,11 +816,9 @@ class Component(CustomComponent):
|
|||
def _get_fallback_input(self, **kwargs):
|
||||
return Input(**kwargs)
|
||||
|
||||
def to_tool(self):
|
||||
# TODO: This is a temporary solution to avoid circular imports
|
||||
from langflow.base.tools.component_tool import ComponentTool
|
||||
|
||||
return ComponentTool(component=self)
|
||||
def to_toolkit(self) -> list[Tool]:
|
||||
ComponentToolkit = _get_component_toolkit()
|
||||
return ComponentToolkit(component=self).get_tools()
|
||||
|
||||
def get_project_name(self):
|
||||
if hasattr(self, "_tracing_service") and self._tracing_service:
|
||||
|
|
@ -773,3 +843,7 @@ class Component(CustomComponent):
|
|||
data["output"] = self._current_output
|
||||
data["component_id"] = self._id
|
||||
self._event_manager.on_log(data=data)
|
||||
|
||||
def _append_tool_output(self):
|
||||
if next((output for output in self.outputs if output.name == TOOL_OUTPUT_NAME), None) is None:
|
||||
self.outputs.append(Output(name=TOOL_OUTPUT_NAME, display_name="Tool", method="to_toolkit", types=["Tool"]))
|
||||
|
|
|
|||
12
src/backend/base/langflow/custom/tree_visitor.py
Normal file
12
src/backend/base/langflow/custom/tree_visitor.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import ast
|
||||
|
||||
|
||||
class RequiredInputsVisitor(ast.NodeVisitor):
|
||||
def __init__(self, inputs):
|
||||
self.inputs = inputs
|
||||
self.required_inputs = set()
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "self" and node.attr in self.inputs:
|
||||
self.required_inputs.add(node.attr)
|
||||
self.generic_visit(node)
|
||||
|
|
@ -208,6 +208,7 @@ class InterfaceVertex(ComponentVertex):
|
|||
super().__init__(data, graph=graph)
|
||||
self._added_message = None
|
||||
self.steps = [self._build, self._run]
|
||||
self.is_interface_component = True
|
||||
|
||||
def build_stream_url(self):
|
||||
return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream"
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
32
src/backend/base/langflow/inputs/utils.py
Normal file
32
src/backend/base/langflow/inputs/utils.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.inputs.inputs import InputTypes, InputTypesMap
|
||||
else:
|
||||
InputTypes = Any
|
||||
InputTypesMap = Any
|
||||
|
||||
# Lazy import for InputTypesMap
|
||||
_InputTypesMap: dict[str, type["InputTypes"]] | None = None
|
||||
|
||||
|
||||
def get_InputTypesMap():
|
||||
global _InputTypesMap
|
||||
if _InputTypesMap is None:
|
||||
from langflow.inputs.inputs import InputTypesMap
|
||||
|
||||
_InputTypesMap = InputTypesMap
|
||||
return _InputTypesMap
|
||||
|
||||
|
||||
def instantiate_input(input_type: str, data: dict) -> InputTypes:
|
||||
InputTypesMap = get_InputTypesMap()
|
||||
|
||||
input_type_class = InputTypesMap.get(input_type)
|
||||
if "type" in data:
|
||||
# Replace with field_type
|
||||
data["field_type"] = data.pop("type")
|
||||
if input_type_class:
|
||||
return input_type_class(**data)
|
||||
msg = f"Invalid input type: {input_type}"
|
||||
raise ValueError(msg)
|
||||
|
|
@ -14,6 +14,7 @@ _convert_field_type_to_type: dict[FieldTypes, type] = {
|
|||
FieldTypes.TABLE: dict,
|
||||
FieldTypes.FILE: str,
|
||||
FieldTypes.PROMPT: str,
|
||||
FieldTypes.OTHER: str,
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -30,6 +31,9 @@ def create_input_schema(inputs: list["InputTypes"]) -> type[BaseModel]:
|
|||
field_type = input_model.field_type
|
||||
if isinstance(field_type, FieldTypes):
|
||||
field_type = _convert_field_type_to_type[field_type]
|
||||
else:
|
||||
msg = f"Invalid field type: {field_type}"
|
||||
raise ValueError(msg)
|
||||
if hasattr(input_model, "options") and isinstance(input_model.options, list) and input_model.options:
|
||||
literal_string = f"Literal{input_model.options}"
|
||||
# validate that the literal_string is a valid literal
|
||||
|
|
|
|||
11
src/backend/base/langflow/services/settings/feature_flags.py
Normal file
11
src/backend/base/langflow/services/settings/feature_flags.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class FeatureFlags(BaseSettings):
|
||||
add_toolkit_output: bool = False
|
||||
|
||||
class Config:
|
||||
env_prefix = "LANGFLOW_FEATURE_"
|
||||
|
||||
|
||||
FEATURE_FLAGS = FeatureFlags()
|
||||
|
|
@ -1,16 +1,11 @@
|
|||
from collections.abc import Callable
|
||||
from collections.abc import Callable # noqa: I001
|
||||
from enum import Enum
|
||||
from typing import Any, GenericAlias, _GenericAlias, _UnionGenericAlias # type: ignore
|
||||
from typing import Any # noqa
|
||||
from typing import GenericAlias # type: ignore
|
||||
from typing import _GenericAlias # type: ignore
|
||||
from typing import _UnionGenericAlias # type: ignore
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
|
||||
|
||||
from langflow.field_typing import Text
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
|
|
@ -189,6 +184,9 @@ class Output(BaseModel):
|
|||
|
||||
cache: bool = Field(default=True)
|
||||
|
||||
required_inputs: list[str] | None = Field(default=None)
|
||||
"""List of required inputs for this output."""
|
||||
|
||||
def to_dict(self):
|
||||
return self.model_dump(by_alias=True, exclude_none=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from typing import cast
|
|||
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
|
||||
from langflow.inputs.inputs import InputTypes, instantiate_input
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.inputs.utils import instantiate_input
|
||||
from langflow.template.field.base import Input
|
||||
from langflow.utils.constants import DIRECT_TYPES
|
||||
|
||||
|
|
|
|||
|
|
@ -367,7 +367,8 @@ def added_flow_with_prompt_and_history(client, json_flow_with_prompt_and_history
|
|||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
return response.json()
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -379,7 +380,8 @@ def added_flow_chat_input(client, json_chat_input, logged_in_headers):
|
|||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
return response.json()
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -391,7 +393,8 @@ def added_flow_two_outputs(client, json_two_outputs, logged_in_headers):
|
|||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
return response.json()
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -403,7 +406,8 @@ def added_vector_store(client, json_vector_store, logged_in_headers):
|
|||
assert response.status_code == 201
|
||||
assert response.json()["name"] == vector_store.name
|
||||
assert response.json()["data"] == vector_store.data
|
||||
return response.json()
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -429,7 +433,8 @@ def flow_component(client: TestClient, logged_in_headers):
|
|||
flow = FlowCreate(**graph_dict)
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -460,7 +465,8 @@ def get_simple_api_test(client, logged_in_headers, json_simple_api_test):
|
|||
flow = FlowCreate(name="Simple API Test", data=data, description="Simple API Test")
|
||||
response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
yield response.json()
|
||||
client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.fixture(name="starter_project")
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from langflow.base.tools.component_tool import ComponentTool
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
|
||||
|
||||
def test_component_tool():
|
||||
chat_input = ChatInput()
|
||||
component_tool = ComponentTool(component=chat_input)
|
||||
assert component_tool.name == "ChatInput"
|
||||
assert component_tool.description == chat_input.description
|
||||
assert component_tool.args == {
|
||||
"input_value": {
|
||||
"default": "",
|
||||
"description": "Message to be passed as input.",
|
||||
"title": "Input Value",
|
||||
"type": "string",
|
||||
},
|
||||
"should_store_message": {
|
||||
"default": True,
|
||||
"description": "Store the message in the history.",
|
||||
"title": "Should Store Message",
|
||||
"type": "boolean",
|
||||
},
|
||||
"sender": {
|
||||
"default": "User",
|
||||
"description": "Type of sender.",
|
||||
"enum": ["Machine", "User"],
|
||||
"title": "Sender",
|
||||
"type": "string",
|
||||
},
|
||||
"sender_name": {
|
||||
"default": "User",
|
||||
"description": "Name of the sender.",
|
||||
"title": "Sender Name",
|
||||
"type": "string",
|
||||
},
|
||||
"session_id": {
|
||||
"default": "",
|
||||
"description": "The session ID of the chat. If empty, the current session ID parameter will be used.",
|
||||
"title": "Session Id",
|
||||
"type": "string",
|
||||
},
|
||||
"files": {
|
||||
"default": "",
|
||||
"description": "Files to be sent with the message.",
|
||||
"items": {"type": "string"},
|
||||
"title": "Files",
|
||||
"type": "array",
|
||||
},
|
||||
}
|
||||
assert component_tool.component == chat_input
|
||||
|
||||
result = component_tool.invoke(input=dict(input_value="test"))
|
||||
assert isinstance(result, dict)
|
||||
assert hasattr(result["message"], "get_text")
|
||||
assert result["message"].get_text() == "test"
|
||||
103
src/backend/tests/unit/base/tools/test_component_toolkit.py
Normal file
103
src/backend/tests/unit/base/tools/test_component_toolkit.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.base.tools.component_tool import ComponentToolkit
|
||||
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.outputs import ChatOutput
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_toolkit_output():
|
||||
FEATURE_FLAGS.add_toolkit_output = True
|
||||
yield
|
||||
FEATURE_FLAGS.add_toolkit_output = False
|
||||
|
||||
|
||||
def test_component_tool():
|
||||
chat_input = ChatInput()
|
||||
component_toolkit = ComponentToolkit(component=chat_input)
|
||||
component_tool = component_toolkit.get_tools()[0]
|
||||
assert component_tool.name == "ChatInput-message_response"
|
||||
terms = [
|
||||
"message_response",
|
||||
"files",
|
||||
"input_value",
|
||||
"sender",
|
||||
"sender_name",
|
||||
"session_id",
|
||||
"should_store_message",
|
||||
]
|
||||
assert all(term in component_tool.description for term in terms)
|
||||
assert component_tool.args == {
|
||||
"input_value": {
|
||||
"default": "",
|
||||
"description": "Message to be passed as input.",
|
||||
"title": "Input Value",
|
||||
"type": "string",
|
||||
},
|
||||
"should_store_message": {
|
||||
"default": True,
|
||||
"description": "Store the message in the history.",
|
||||
"title": "Should Store Message",
|
||||
"type": "boolean",
|
||||
},
|
||||
"sender": {
|
||||
"default": "User",
|
||||
"description": "Type of sender.",
|
||||
"enum": ["Machine", "User"],
|
||||
"title": "Sender",
|
||||
"type": "string",
|
||||
},
|
||||
"sender_name": {
|
||||
"default": "User",
|
||||
"description": "Name of the sender.",
|
||||
"title": "Sender Name",
|
||||
"type": "string",
|
||||
},
|
||||
"session_id": {
|
||||
"default": "",
|
||||
"description": "The session ID of the chat. If empty, the current session ID parameter will be used.",
|
||||
"title": "Session Id",
|
||||
"type": "string",
|
||||
},
|
||||
"files": {
|
||||
"default": "",
|
||||
"description": "Files to be sent with the message.",
|
||||
"items": {"type": "string"},
|
||||
"title": "Files",
|
||||
"type": "array",
|
||||
},
|
||||
}
|
||||
assert component_toolkit.component == chat_input
|
||||
|
||||
result = component_tool.invoke(input=dict(input_value="test"))
|
||||
assert isinstance(result, Message)
|
||||
assert result.get_text() == "test"
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
def test_component_tool_with_api_key(client, add_toolkit_output):
|
||||
chat_output = ChatOutput()
|
||||
openai_llm = OpenAIModelComponent()
|
||||
openai_llm.set(api_key=os.environ["OPENAI_API_KEY"])
|
||||
tool_calling_agent = ToolCallingAgentComponent()
|
||||
tool_calling_agent.set(
|
||||
llm=openai_llm.build_model, tools=[chat_output], input_value="Which tools are available? Please tell its name."
|
||||
)
|
||||
|
||||
g = Graph(start=tool_calling_agent, end=tool_calling_agent)
|
||||
assert g is not None
|
||||
results = list(g.start())
|
||||
assert len(results) == 4
|
||||
assert "message_response" in tool_calling_agent._outputs_map["response"].value.get_text()
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
|
|
@ -8,9 +10,29 @@ def client():
|
|||
pass
|
||||
|
||||
|
||||
def test_component_to_tool():
|
||||
def test_component_to_toolkit():
|
||||
chat_input = ChatInput()
|
||||
tool = chat_input.to_tool()
|
||||
assert tool.name == "ChatInput"
|
||||
assert tool.description == "Get chat inputs from the Playground."
|
||||
assert tool.component._id == chat_input._id
|
||||
tools = chat_input.to_toolkit()
|
||||
assert len(tools) == 1
|
||||
tool = tools[0]
|
||||
|
||||
assert tool.name == "ChatInput-message_response"
|
||||
terms = [
|
||||
"message_response",
|
||||
"files",
|
||||
"input_value",
|
||||
"sender",
|
||||
"sender_name",
|
||||
"session_id",
|
||||
"should_store_message",
|
||||
]
|
||||
assert all(term in tool.description for term in terms)
|
||||
|
||||
assert isinstance(tool.func, Callable)
|
||||
assert tool.args_schema is not None
|
||||
|
||||
|
||||
def test_component_to_tool_has_no_component_as_tool():
|
||||
chat_input = ChatInput()
|
||||
tools = chat_input.to_toolkit()
|
||||
assert len(tools) == 1
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import pytest
|
||||
|
||||
from langflow.components.agents.CrewAIAgent import CrewAIAgentComponent
|
||||
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
|
||||
from langflow.components.helpers.SequentialTask import SequentialTaskComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.outputs import ChatOutput
|
||||
from langflow.template.field.base import Output
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -24,3 +27,45 @@ def test_set_component():
|
|||
task.set(agent=crewai_agent)
|
||||
assert task._edges[0]["source"] == crewai_agent._id
|
||||
assert crewai_agent in task._components
|
||||
|
||||
|
||||
def _output_required_inputs_are_in_inputs(output: Output, inputs: list[str]):
|
||||
return all(input_type in inputs for input_type in output.required_inputs)
|
||||
|
||||
|
||||
def _assert_all_outputs_have_different_required_inputs(outputs: list[Output]):
|
||||
required_inputs = [tuple(output.required_inputs) for output in outputs]
|
||||
assert len(required_inputs) == len(set(required_inputs)), "All outputs must have different required inputs"
|
||||
return True
|
||||
|
||||
|
||||
def test_set_required_inputs():
|
||||
chatinput = ChatInput()
|
||||
|
||||
assert all(_output_required_inputs_are_in_inputs(output, chatinput._inputs) for output in chatinput.outputs)
|
||||
assert _assert_all_outputs_have_different_required_inputs(chatinput.outputs)
|
||||
|
||||
|
||||
def test_set_required_inputs_various_components():
|
||||
chatinput = ChatInput()
|
||||
chatoutput = ChatOutput()
|
||||
task = SequentialTaskComponent()
|
||||
tool_calling_agent = ToolCallingAgentComponent()
|
||||
openai_component = OpenAIModelComponent()
|
||||
|
||||
assert all(_output_required_inputs_are_in_inputs(output, chatinput._inputs) for output in chatinput.outputs)
|
||||
assert all(_output_required_inputs_are_in_inputs(output, chatoutput._inputs) for output in chatoutput.outputs)
|
||||
assert all(_output_required_inputs_are_in_inputs(output, task._inputs) for output in task.outputs)
|
||||
assert all(
|
||||
_output_required_inputs_are_in_inputs(output, tool_calling_agent._inputs)
|
||||
for output in tool_calling_agent.outputs
|
||||
)
|
||||
assert all(
|
||||
_output_required_inputs_are_in_inputs(output, openai_component._inputs) for output in openai_component.outputs
|
||||
)
|
||||
|
||||
assert _assert_all_outputs_have_different_required_inputs(chatinput.outputs)
|
||||
assert _assert_all_outputs_have_different_required_inputs(chatoutput.outputs)
|
||||
assert _assert_all_outputs_have_different_required_inputs(task.outputs)
|
||||
assert _assert_all_outputs_have_different_required_inputs(tool_calling_agent.outputs)
|
||||
assert _assert_all_outputs_have_different_required_inputs(openai_component.outputs)
|
||||
|
|
|
|||
|
|
@ -145,13 +145,6 @@ def test_graph_functional_start_end():
|
|||
assert results[-1] == Finish()
|
||||
|
||||
|
||||
def test_graph_set_with_invalid_component():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
with pytest.raises(ValueError, match="There are multiple outputs"):
|
||||
chat_output.set(sender_name=chat_input)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily disabled")
|
||||
def test_graph_set_with_valid_component():
|
||||
tool = YfinanceToolComponent()
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ from langflow.inputs.inputs import (
|
|||
SecretStrInput,
|
||||
StrInput,
|
||||
TableInput,
|
||||
instantiate_input,
|
||||
)
|
||||
from langflow.inputs.utils import instantiate_input
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -218,18 +218,6 @@ class TestCreateInputSchema:
|
|||
assert field_info.description == ""
|
||||
|
||||
# Handling invalid field types
|
||||
def test_invalid_field_types_handling(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
class InvalidFieldType:
|
||||
pass
|
||||
|
||||
input_instance = StrInput(name="test_field")
|
||||
input_instance.field_type = InvalidFieldType()
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
create_input_schema([input_instance])
|
||||
|
||||
# Handling input types with None as default value
|
||||
def test_none_default_value_handling(self):
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
|
|||
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
from langflow.services.database.models.flow import FlowCreate
|
||||
from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -543,6 +544,14 @@ def test_custom_component_multiple_outputs(code_component_with_multiple_outputs,
|
|||
assert frontnd_node_dict["outputs"][0]["types"] == ["Text"]
|
||||
|
||||
|
||||
def test_feature_flags_add_toolkit_output(active_user, code_component_with_multiple_outputs):
|
||||
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
|
||||
len_outputs = len(frontnd_node_dict["outputs"])
|
||||
FEATURE_FLAGS.add_toolkit_output = True
|
||||
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
|
||||
assert len(frontnd_node_dict["outputs"]) == len_outputs + 1
|
||||
|
||||
|
||||
def test_custom_component_subclass_from_lctoolcomponent():
|
||||
# Import LCToolComponent and create a subclass
|
||||
code = dedent("""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue