diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 605df68ad..9236d5996 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -6,6 +6,7 @@ chains: - SeriesCharacterChain - MidJourneyPromptChain - TimeTravelGuideChain + - SQLDatabaseChain agents: - ZeroShotAgent @@ -122,5 +123,6 @@ utilities: - WikipediaAPIWrapper - WolframAlphaAPIWrapper # - ZapierNLAWrapper + - SQLDatabase dev: false diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index e77b81ec6..d45221be7 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -12,6 +12,9 @@ CUSTOM_NODES = { "VectorStoreRouterAgent": nodes.VectorStoreRouterAgentNode(), "SQLAgent": nodes.SQLAgentNode(), }, + "utilities": { + "SQLDatabase": nodes.SQLDatabaseNode(), + }, } diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index ff586c6da..6d998eed6 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -202,7 +202,11 @@ class Node: "VectorStoreRouterAgent", "VectorStoreAgent", "VectorStoreInfo", - ] or self.node_type in ["VectorStoreInfo", "VectorStoreRouterToolkit"]: + ] or self.node_type in [ + "VectorStoreInfo", + "VectorStoreRouterToolkit", + "SQLDatabase", + ]: return self._built_object return deepcopy(self._built_object) diff --git a/src/backend/langflow/graph/nodes.py b/src/backend/langflow/graph/nodes.py index 7296a0c0d..018174334 100644 --- a/src/backend/langflow/graph/nodes.py +++ b/src/backend/langflow/graph/nodes.py @@ -101,6 +101,10 @@ class ChainNode(Node): self.params[key] = value.build(tools=tools, force=force) self._build() + + #! Cannot deepcopy SQLDatabaseChain + if self.node_type in ["SQLDatabaseChain"]: + return self._built_object return deepcopy(self._built_object) diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index fb97e8dae..f07b03f04 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -14,6 +14,7 @@ from langchain import ( ) from langchain.agents import agent_toolkits from langchain.chat_models import ChatOpenAI +from langchain.sql_database import SQLDatabase from langflow.interface.importing.utils import import_class @@ -82,3 +83,4 @@ textsplitter_type_to_cls_dict: dict[str, Any] = dict( utility_type_to_cls_dict: dict[str, Any] = dict( inspect.getmembers(utilities, inspect.isclass) ) +utility_type_to_cls_dict["SQLDatabase"] = SQLDatabase diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index c426eaf85..e303da0eb 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -44,6 +44,7 @@ def import_by_type(_type: str, name: str) -> Any: "vectorstores": import_vectorstore, "documentloaders": import_documentloader, "textsplitters": import_textsplitter, + "utilities": import_utility, } if _type == "llms": key = "chat" if "chat" in name.lower() else "llm" @@ -131,10 +132,16 @@ def import_vectorstore(vectorstore: str) -> Any: def import_documentloader(documentloader: str) -> Any: """Import documentloader from documentloader name""" - return import_class(f"langchain.document_loaders.{documentloader}") def import_textsplitter(textsplitter: str) -> Any: """Import textsplitter from textsplitter name""" return import_class(f"langchain.text_splitter.{textsplitter}") + + +def import_utility(utility: str) -> Any: + """Import utility from utility name""" + if utility == "SQLDatabase": + return import_class(f"langchain.sql_database.{utility}") + return import_class(f"langchain.utilities.{utility}") diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 11db47ee6..0fa1c5f2e 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -75,6 +75,9 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: documents = params.pop("documents") text_splitter = class_object(**params) return text_splitter.split_documents(documents) + elif base_type == "utilities": + if node_type == "SQLDatabase": + return class_object.from_uri(params.pop("uri")) return class_object(**params) diff --git a/src/backend/langflow/interface/tools/base.py b/src/backend/langflow/interface/tools/base.py index f236831d8..5fd0c72f0 100644 --- a/src/backend/langflow/interface/tools/base.py +++ b/src/backend/langflow/interface/tools/base.py @@ -1,4 +1,3 @@ -import inspect from typing import Dict, List, Optional from langchain.agents.load_tools import ( diff --git a/src/backend/langflow/interface/utilities/base.py b/src/backend/langflow/interface/utilities/base.py index a56e8bce8..e60e344ad 100644 --- a/src/backend/langflow/interface/utilities/base.py +++ b/src/backend/langflow/interface/utilities/base.py @@ -1,5 +1,6 @@ from typing import Dict, List, Optional +from langflow.custom.customs import get_custom_nodes from langflow.interface.base import LangChainTypeCreator from langflow.interface.custom_lists import utility_type_to_cls_dict from langflow.settings import settings @@ -17,6 +18,8 @@ class UtilityCreator(LangChainTypeCreator): def get_signature(self, name: str) -> Optional[Dict]: """Get the signature of a utility.""" try: + if name in get_custom_nodes(self.type_name).keys(): + return get_custom_nodes(self.type_name)[name] return build_template_from_class(name, utility_type_to_cls_dict) except ValueError as exc: raise ValueError(f"Utility {name} not found") from exc diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 6ac026e59..f2e8bd94f 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -256,6 +256,29 @@ class CSVAgentNode(FrontendNode): return super().to_dict() +class SQLDatabaseNode(FrontendNode): + name: str = "SQLDatabase" + template: Template = Template( + type_name="sql_database", + fields=[ + TemplateField( + field_type="str", + required=True, + is_list=False, + show=True, + multiline=False, + value="", + name="uri", + ), + ], + ) + description: str = """SQLAlchemy wrapper around a database.""" + base_classes: list[str] = ["SQLDatabase"] + + def to_dict(self): + return super().to_dict() + + class VectorStoreAgentNode(FrontendNode): name: str = "VectorStoreAgent" template: Template = Template(