parent
843ae8efc5
commit
46e76c8ca4
10 changed files with 53 additions and 3 deletions
|
|
@ -6,6 +6,7 @@ chains:
|
|||
- SeriesCharacterChain
|
||||
- MidJourneyPromptChain
|
||||
- TimeTravelGuideChain
|
||||
- SQLDatabaseChain
|
||||
|
||||
agents:
|
||||
- ZeroShotAgent
|
||||
|
|
@ -122,5 +123,6 @@ utilities:
|
|||
- WikipediaAPIWrapper
|
||||
- WolframAlphaAPIWrapper
|
||||
# - ZapierNLAWrapper
|
||||
- SQLDatabase
|
||||
|
||||
dev: false
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ CUSTOM_NODES = {
|
|||
"VectorStoreRouterAgent": nodes.VectorStoreRouterAgentNode(),
|
||||
"SQLAgent": nodes.SQLAgentNode(),
|
||||
},
|
||||
"utilities": {
|
||||
"SQLDatabase": nodes.SQLDatabaseNode(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -202,7 +202,11 @@ class Node:
|
|||
"VectorStoreRouterAgent",
|
||||
"VectorStoreAgent",
|
||||
"VectorStoreInfo",
|
||||
] or self.node_type in ["VectorStoreInfo", "VectorStoreRouterToolkit"]:
|
||||
] or self.node_type in [
|
||||
"VectorStoreInfo",
|
||||
"VectorStoreRouterToolkit",
|
||||
"SQLDatabase",
|
||||
]:
|
||||
return self._built_object
|
||||
return deepcopy(self._built_object)
|
||||
|
||||
|
|
|
|||
|
|
@ -101,6 +101,10 @@ class ChainNode(Node):
|
|||
self.params[key] = value.build(tools=tools, force=force)
|
||||
|
||||
self._build()
|
||||
|
||||
#! Cannot deepcopy SQLDatabaseChain
|
||||
if self.node_type in ["SQLDatabaseChain"]:
|
||||
return self._built_object
|
||||
return deepcopy(self._built_object)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from langchain import (
|
|||
)
|
||||
from langchain.agents import agent_toolkits
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
|
||||
|
|
@ -82,3 +83,4 @@ textsplitter_type_to_cls_dict: dict[str, Any] = dict(
|
|||
utility_type_to_cls_dict: dict[str, Any] = dict(
|
||||
inspect.getmembers(utilities, inspect.isclass)
|
||||
)
|
||||
utility_type_to_cls_dict["SQLDatabase"] = SQLDatabase
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ def import_by_type(_type: str, name: str) -> Any:
|
|||
"vectorstores": import_vectorstore,
|
||||
"documentloaders": import_documentloader,
|
||||
"textsplitters": import_textsplitter,
|
||||
"utilities": import_utility,
|
||||
}
|
||||
if _type == "llms":
|
||||
key = "chat" if "chat" in name.lower() else "llm"
|
||||
|
|
@ -131,10 +132,16 @@ def import_vectorstore(vectorstore: str) -> Any:
|
|||
|
||||
def import_documentloader(documentloader: str) -> Any:
|
||||
"""Import documentloader from documentloader name"""
|
||||
|
||||
return import_class(f"langchain.document_loaders.{documentloader}")
|
||||
|
||||
|
||||
def import_textsplitter(textsplitter: str) -> Any:
|
||||
"""Import textsplitter from textsplitter name"""
|
||||
return import_class(f"langchain.text_splitter.{textsplitter}")
|
||||
|
||||
|
||||
def import_utility(utility: str) -> Any:
|
||||
"""Import utility from utility name"""
|
||||
if utility == "SQLDatabase":
|
||||
return import_class(f"langchain.sql_database.{utility}")
|
||||
return import_class(f"langchain.utilities.{utility}")
|
||||
|
|
|
|||
|
|
@ -75,6 +75,9 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
documents = params.pop("documents")
|
||||
text_splitter = class_object(**params)
|
||||
return text_splitter.split_documents(documents)
|
||||
elif base_type == "utilities":
|
||||
if node_type == "SQLDatabase":
|
||||
return class_object.from_uri(params.pop("uri"))
|
||||
|
||||
return class_object(**params)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import inspect
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.agents.load_tools import (
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import utility_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
|
|
@ -17,6 +18,8 @@ class UtilityCreator(LangChainTypeCreator):
|
|||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
"""Get the signature of a utility."""
|
||||
try:
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
return build_template_from_class(name, utility_type_to_cls_dict)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Utility {name} not found") from exc
|
||||
|
|
|
|||
|
|
@ -256,6 +256,29 @@ class CSVAgentNode(FrontendNode):
|
|||
return super().to_dict()
|
||||
|
||||
|
||||
class SQLDatabaseNode(FrontendNode):
|
||||
name: str = "SQLDatabase"
|
||||
template: Template = Template(
|
||||
type_name="sql_database",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=False,
|
||||
value="",
|
||||
name="uri",
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = """SQLAlchemy wrapper around a database."""
|
||||
base_classes: list[str] = ["SQLDatabase"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
|
||||
class VectorStoreAgentNode(FrontendNode):
|
||||
name: str = "VectorStoreAgent"
|
||||
template: Template = Template(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue