diff --git a/src/backend/langflow/interface/signature.py b/src/backend/langflow/interface/signature.py index baa2956dc..a915a2c3b 100644 --- a/src/backend/langflow/interface/signature.py +++ b/src/backend/langflow/interface/signature.py @@ -28,8 +28,9 @@ def get_chain_signature(name: str): """Get the chain type by signature.""" try: return util.build_template_from_function( - name, chains.loading.type_to_loader_dict + name, chains.loading.type_to_loader_dict, add_function=True ) + except ValueError as exc: raise ValueError("Chain not found") from exc @@ -37,7 +38,9 @@ def get_chain_signature(name: str): def get_agent_signature(name: str): """Get the signature of an agent.""" try: - return util.build_template_from_class(name, agents.loading.AGENT_TO_CLASS) + return util.build_template_from_class( + name, agents.loading.AGENT_TO_CLASS, add_function=True + ) except ValueError as exc: raise ValueError("Agent not found") from exc @@ -65,11 +68,13 @@ def get_llm_signature(name: str): def get_tool_signature(name: str): """Get the signature of a tool.""" + NODE_INPUTS = ["llm", "func"] all_tools = {} for tool in get_all_tool_names(): if tool_params := util.get_tool_params(util.get_tools_dict(tool)): all_tools[tool_params["name"]] = tool + all_tools["BaseTool"] = "BaseTool" # Raise error if name is not in tools if name not in all_tools.keys(): raise ValueError("Tool not found") @@ -84,6 +89,14 @@ def get_tool_signature(name: str): "value": "", }, "llm": {"type": "BaseLLM", "required": True, "list": False, "show": True}, + "func": { + "type": "function", + "required": True, + "list": False, + "show": True, + "value": "", + "multiline": True, + }, } tool_type = all_tools[name] @@ -98,11 +111,15 @@ def get_tool_signature(name: str): elif tool_type in _EXTRA_OPTIONAL_TOOLS: _, extra_keys = _EXTRA_OPTIONAL_TOOLS[tool_type] params = extra_keys + elif tool_type == "BaseTool": + params = ["name", "description", "func"] else: params = [] template = { - param: (type_dict[param].copy() if param == "llm" else type_dict["str"].copy()) + param: ( + type_dict[param].copy() if param in NODE_INPUTS else type_dict["str"].copy() + ) for param in params } diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index dae450f0c..93293155a 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -11,10 +11,16 @@ from langchain.agents.load_tools import ( ) from typing import Optional, Dict + +from langchain.agents.tools import Tool +from langchain.tools import BaseTool + from langflow.utils import constants -def build_template_from_function(name: str, type_to_loader_dict: Dict): +def build_template_from_function( + name: str, type_to_loader_dict: Dict, add_function: bool = False +): classes = [ item.__annotations__["return"].__name__ for item in type_to_loader_dict.values() ] @@ -52,6 +58,11 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict): if class_field_items in docs["Attributes"] else "" ) + # Adding function to base classes to allow + # the output to be a function + base_classes = get_base_classes(_class) + if add_function: + base_classes.append("function") return { "template": format_dict(variables, name), @@ -60,7 +71,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict): } -def build_template_from_class(name: str, type_to_cls_dict: Dict): +def build_template_from_class( + name: str, type_to_cls_dict: Dict, add_function: bool = False +): classes = [item.__name__ for item in type_to_cls_dict.values()] # Raise error if name is not in chains @@ -96,15 +109,22 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict): if class_field_items in docs["Attributes"] else "" ) - + base_classes = get_base_classes(_class) + # Adding function to base classes to allow + # the output to be a function + if add_function: + base_classes.append("function") return { "template": format_dict(variables, name), "description": docs["Description"], - "base_classes": get_base_classes(_class), + "base_classes": base_classes, } def get_base_classes(cls): + """Get the base classes of a class. + These are used to determine the output of the nodes. + """ bases = cls.__bases__ if not bases: return [] @@ -127,6 +147,30 @@ def get_default_factory(module: str, function: str): return None +class GenericTool(Tool): + """Base class for all tools.""" + + def default_func(self, **kwargs): + """Default function for the tool.""" + return "Default function" + + def __init__( + self, + name: str = "Tool name", + description: str = "Tool description", + func: callable = None, + ): + """Initialize the tool.""" + super().__init__(name=name, description=description, func=func) + + +# from langchain.llms.base import BaseLLM + + +def get_base_tool(name, description, func: callable) -> BaseTool: + return GenericTool(func=func, name="Generic Tool", description="Bacon") + + def get_tools_dict(name: Optional[str] = None): """Get the tools dictionary.""" tools = { @@ -135,11 +179,13 @@ def get_tools_dict(name: Optional[str] = None): **{k: v[0] for k, v in _EXTRA_LLM_TOOLS.items()}, # type: ignore **{k: v[0] for k, v in _EXTRA_OPTIONAL_TOOLS.items()}, } + tools.update({"BaseTool": get_base_tool}) return tools[name] if name else tools def get_tool_params(func, **kwargs): # Parse the function code into an abstract syntax tree + tree = ast.parse(inspect.getsource(func)) # Iterate over the statements in the abstract syntax tree @@ -266,7 +312,7 @@ def format_dict(d, name: Optional[str] = None): _type = _type.replace("Mapping", "dict") # Change type from str to Tool - value["type"] = "Tool" if key == "allowed_tools" else _type + value["type"] = "Tool" if key in ["allowed_tools", "func"] else _type # Show or not field value["show"] = bool(