From 7ae42bf7d2649380fe44e54d61f7ddfda4584d58 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Sat, 6 May 2023 10:26:21 -0300 Subject: [PATCH] refactor(importing/utils.py): move tool_creator import to import_tool function refactor(tools/base.py): remove redundant variable assignment refactor(tools/constants.py): import all tools dynamically using __all__ attribute of langchain.tools module --- .../langflow/interface/importing/utils.py | 9 ++-- src/backend/langflow/interface/tools/base.py | 1 + .../langflow/interface/tools/constants.py | 45 +++---------------- 3 files changed, 13 insertions(+), 42 deletions(-) diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index e303da0eb..499b70a65 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -10,8 +10,6 @@ from langchain.chat_models.base import BaseChatModel from langchain.llms.base import BaseLLM from langchain.tools import BaseTool -from langflow.interface.tools.base import tool_creator - def import_module(module_path: str) -> Any: """Import module from module path""" @@ -107,8 +105,13 @@ def import_llm(llm: str) -> BaseLLM: def import_tool(tool: str) -> BaseTool: """Import tool from tool name""" + from langflow.interface.tools.base import tool_creator + from langflow.interface.tools.constants import ALL_TOOLS_NAMES - return tool_creator.type_to_loader_dict[tool]["fcn"] + if tool in ALL_TOOLS_NAMES: + return tool_creator.type_to_loader_dict[tool]["fcn"] + + return import_class(f"langchain.tools.{tool}") def import_chain(chain: str) -> Type[Chain]: diff --git a/src/backend/langflow/interface/tools/base.py b/src/backend/langflow/interface/tools/base.py index b94f4d62e..10eeead03 100644 --- a/src/backend/langflow/interface/tools/base.py +++ b/src/backend/langflow/interface/tools/base.py @@ -65,6 +65,7 @@ class ToolCreator(LangChainTypeCreator): def type_to_loader_dict(self) -> Dict: if self.tools_dict is None: all_tools = {} + for tool, tool_fcn in ALL_TOOLS_NAMES.items(): tool_params = get_tool_params(tool_fcn) tool_name = tool_params.get("name", tool) diff --git a/src/backend/langflow/interface/tools/constants.py b/src/backend/langflow/interface/tools/constants.py index 34890a684..f939d55ad 100644 --- a/src/backend/langflow/interface/tools/constants.py +++ b/src/backend/langflow/interface/tools/constants.py @@ -1,3 +1,4 @@ +from langchain import tools from langchain.agents import Tool from langchain.agents.load_tools import ( _BASE_TOOLS, @@ -5,50 +6,16 @@ from langchain.agents.load_tools import ( _EXTRA_OPTIONAL_TOOLS, _LLM_TOOLS, ) -from langchain.tools.bing_search.tool import BingSearchRun -from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun -from langchain.tools.json.tool import JsonGetValueTool, JsonListKeysTool, JsonSpec -from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool -from langchain.tools.requests.tool import ( - RequestsDeleteTool, - RequestsGetTool, - RequestsPatchTool, - RequestsPostTool, - RequestsPutTool, -) -from langchain.tools.sql_database.tool import ( - InfoSQLDatabaseTool, - ListSQLDatabaseTool, - QueryCheckerTool, - QuerySQLDataBaseTool, -) -from langchain.tools.wikipedia.tool import WikipediaQueryRun -from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun +from langchain.tools.json.tool import JsonSpec +from langflow.interface.importing.utils import import_class from langflow.interface.tools.custom import PythonFunction FILE_TOOLS = {"JsonSpec": JsonSpec} CUSTOM_TOOLS = {"Tool": Tool, "PythonFunction": PythonFunction} -OTHER_TOOLS = { - "QuerySQLDataBaseTool": QuerySQLDataBaseTool, - "InfoSQLDatabaseTool": InfoSQLDatabaseTool, - "ListSQLDatabaseTool": ListSQLDatabaseTool, - "QueryCheckerTool": QueryCheckerTool, - "BingSearchRun": BingSearchRun, - "GoogleSearchRun": GoogleSearchRun, - "GoogleSearchResults": GoogleSearchResults, - "JsonListKeysTool": JsonListKeysTool, - "JsonGetValueTool": JsonGetValueTool, - "PythonREPLTool": PythonREPLTool, - "PythonAstREPLTool": PythonAstREPLTool, - "RequestsGetTool": RequestsGetTool, - "RequestsPostTool": RequestsPostTool, - "RequestsPatchTool": RequestsPatchTool, - "RequestsPutTool": RequestsPutTool, - "RequestsDeleteTool": RequestsDeleteTool, - "WikipediaQueryRun": WikipediaQueryRun, - "WolframAlphaQueryRun": WolframAlphaQueryRun, -} + +OTHER_TOOLS = {tool: import_class(f"langchain.tools.{tool}") for tool in tools.__all__} + ALL_TOOLS_NAMES = { **_BASE_TOOLS, **_LLM_TOOLS, # type: ignore