refac: implement Field in tool signature

This commit is contained in:
Gabriel Almeida 2023-03-27 17:34:38 -03:00
commit ea2c8ca985
2 changed files with 51 additions and 42 deletions

View file

@ -46,7 +46,9 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
return util.eval_function(function_string)
raise ValueError("Function should be a string")
else:
return ZeroShotAgent.create_prompt(**params, tools=[])
if "tools" not in params:
params["tools"] = []
return ZeroShotAgent.create_prompt(**params)
def load_flow_from_json(path: str):

View file

@ -14,6 +14,7 @@ from langflow.interface.custom_lists import (
llm_type_to_cls_dict,
memory_type_to_cls_dict,
)
from langflow.node.template import Field, Template
from langflow.utils import util
from langflow.utils.constants import CUSTOM_TOOLS
@ -85,7 +86,7 @@ def get_tool_signature(name: str):
NODE_INPUTS = ["llm", "func"]
base_classes = ["Tool"]
all_tools = {}
all_tool_names = get_all_tool_names() + list(CUSTOM_TOOLS.keys())
all_tool_names: list[str] = get_all_tool_names() + list(CUSTOM_TOOLS.keys())
for tool in all_tool_names:
if tool_params := util.get_tool_params(util.get_tool_by_name(tool)):
tool_name = tool_params.get("name") or str(tool)
@ -96,34 +97,33 @@ def get_tool_signature(name: str):
raise ValueError("Tool not found")
type_dict = {
"str": {
"type": "str",
"required": True,
"list": False,
"show": True,
"placeholder": "",
"value": "",
},
"llm": {"type": "BaseLLM", "required": True, "list": False, "show": True},
"func": {
"type": "function",
"required": True,
"list": False,
"show": True,
"value": "",
"multiline": True,
},
"code": {
"type": "str",
"required": True,
"list": False,
"show": True,
"value": "",
"multiline": True,
},
"str": Field(
field_type="str",
required=True,
is_list=False,
show=True,
placeholder="",
value="",
),
"llm": Field(field_type="BaseLLM", required=True, is_list=False, show=True),
"func": Field(
field_type="function",
required=True,
is_list=False,
show=True,
multiline=True,
),
"code": Field(
field_type="str",
required=True,
is_list=False,
show=True,
value="",
multiline=True,
),
}
tool_type = all_tools[name]["type"]
tool_type: str = all_tools[name]["type"] # type: ignore
if tool_type in _BASE_TOOLS:
params = []
@ -139,27 +139,34 @@ def get_tool_signature(name: str):
params = ["name", "description", "func"]
elif tool_type in CUSTOM_TOOLS:
# Get custom tool params
params = all_tools[name]["params"]
params = all_tools[name]["params"] # type: ignore
base_classes = ["function"]
if node := customs.get_custom_nodes("tools").get(tool_type):
return node
else:
params = []
template = {
param: (
type_dict[param].copy() if param in NODE_INPUTS else type_dict["str"].copy()
)
for param in params
}
# Copy the field and add the name
fields = []
for param in params:
if param in NODE_INPUTS:
field = type_dict[param].copy()
else:
field = type_dict["str"].copy()
field.name = param
if param == "aiosession":
field.show = False
field.required = False
fields.append(field)
# Remove required from aiosession
if "aiosession" in template.keys():
template["aiosession"]["required"] = False
template["aiosession"]["show"] = False
template = Template(fields=fields, type_name=tool_type)
template["_type"] = tool_type # type: ignore
tool_params = util.get_tool_params(util.get_tool_by_name(tool_type))
if tool_params is None:
tool_params = {}
return {
"template": util.format_dict(template),
**util.get_tool_params(util.get_tool_by_name(tool_type)),
"template": util.format_dict(template.to_dict()),
**tool_params,
"base_classes": base_classes,
}