feat: small changes to test func

This commit is contained in:
Gabriel Almeida 2023-03-26 01:14:26 -03:00
commit f70f9339a2
2 changed files with 71 additions and 8 deletions

View file

@ -28,8 +28,9 @@ def get_chain_signature(name: str):
"""Get the chain type by signature."""
try:
return util.build_template_from_function(
name, chains.loading.type_to_loader_dict
name, chains.loading.type_to_loader_dict, add_function=True
)
except ValueError as exc:
raise ValueError("Chain not found") from exc
@ -37,7 +38,9 @@ def get_chain_signature(name: str):
def get_agent_signature(name: str):
"""Get the signature of an agent."""
try:
return util.build_template_from_class(name, agents.loading.AGENT_TO_CLASS)
return util.build_template_from_class(
name, agents.loading.AGENT_TO_CLASS, add_function=True
)
except ValueError as exc:
raise ValueError("Agent not found") from exc
@ -65,11 +68,13 @@ def get_llm_signature(name: str):
def get_tool_signature(name: str):
"""Get the signature of a tool."""
NODE_INPUTS = ["llm", "func"]
all_tools = {}
for tool in get_all_tool_names():
if tool_params := util.get_tool_params(util.get_tools_dict(tool)):
all_tools[tool_params["name"]] = tool
all_tools["BaseTool"] = "BaseTool"
# Raise error if name is not in tools
if name not in all_tools.keys():
raise ValueError("Tool not found")
@ -84,6 +89,14 @@ def get_tool_signature(name: str):
"value": "",
},
"llm": {"type": "BaseLLM", "required": True, "list": False, "show": True},
"func": {
"type": "function",
"required": True,
"list": False,
"show": True,
"value": "",
"multiline": True,
},
}
tool_type = all_tools[name]
@ -98,11 +111,15 @@ def get_tool_signature(name: str):
elif tool_type in _EXTRA_OPTIONAL_TOOLS:
_, extra_keys = _EXTRA_OPTIONAL_TOOLS[tool_type]
params = extra_keys
elif tool_type == "BaseTool":
params = ["name", "description", "func"]
else:
params = []
template = {
param: (type_dict[param].copy() if param == "llm" else type_dict["str"].copy())
param: (
type_dict[param].copy() if param in NODE_INPUTS else type_dict["str"].copy()
)
for param in params
}

View file

@ -11,10 +11,16 @@ from langchain.agents.load_tools import (
)
from typing import Optional, Dict
from langchain.agents.tools import Tool
from langchain.tools import BaseTool
from langflow.utils import constants
def build_template_from_function(name: str, type_to_loader_dict: Dict):
def build_template_from_function(
name: str, type_to_loader_dict: Dict, add_function: bool = False
):
classes = [
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
]
@ -52,6 +58,11 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict):
if class_field_items in docs["Attributes"]
else ""
)
# Adding function to base classes to allow
# the output to be a function
base_classes = get_base_classes(_class)
if add_function:
base_classes.append("function")
return {
"template": format_dict(variables, name),
@ -60,7 +71,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict):
}
def build_template_from_class(name: str, type_to_cls_dict: Dict):
def build_template_from_class(
name: str, type_to_cls_dict: Dict, add_function: bool = False
):
classes = [item.__name__ for item in type_to_cls_dict.values()]
# Raise error if name is not in chains
@ -96,15 +109,22 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict):
if class_field_items in docs["Attributes"]
else ""
)
base_classes = get_base_classes(_class)
# Adding function to base classes to allow
# the output to be a function
if add_function:
base_classes.append("function")
return {
"template": format_dict(variables, name),
"description": docs["Description"],
"base_classes": get_base_classes(_class),
"base_classes": base_classes,
}
def get_base_classes(cls):
"""Get the base classes of a class.
These are used to determine the output of the nodes.
"""
bases = cls.__bases__
if not bases:
return []
@ -127,6 +147,30 @@ def get_default_factory(module: str, function: str):
return None
class GenericTool(Tool):
"""Base class for all tools."""
def default_func(self, **kwargs):
"""Default function for the tool."""
return "Default function"
def __init__(
self,
name: str = "Tool name",
description: str = "Tool description",
func: callable = None,
):
"""Initialize the tool."""
super().__init__(name=name, description=description, func=func)
# from langchain.llms.base import BaseLLM
def get_base_tool(name, description, func: callable) -> BaseTool:
return GenericTool(func=func, name="Generic Tool", description="Bacon")
def get_tools_dict(name: Optional[str] = None):
"""Get the tools dictionary."""
tools = {
@ -135,11 +179,13 @@ def get_tools_dict(name: Optional[str] = None):
**{k: v[0] for k, v in _EXTRA_LLM_TOOLS.items()}, # type: ignore
**{k: v[0] for k, v in _EXTRA_OPTIONAL_TOOLS.items()},
}
tools.update({"BaseTool": get_base_tool})
return tools[name] if name else tools
def get_tool_params(func, **kwargs):
# Parse the function code into an abstract syntax tree
tree = ast.parse(inspect.getsource(func))
# Iterate over the statements in the abstract syntax tree
@ -266,7 +312,7 @@ def format_dict(d, name: Optional[str] = None):
_type = _type.replace("Mapping", "dict")
# Change type from str to Tool
value["type"] = "Tool" if key == "allowed_tools" else _type
value["type"] = "Tool" if key in ["allowed_tools", "func"] else _type
# Show or not field
value["show"] = bool(