diff --git a/src/backend/langflow/interface/tools/base.py b/src/backend/langflow/interface/tools/base.py index 2e00ea0d2..0bb395ac9 100644 --- a/src/backend/langflow/interface/tools/base.py +++ b/src/backend/langflow/interface/tools/base.py @@ -2,9 +2,10 @@ from langflow.custom import customs from langflow.interface.tools.constants import ( ALL_TOOLS_NAMES, CUSTOM_TOOLS, - OTHER_TOOLS, + FILE_TOOLS, ) -from langflow.template.template import Field, Template +from langflow.template.base import Field +from langflow.template.base import Template from langflow.utils import util from langflow.settings import settings from langflow.interface.base import LangChainTypeCreator @@ -22,6 +23,41 @@ from langflow.interface.tools.util import ( ) +TOOL_INPUTS = { + "str": Field( + field_type="str", + required=True, + is_list=False, + show=True, + placeholder="", + value="", + ), + "llm": Field(field_type="BaseLLM", required=True, is_list=False, show=True), + "func": Field( + field_type="function", + required=True, + is_list=False, + show=True, + multiline=True, + ), + "code": Field( + field_type="str", + required=True, + is_list=False, + show=True, + value="", + multiline=True, + ), + "dict_": Field( + field_type="file", + required=True, + is_list=False, + show=True, + value="", + ), +} + + class ToolCreator(LangChainTypeCreator): type_name: str = "tools" tools_dict: Dict | None = None @@ -35,7 +71,6 @@ class ToolCreator(LangChainTypeCreator): def get_signature(self, name: str) -> Dict | None: """Get the signature of a tool.""" - NODE_INPUTS = ["llm", "func"] base_classes = ["Tool"] all_tools = {} for tool in self.type_to_loader_dict.keys(): @@ -47,40 +82,6 @@ class ToolCreator(LangChainTypeCreator): if name not in all_tools.keys(): raise ValueError("Tool not found") - type_dict = { - "str": Field( - field_type="str", - required=True, - is_list=False, - show=True, - placeholder="", - value="", - ), - "llm": Field(field_type="BaseLLM", required=True, is_list=False, show=True), - "func": Field( - field_type="function", - required=True, - is_list=False, - show=True, - multiline=True, - ), - "code": Field( - field_type="str", - required=True, - is_list=False, - show=True, - value="", - multiline=True, - ), - "dict_": Field( - field_type="file", - required=True, - is_list=False, - show=True, - value="", - ), - } - tool_type: str = all_tools[name]["type"] # type: ignore if tool_type in _BASE_TOOLS: @@ -101,8 +102,9 @@ class ToolCreator(LangChainTypeCreator): base_classes = ["function"] if node := customs.get_custom_nodes("tools").get(tool_type): return node - elif tool_type in OTHER_TOOLS: + elif tool_type in FILE_TOOLS: params = all_tools[name]["params"] # type: ignore + base_classes += [name] else: params = [] @@ -110,10 +112,7 @@ class ToolCreator(LangChainTypeCreator): # Copy the field and add the name fields = [] for param in params: - if param in NODE_INPUTS: - field = type_dict[param].copy() - else: - field = type_dict["str"].copy() + field = TOOL_INPUTS.get(param, TOOL_INPUTS["str"]) field.name = param if param == "aiosession": field.show = False @@ -122,9 +121,7 @@ class ToolCreator(LangChainTypeCreator): template = Template(fields=fields, type_name=tool_type) - tool_params = get_tool_params(get_tool_by_name(tool_type)) - if tool_params is None: - tool_params = {} + tool_params = all_tools[name]["params"] return { "template": util.format_dict(template.to_dict()), **tool_params,