refac: refactor tools and add QuerySQLDataBaseTool

This commit is contained in:
Ibis Prevedello 2023-04-17 15:59:15 -03:00 committed by Gabriel Luiz Freitas Almeida
commit 571f407ef3
5 changed files with 66 additions and 79 deletions

View file

@ -40,6 +40,10 @@ tools:
- Tool
- PythonFunction
- JsonSpec
- News API
- TMDB API
- Podcast API
- QuerySQLDataBaseTool
wrappers:
- RequestsWrapper

View file

@ -10,7 +10,7 @@ from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.tools import BaseTool
from langflow.interface.tools.util import get_tool_by_name
from langflow.interface.tools.base import tool_creator
def import_module(module_path: str) -> Any:
@ -107,7 +107,7 @@ def import_llm(llm: str) -> BaseLLM:
def import_tool(tool: str) -> BaseTool:
"""Import tool from tool name"""
return get_tool_by_name(tool)
return tool_creator.type_to_loader_dict[tool]["fcn"]
def import_chain(chain: str) -> Type[Chain]:

View file

@ -1,7 +1,7 @@
import inspect
from typing import Dict, List, Optional
from langchain.agents.load_tools import (
_BASE_TOOLS,
_EXTRA_LLM_TOOLS,
_EXTRA_OPTIONAL_TOOLS,
_LLM_TOOLS,
@ -10,17 +10,16 @@ from langchain.agents.load_tools import (
from langflow.custom import customs
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.tools.constants import (
ALL_TOOLS_NAMES,
CUSTOM_TOOLS,
FILE_TOOLS,
OTHER_TOOLS,
)
from langflow.interface.tools.util import (
get_tool_by_name,
get_tool_params,
get_tools_dict,
)
from langflow.interface.tools.util import get_tool_params
from langflow.settings import settings
from langflow.template.base import Template, TemplateField
from langflow.utils import util
from langflow.utils.util import build_template_from_class
TOOL_INPUTS = {
"str": TemplateField(
@ -66,64 +65,79 @@ class ToolCreator(LangChainTypeCreator):
@property
def type_to_loader_dict(self) -> Dict:
if self.tools_dict is None:
self.tools_dict = get_tools_dict()
all_tools = {}
for tool, tool_fcn in ALL_TOOLS_NAMES.items():
tool_params = get_tool_params(tool_fcn)
tool_name = tool_params.get("name", tool)
if tool_name in settings.tools or settings.dev:
if tool_name == "JsonSpec":
tool_params["path"] = tool_params.pop("dict_") # type: ignore
all_tools[tool_name] = {
"type": tool,
"params": tool_params,
"fcn": tool_fcn,
}
self.tools_dict = all_tools
return self.tools_dict
def get_signature(self, name: str) -> Optional[Dict]:
"""Get the signature of a tool."""
base_classes = ["Tool"]
all_tools = {}
for tool in self.type_to_loader_dict.keys():
tool_fcn = get_tool_by_name(tool)
if tool_params := get_tool_params(tool_fcn):
tool_name = tool_params.get("name") or str(tool)
all_tools[tool_name] = {
"type": tool,
"params": tool_params,
"fcn": tool_fcn,
}
fields = []
params = []
tool_params = {}
# Raise error if name is not in tools
if name not in all_tools.keys():
if name not in self.type_to_loader_dict.keys():
raise ValueError("Tool not found")
tool_type: str = all_tools[name]["type"] # type: ignore
tool_type: str = self.type_to_loader_dict[name]["type"] # type: ignore
if all_tools[tool_type]["fcn"] in _BASE_TOOLS.values():
params = []
elif all_tools[tool_type]["fcn"] in _LLM_TOOLS.values():
# if tool_type in _BASE_TOOLS.keys():
# params = []
if tool_type in _LLM_TOOLS.keys():
params = ["llm"]
elif all_tools[tool_type]["fcn"] in [
val[0] for val in _EXTRA_LLM_TOOLS.values()
]:
n_dict = {val[0]: val[1] for val in _EXTRA_LLM_TOOLS.values()}
extra_keys = n_dict[all_tools[tool_type]["fcn"]]
elif tool_type in _EXTRA_LLM_TOOLS.keys():
extra_keys = _EXTRA_LLM_TOOLS[tool_type][1]
params = ["llm"] + extra_keys
elif all_tools[tool_type]["fcn"] in [
val[0] for val in _EXTRA_OPTIONAL_TOOLS.values()
]:
n_dict = {val[0]: val[1] for val in _EXTRA_OPTIONAL_TOOLS.values()} # type: ignore
extra_keys = n_dict[all_tools[tool_type]["fcn"]]
elif tool_type in _EXTRA_OPTIONAL_TOOLS.keys():
extra_keys = _EXTRA_OPTIONAL_TOOLS[tool_type][1]
params = extra_keys
# elif tool_type == "Tool":
# params = ["name", "description", "func"]
elif tool_type in CUSTOM_TOOLS:
# Get custom tool params
params = all_tools[name]["params"] # type: ignore
params = self.type_to_loader_dict[name]["params"] # type: ignore
base_classes = ["function"]
if node := customs.get_custom_nodes("tools").get(tool_type):
return node
elif tool_type in FILE_TOOLS:
params = all_tools[name]["params"] # type: ignore
if tool_type == "JsonSpec":
params["path"] = params.pop("dict_") # type: ignore
params = self.type_to_loader_dict[name]["params"] # type: ignore
base_classes += [name]
else:
params = []
elif tool_type in OTHER_TOOLS:
tool_dict = build_template_from_class(tool_type, OTHER_TOOLS)
fields = tool_dict["template"]
# Pop unnecessary fields and add name
fields.pop("_type")
fields.pop("return_direct")
fields.pop("verbose")
tool_params = {
"name": fields.pop("name")["value"],
"description": fields.pop("description")["value"],
}
fields = [
TemplateField(name=name, **field) for name, field in fields.items()
]
base_classes += tool_dict["base_classes"]
# Copy the field and add the name
fields = []
for param in params:
field = TOOL_INPUTS.get(param, TOOL_INPUTS["str"]).copy()
field.name = param
@ -134,7 +148,7 @@ class ToolCreator(LangChainTypeCreator):
template = Template(fields=fields, type_name=tool_type)
tool_params = all_tools[name]["params"]
tool_params = {**tool_params, **self.type_to_loader_dict[name]["params"]}
return {
"template": util.format_dict(template.to_dict()),
**tool_params,
@ -144,21 +158,7 @@ class ToolCreator(LangChainTypeCreator):
def to_list(self) -> List[str]:
"""List all load tools"""
tools = []
for tool, fcn in get_tools_dict().items():
tool_params = get_tool_params(fcn)
if tool_params and not tool_params.get("name"):
tool_params["name"] = tool
if tool_params and (
tool_params.get("name") in settings.tools
or (tool_params.get("name") and settings.dev)
):
tools.append(tool_params["name"])
return tools
return list(self.type_to_loader_dict.keys())
tool_creator = ToolCreator()

View file

@ -6,11 +6,13 @@ from langchain.agents.load_tools import (
_LLM_TOOLS,
)
from langchain.tools.json.tool import JsonSpec
from langchain.tools.sql_database.tool import QuerySQLDataBaseTool
from langflow.interface.tools.custom import PythonFunction
FILE_TOOLS = {"JsonSpec": JsonSpec}
CUSTOM_TOOLS = {"Tool": Tool, "PythonFunction": PythonFunction}
OTHER_TOOLS = {"QuerySQLDataBaseTool": QuerySQLDataBaseTool}
ALL_TOOLS_NAMES = {
**_BASE_TOOLS,
**_LLM_TOOLS, # type: ignore
@ -18,4 +20,5 @@ ALL_TOOLS_NAMES = {
**{k: v[0] for k, v in _EXTRA_OPTIONAL_TOOLS.items()},
**CUSTOM_TOOLS,
**FILE_TOOLS, # type: ignore
**OTHER_TOOLS,
}

View file

@ -4,28 +4,6 @@ from typing import Dict, Union
from langchain.agents.tools import Tool
from langflow.interface.tools.constants import ALL_TOOLS_NAMES
def get_tools_dict():
"""Get the tools dictionary."""
all_tools = {}
for tool, fcn in ALL_TOOLS_NAMES.items():
if tool_params := get_tool_params(fcn):
tool_name = tool_params.get("name") or str(tool)
all_tools[tool_name] = fcn
return all_tools
def get_tool_by_name(name: str):
"""Get a tool from the tools dictionary."""
tools = get_tools_dict()
if name not in tools:
raise ValueError(f"{name} not found.")
return tools[name]
def get_func_tool_params(func, **kwargs) -> Union[Dict, None]:
@ -113,6 +91,8 @@ def get_tool_params(tool, **kwargs) -> Dict:
elif inspect.isclass(tool):
# Get the parameters necessary to
# instantiate the class
return get_class_tool_params(tool, **kwargs) or {}
else:
raise ValueError("Tool must be a function or class.")