feat: small changes to test func
This commit is contained in:
parent
cdb01f2a39
commit
f70f9339a2
2 changed files with 71 additions and 8 deletions
|
|
@ -28,8 +28,9 @@ def get_chain_signature(name: str):
|
|||
"""Get the chain type by signature."""
|
||||
try:
|
||||
return util.build_template_from_function(
|
||||
name, chains.loading.type_to_loader_dict
|
||||
name, chains.loading.type_to_loader_dict, add_function=True
|
||||
)
|
||||
|
||||
except ValueError as exc:
|
||||
raise ValueError("Chain not found") from exc
|
||||
|
||||
|
|
@ -37,7 +38,9 @@ def get_chain_signature(name: str):
|
|||
def get_agent_signature(name: str):
|
||||
"""Get the signature of an agent."""
|
||||
try:
|
||||
return util.build_template_from_class(name, agents.loading.AGENT_TO_CLASS)
|
||||
return util.build_template_from_class(
|
||||
name, agents.loading.AGENT_TO_CLASS, add_function=True
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Agent not found") from exc
|
||||
|
||||
|
|
@ -65,11 +68,13 @@ def get_llm_signature(name: str):
|
|||
def get_tool_signature(name: str):
|
||||
"""Get the signature of a tool."""
|
||||
|
||||
NODE_INPUTS = ["llm", "func"]
|
||||
all_tools = {}
|
||||
for tool in get_all_tool_names():
|
||||
if tool_params := util.get_tool_params(util.get_tools_dict(tool)):
|
||||
all_tools[tool_params["name"]] = tool
|
||||
|
||||
all_tools["BaseTool"] = "BaseTool"
|
||||
# Raise error if name is not in tools
|
||||
if name not in all_tools.keys():
|
||||
raise ValueError("Tool not found")
|
||||
|
|
@ -84,6 +89,14 @@ def get_tool_signature(name: str):
|
|||
"value": "",
|
||||
},
|
||||
"llm": {"type": "BaseLLM", "required": True, "list": False, "show": True},
|
||||
"func": {
|
||||
"type": "function",
|
||||
"required": True,
|
||||
"list": False,
|
||||
"show": True,
|
||||
"value": "",
|
||||
"multiline": True,
|
||||
},
|
||||
}
|
||||
|
||||
tool_type = all_tools[name]
|
||||
|
|
@ -98,11 +111,15 @@ def get_tool_signature(name: str):
|
|||
elif tool_type in _EXTRA_OPTIONAL_TOOLS:
|
||||
_, extra_keys = _EXTRA_OPTIONAL_TOOLS[tool_type]
|
||||
params = extra_keys
|
||||
elif tool_type == "BaseTool":
|
||||
params = ["name", "description", "func"]
|
||||
else:
|
||||
params = []
|
||||
|
||||
template = {
|
||||
param: (type_dict[param].copy() if param == "llm" else type_dict["str"].copy())
|
||||
param: (
|
||||
type_dict[param].copy() if param in NODE_INPUTS else type_dict["str"].copy()
|
||||
)
|
||||
for param in params
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,10 +11,16 @@ from langchain.agents.load_tools import (
|
|||
)
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from langflow.utils import constants
|
||||
|
||||
|
||||
def build_template_from_function(name: str, type_to_loader_dict: Dict):
|
||||
def build_template_from_function(
|
||||
name: str, type_to_loader_dict: Dict, add_function: bool = False
|
||||
):
|
||||
classes = [
|
||||
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
|
||||
]
|
||||
|
|
@ -52,6 +58,11 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict):
|
|||
if class_field_items in docs["Attributes"]
|
||||
else ""
|
||||
)
|
||||
# Adding function to base classes to allow
|
||||
# the output to be a function
|
||||
base_classes = get_base_classes(_class)
|
||||
if add_function:
|
||||
base_classes.append("function")
|
||||
|
||||
return {
|
||||
"template": format_dict(variables, name),
|
||||
|
|
@ -60,7 +71,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict):
|
|||
}
|
||||
|
||||
|
||||
def build_template_from_class(name: str, type_to_cls_dict: Dict):
|
||||
def build_template_from_class(
|
||||
name: str, type_to_cls_dict: Dict, add_function: bool = False
|
||||
):
|
||||
classes = [item.__name__ for item in type_to_cls_dict.values()]
|
||||
|
||||
# Raise error if name is not in chains
|
||||
|
|
@ -96,15 +109,22 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict):
|
|||
if class_field_items in docs["Attributes"]
|
||||
else ""
|
||||
)
|
||||
|
||||
base_classes = get_base_classes(_class)
|
||||
# Adding function to base classes to allow
|
||||
# the output to be a function
|
||||
if add_function:
|
||||
base_classes.append("function")
|
||||
return {
|
||||
"template": format_dict(variables, name),
|
||||
"description": docs["Description"],
|
||||
"base_classes": get_base_classes(_class),
|
||||
"base_classes": base_classes,
|
||||
}
|
||||
|
||||
|
||||
def get_base_classes(cls):
|
||||
"""Get the base classes of a class.
|
||||
These are used to determine the output of the nodes.
|
||||
"""
|
||||
bases = cls.__bases__
|
||||
if not bases:
|
||||
return []
|
||||
|
|
@ -127,6 +147,30 @@ def get_default_factory(module: str, function: str):
|
|||
return None
|
||||
|
||||
|
||||
class GenericTool(Tool):
|
||||
"""Base class for all tools."""
|
||||
|
||||
def default_func(self, **kwargs):
|
||||
"""Default function for the tool."""
|
||||
return "Default function"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "Tool name",
|
||||
description: str = "Tool description",
|
||||
func: callable = None,
|
||||
):
|
||||
"""Initialize the tool."""
|
||||
super().__init__(name=name, description=description, func=func)
|
||||
|
||||
|
||||
# from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
def get_base_tool(name, description, func: callable) -> BaseTool:
|
||||
return GenericTool(func=func, name="Generic Tool", description="Bacon")
|
||||
|
||||
|
||||
def get_tools_dict(name: Optional[str] = None):
|
||||
"""Get the tools dictionary."""
|
||||
tools = {
|
||||
|
|
@ -135,11 +179,13 @@ def get_tools_dict(name: Optional[str] = None):
|
|||
**{k: v[0] for k, v in _EXTRA_LLM_TOOLS.items()}, # type: ignore
|
||||
**{k: v[0] for k, v in _EXTRA_OPTIONAL_TOOLS.items()},
|
||||
}
|
||||
tools.update({"BaseTool": get_base_tool})
|
||||
return tools[name] if name else tools
|
||||
|
||||
|
||||
def get_tool_params(func, **kwargs):
|
||||
# Parse the function code into an abstract syntax tree
|
||||
|
||||
tree = ast.parse(inspect.getsource(func))
|
||||
|
||||
# Iterate over the statements in the abstract syntax tree
|
||||
|
|
@ -266,7 +312,7 @@ def format_dict(d, name: Optional[str] = None):
|
|||
_type = _type.replace("Mapping", "dict")
|
||||
|
||||
# Change type from str to Tool
|
||||
value["type"] = "Tool" if key == "allowed_tools" else _type
|
||||
value["type"] = "Tool" if key in ["allowed_tools", "func"] else _type
|
||||
|
||||
# Show or not field
|
||||
value["show"] = bool(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue