🔧 chore(config): add VectorStoreToolkit to toolkits list
🐛 fix(base.py): remove deepcopy for VectorStore and VectorStoreRouter agents 🐛 fix(nodes.py): remove deepcopy for VectorStore and VectorStoreRouter agents 🔧 chore(loading.py): comment out unused code for loading toolkits 🐛 fix(toolkits/base.py): add Tool to base_classes in get_signature method The changes to the config file add the VectorStoreToolkit to the list of toolkits. The deepcopy for VectorStore and VectorStoreRouter agents was causing issues, so it was removed from the base.py and nodes.py files. The loading.py file had some unused code for loading toolkits, so it was commented out. Finally, the base.py file had a bug where the Tool class was not being added to the base_classes list in the get_signature method, so it was added.
This commit is contained in:
parent
09cb59d9d6
commit
a9c7fc0a69
5 changed files with 17 additions and 25 deletions
|
|
@ -74,6 +74,7 @@ toolkits:
|
|||
- JsonToolkit
|
||||
- VectorStoreInfo
|
||||
- VectorStoreRouterToolkit
|
||||
- VectorStoreToolkit
|
||||
tools:
|
||||
- Search
|
||||
- PAL-MATH
|
||||
|
|
|
|||
|
|
@ -212,19 +212,7 @@ class Node:
|
|||
if not self._built or force:
|
||||
self._build()
|
||||
|
||||
#! Deepcopy is breaking for vectorstores
|
||||
if self.base_type in [
|
||||
"vectorstores",
|
||||
"VectorStoreRouterAgent",
|
||||
"VectorStoreAgent",
|
||||
"VectorStoreInfo",
|
||||
] or self.node_type in [
|
||||
"VectorStoreInfo",
|
||||
"VectorStoreRouterToolkit",
|
||||
"SQLDatabase",
|
||||
]:
|
||||
return self._built_object
|
||||
return deepcopy(self._built_object)
|
||||
return self._built_object
|
||||
|
||||
def add_edge(self, edge: "Edge") -> None:
|
||||
self.edges.append(edge)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class AgentNode(Node):
|
|||
def _set_tools_and_chains(self) -> None:
|
||||
for edge in self.edges:
|
||||
source_node = edge.source
|
||||
if isinstance(source_node, ToolNode):
|
||||
if isinstance(source_node, (ToolNode, ToolkitNode)):
|
||||
self.tools.append(source_node)
|
||||
elif isinstance(source_node, ChainNode):
|
||||
self.chains.append(source_node)
|
||||
|
|
@ -32,9 +32,6 @@ class AgentNode(Node):
|
|||
|
||||
self._build()
|
||||
|
||||
#! Cannot deepcopy VectorStore, VectorStoreRouter, or SQL agents
|
||||
if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent", "SQLAgent"]:
|
||||
return self._built_object
|
||||
return self._built_object
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -101,8 +101,11 @@ def instantiate_tool(node_type, class_object, params):
|
|||
|
||||
def instantiate_toolkit(node_type, class_object, params):
|
||||
loaded_toolkit = class_object(**params)
|
||||
if toolkits_creator.has_create_function(node_type):
|
||||
return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
# Commenting this out for now to use toolkits as normal tools
|
||||
# if toolkits_creator.has_create_function(node_type):
|
||||
# return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
if isinstance(loaded_toolkit, BaseToolkit):
|
||||
return loaded_toolkit.get_tools()
|
||||
return loaded_toolkit
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -42,24 +42,27 @@ class ToolkitCreator(LangChainTypeCreator):
|
|||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
try:
|
||||
return build_template_from_class(name, self.type_to_loader_dict)
|
||||
template = build_template_from_class(name, self.type_to_loader_dict)
|
||||
# add Tool to base_classes
|
||||
if template:
|
||||
template["base_classes"].append("Tool")
|
||||
return template
|
||||
except ValueError as exc:
|
||||
raise ValueError("Prompt not found") from exc
|
||||
raise ValueError("Toolkit not found") from exc
|
||||
except AttributeError as exc:
|
||||
logger.error(f"Prompt {name} not loaded: {exc}")
|
||||
logger.error(f"Toolkit {name} not loaded: {exc}")
|
||||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
return list(self.type_to_loader_dict.keys())
|
||||
|
||||
def get_create_function(self, name: str) -> Callable:
|
||||
if loader_name := self.create_functions.get(name, None):
|
||||
# import loader
|
||||
if loader_name := self.create_functions.get(name):
|
||||
return import_module(
|
||||
f"from langchain.agents.agent_toolkits import {loader_name[0]}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Loader not found")
|
||||
raise ValueError("Toolkit not found")
|
||||
|
||||
def has_create_function(self, name: str) -> bool:
|
||||
# check if the function list is not empty
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue