From ea2c8ca985981f7009802a568f4ade2033abcfc3 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Mon, 27 Mar 2023 17:34:38 -0300 Subject: [PATCH] refac: implement Field in tool signature --- src/backend/langflow/interface/loading.py | 4 +- src/backend/langflow/interface/signature.py | 89 +++++++++++---------- 2 files changed, 51 insertions(+), 42 deletions(-) diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index b376ac87f..ee4764977 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -46,7 +46,9 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: return util.eval_function(function_string) raise ValueError("Function should be a string") else: - return ZeroShotAgent.create_prompt(**params, tools=[]) + if "tools" not in params: + params["tools"] = [] + return ZeroShotAgent.create_prompt(**params) def load_flow_from_json(path: str): diff --git a/src/backend/langflow/interface/signature.py b/src/backend/langflow/interface/signature.py index b1414a88e..b3a238cdb 100644 --- a/src/backend/langflow/interface/signature.py +++ b/src/backend/langflow/interface/signature.py @@ -14,6 +14,7 @@ from langflow.interface.custom_lists import ( llm_type_to_cls_dict, memory_type_to_cls_dict, ) +from langflow.node.template import Field, Template from langflow.utils import util from langflow.utils.constants import CUSTOM_TOOLS @@ -85,7 +86,7 @@ def get_tool_signature(name: str): NODE_INPUTS = ["llm", "func"] base_classes = ["Tool"] all_tools = {} - all_tool_names = get_all_tool_names() + list(CUSTOM_TOOLS.keys()) + all_tool_names: list[str] = get_all_tool_names() + list(CUSTOM_TOOLS.keys()) for tool in all_tool_names: if tool_params := util.get_tool_params(util.get_tool_by_name(tool)): tool_name = tool_params.get("name") or str(tool) @@ -96,34 +97,33 @@ def get_tool_signature(name: str): raise ValueError("Tool not found") type_dict = { - "str": { - "type": "str", - "required": True, - "list": False, - "show": True, - "placeholder": "", - "value": "", - }, - "llm": {"type": "BaseLLM", "required": True, "list": False, "show": True}, - "func": { - "type": "function", - "required": True, - "list": False, - "show": True, - "value": "", - "multiline": True, - }, - "code": { - "type": "str", - "required": True, - "list": False, - "show": True, - "value": "", - "multiline": True, - }, + "str": Field( + field_type="str", + required=True, + is_list=False, + show=True, + placeholder="", + value="", + ), + "llm": Field(field_type="BaseLLM", required=True, is_list=False, show=True), + "func": Field( + field_type="function", + required=True, + is_list=False, + show=True, + multiline=True, + ), + "code": Field( + field_type="str", + required=True, + is_list=False, + show=True, + value="", + multiline=True, + ), } - tool_type = all_tools[name]["type"] + tool_type: str = all_tools[name]["type"] # type: ignore if tool_type in _BASE_TOOLS: params = [] @@ -139,27 +139,34 @@ def get_tool_signature(name: str): params = ["name", "description", "func"] elif tool_type in CUSTOM_TOOLS: # Get custom tool params - params = all_tools[name]["params"] + params = all_tools[name]["params"] # type: ignore base_classes = ["function"] + if node := customs.get_custom_nodes("tools").get(tool_type): + return node else: params = [] - template = { - param: ( - type_dict[param].copy() if param in NODE_INPUTS else type_dict["str"].copy() - ) - for param in params - } + # Copy the field and add the name + fields = [] + for param in params: + if param in NODE_INPUTS: + field = type_dict[param].copy() + else: + field = type_dict["str"].copy() + field.name = param + if param == "aiosession": + field.show = False + field.required = False + fields.append(field) - # Remove required from aiosession - if "aiosession" in template.keys(): - template["aiosession"]["required"] = False - template["aiosession"]["show"] = False + template = Template(fields=fields, type_name=tool_type) - template["_type"] = tool_type # type: ignore + tool_params = util.get_tool_params(util.get_tool_by_name(tool_type)) + if tool_params is None: + tool_params = {} return { - "template": util.format_dict(template), - **util.get_tool_params(util.get_tool_by_name(tool_type)), + "template": util.format_dict(template.to_dict()), + **tool_params, "base_classes": base_classes, }