From 571f407ef3eadfaa695f896c1cdc812a3f49ae47 Mon Sep 17 00:00:00 2001 From: Ibis Prevedello Date: Mon, 17 Apr 2023 15:59:15 -0300 Subject: [PATCH] refac: refactor tools and add QuerySQLDataBaseTool --- src/backend/langflow/config.yaml | 4 + .../langflow/interface/importing/utils.py | 4 +- src/backend/langflow/interface/tools/base.py | 110 +++++++++--------- .../langflow/interface/tools/constants.py | 3 + src/backend/langflow/interface/tools/util.py | 24 +--- 5 files changed, 66 insertions(+), 79 deletions(-) diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 7f00167b4..3508f5196 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -40,6 +40,10 @@ tools: - Tool - PythonFunction - JsonSpec + - News API + - TMDB API + - Podcast API + - QuerySQLDataBaseTool wrappers: - RequestsWrapper diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index a3480928e..c426eaf85 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -10,7 +10,7 @@ from langchain.chat_models.base import BaseChatModel from langchain.llms.base import BaseLLM from langchain.tools import BaseTool -from langflow.interface.tools.util import get_tool_by_name +from langflow.interface.tools.base import tool_creator def import_module(module_path: str) -> Any: @@ -107,7 +107,7 @@ def import_llm(llm: str) -> BaseLLM: def import_tool(tool: str) -> BaseTool: """Import tool from tool name""" - return get_tool_by_name(tool) + return tool_creator.type_to_loader_dict[tool]["fcn"] def import_chain(chain: str) -> Type[Chain]: diff --git a/src/backend/langflow/interface/tools/base.py b/src/backend/langflow/interface/tools/base.py index 6a3439bf0..a5745f1d3 100644 --- a/src/backend/langflow/interface/tools/base.py +++ b/src/backend/langflow/interface/tools/base.py @@ -1,7 +1,7 @@ +import inspect from typing import Dict, List, Optional from langchain.agents.load_tools import ( - _BASE_TOOLS, _EXTRA_LLM_TOOLS, _EXTRA_OPTIONAL_TOOLS, _LLM_TOOLS, @@ -10,17 +10,16 @@ from langchain.agents.load_tools import ( from langflow.custom import customs from langflow.interface.base import LangChainTypeCreator from langflow.interface.tools.constants import ( + ALL_TOOLS_NAMES, CUSTOM_TOOLS, FILE_TOOLS, + OTHER_TOOLS, ) -from langflow.interface.tools.util import ( - get_tool_by_name, - get_tool_params, - get_tools_dict, -) +from langflow.interface.tools.util import get_tool_params from langflow.settings import settings from langflow.template.base import Template, TemplateField from langflow.utils import util +from langflow.utils.util import build_template_from_class TOOL_INPUTS = { "str": TemplateField( @@ -66,64 +65,79 @@ class ToolCreator(LangChainTypeCreator): @property def type_to_loader_dict(self) -> Dict: if self.tools_dict is None: - self.tools_dict = get_tools_dict() + all_tools = {} + for tool, tool_fcn in ALL_TOOLS_NAMES.items(): + tool_params = get_tool_params(tool_fcn) + tool_name = tool_params.get("name", tool) + + if tool_name in settings.tools or settings.dev: + if tool_name == "JsonSpec": + tool_params["path"] = tool_params.pop("dict_") # type: ignore + all_tools[tool_name] = { + "type": tool, + "params": tool_params, + "fcn": tool_fcn, + } + + self.tools_dict = all_tools + return self.tools_dict def get_signature(self, name: str) -> Optional[Dict]: """Get the signature of a tool.""" base_classes = ["Tool"] - all_tools = {} - for tool in self.type_to_loader_dict.keys(): - tool_fcn = get_tool_by_name(tool) - if tool_params := get_tool_params(tool_fcn): - tool_name = tool_params.get("name") or str(tool) - all_tools[tool_name] = { - "type": tool, - "params": tool_params, - "fcn": tool_fcn, - } + fields = [] + params = [] + tool_params = {} # Raise error if name is not in tools - if name not in all_tools.keys(): + if name not in self.type_to_loader_dict.keys(): raise ValueError("Tool not found") - tool_type: str = all_tools[name]["type"] # type: ignore + tool_type: str = self.type_to_loader_dict[name]["type"] # type: ignore - if all_tools[tool_type]["fcn"] in _BASE_TOOLS.values(): - params = [] - elif all_tools[tool_type]["fcn"] in _LLM_TOOLS.values(): + # if tool_type in _BASE_TOOLS.keys(): + # params = [] + if tool_type in _LLM_TOOLS.keys(): params = ["llm"] - elif all_tools[tool_type]["fcn"] in [ - val[0] for val in _EXTRA_LLM_TOOLS.values() - ]: - n_dict = {val[0]: val[1] for val in _EXTRA_LLM_TOOLS.values()} - extra_keys = n_dict[all_tools[tool_type]["fcn"]] + elif tool_type in _EXTRA_LLM_TOOLS.keys(): + extra_keys = _EXTRA_LLM_TOOLS[tool_type][1] params = ["llm"] + extra_keys - elif all_tools[tool_type]["fcn"] in [ - val[0] for val in _EXTRA_OPTIONAL_TOOLS.values() - ]: - n_dict = {val[0]: val[1] for val in _EXTRA_OPTIONAL_TOOLS.values()} # type: ignore - extra_keys = n_dict[all_tools[tool_type]["fcn"]] + elif tool_type in _EXTRA_OPTIONAL_TOOLS.keys(): + extra_keys = _EXTRA_OPTIONAL_TOOLS[tool_type][1] params = extra_keys # elif tool_type == "Tool": # params = ["name", "description", "func"] elif tool_type in CUSTOM_TOOLS: # Get custom tool params - params = all_tools[name]["params"] # type: ignore + params = self.type_to_loader_dict[name]["params"] # type: ignore base_classes = ["function"] if node := customs.get_custom_nodes("tools").get(tool_type): return node elif tool_type in FILE_TOOLS: - params = all_tools[name]["params"] # type: ignore - if tool_type == "JsonSpec": - params["path"] = params.pop("dict_") # type: ignore + params = self.type_to_loader_dict[name]["params"] # type: ignore base_classes += [name] - else: - params = [] + elif tool_type in OTHER_TOOLS: + tool_dict = build_template_from_class(tool_type, OTHER_TOOLS) + fields = tool_dict["template"] + + # Pop unnecessary fields and add name + fields.pop("_type") + fields.pop("return_direct") + fields.pop("verbose") + + tool_params = { + "name": fields.pop("name")["value"], + "description": fields.pop("description")["value"], + } + + fields = [ + TemplateField(name=name, **field) for name, field in fields.items() + ] + base_classes += tool_dict["base_classes"] # Copy the field and add the name - fields = [] for param in params: field = TOOL_INPUTS.get(param, TOOL_INPUTS["str"]).copy() field.name = param @@ -134,7 +148,7 @@ class ToolCreator(LangChainTypeCreator): template = Template(fields=fields, type_name=tool_type) - tool_params = all_tools[name]["params"] + tool_params = {**tool_params, **self.type_to_loader_dict[name]["params"]} return { "template": util.format_dict(template.to_dict()), **tool_params, @@ -144,21 +158,7 @@ class ToolCreator(LangChainTypeCreator): def to_list(self) -> List[str]: """List all load tools""" - tools = [] - - for tool, fcn in get_tools_dict().items(): - tool_params = get_tool_params(fcn) - - if tool_params and not tool_params.get("name"): - tool_params["name"] = tool - - if tool_params and ( - tool_params.get("name") in settings.tools - or (tool_params.get("name") and settings.dev) - ): - tools.append(tool_params["name"]) - - return tools + return list(self.type_to_loader_dict.keys()) tool_creator = ToolCreator() diff --git a/src/backend/langflow/interface/tools/constants.py b/src/backend/langflow/interface/tools/constants.py index 2ec20cef9..101fefa8b 100644 --- a/src/backend/langflow/interface/tools/constants.py +++ b/src/backend/langflow/interface/tools/constants.py @@ -6,11 +6,13 @@ from langchain.agents.load_tools import ( _LLM_TOOLS, ) from langchain.tools.json.tool import JsonSpec +from langchain.tools.sql_database.tool import QuerySQLDataBaseTool from langflow.interface.tools.custom import PythonFunction FILE_TOOLS = {"JsonSpec": JsonSpec} CUSTOM_TOOLS = {"Tool": Tool, "PythonFunction": PythonFunction} +OTHER_TOOLS = {"QuerySQLDataBaseTool": QuerySQLDataBaseTool} ALL_TOOLS_NAMES = { **_BASE_TOOLS, **_LLM_TOOLS, # type: ignore @@ -18,4 +20,5 @@ ALL_TOOLS_NAMES = { **{k: v[0] for k, v in _EXTRA_OPTIONAL_TOOLS.items()}, **CUSTOM_TOOLS, **FILE_TOOLS, # type: ignore + **OTHER_TOOLS, } diff --git a/src/backend/langflow/interface/tools/util.py b/src/backend/langflow/interface/tools/util.py index 8f8673b6b..41fb4bdb7 100644 --- a/src/backend/langflow/interface/tools/util.py +++ b/src/backend/langflow/interface/tools/util.py @@ -4,28 +4,6 @@ from typing import Dict, Union from langchain.agents.tools import Tool -from langflow.interface.tools.constants import ALL_TOOLS_NAMES - - -def get_tools_dict(): - """Get the tools dictionary.""" - - all_tools = {} - - for tool, fcn in ALL_TOOLS_NAMES.items(): - if tool_params := get_tool_params(fcn): - tool_name = tool_params.get("name") or str(tool) - all_tools[tool_name] = fcn - - return all_tools - - -def get_tool_by_name(name: str): - """Get a tool from the tools dictionary.""" - tools = get_tools_dict() - if name not in tools: - raise ValueError(f"{name} not found.") - return tools[name] def get_func_tool_params(func, **kwargs) -> Union[Dict, None]: @@ -113,6 +91,8 @@ def get_tool_params(tool, **kwargs) -> Dict: elif inspect.isclass(tool): # Get the parameters necessary to # instantiate the class + return get_class_tool_params(tool, **kwargs) or {} + else: raise ValueError("Tool must be a function or class.")