refac: implement Field in tool signature
This commit is contained in:
parent
c6ebbe3517
commit
ea2c8ca985
2 changed files with 51 additions and 42 deletions
|
|
@ -46,7 +46,9 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
return util.eval_function(function_string)
|
||||
raise ValueError("Function should be a string")
|
||||
else:
|
||||
return ZeroShotAgent.create_prompt(**params, tools=[])
|
||||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return ZeroShotAgent.create_prompt(**params)
|
||||
|
||||
|
||||
def load_flow_from_json(path: str):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from langflow.interface.custom_lists import (
|
|||
llm_type_to_cls_dict,
|
||||
memory_type_to_cls_dict,
|
||||
)
|
||||
from langflow.node.template import Field, Template
|
||||
from langflow.utils import util
|
||||
from langflow.utils.constants import CUSTOM_TOOLS
|
||||
|
||||
|
|
@ -85,7 +86,7 @@ def get_tool_signature(name: str):
|
|||
NODE_INPUTS = ["llm", "func"]
|
||||
base_classes = ["Tool"]
|
||||
all_tools = {}
|
||||
all_tool_names = get_all_tool_names() + list(CUSTOM_TOOLS.keys())
|
||||
all_tool_names: list[str] = get_all_tool_names() + list(CUSTOM_TOOLS.keys())
|
||||
for tool in all_tool_names:
|
||||
if tool_params := util.get_tool_params(util.get_tool_by_name(tool)):
|
||||
tool_name = tool_params.get("name") or str(tool)
|
||||
|
|
@ -96,34 +97,33 @@ def get_tool_signature(name: str):
|
|||
raise ValueError("Tool not found")
|
||||
|
||||
type_dict = {
|
||||
"str": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"list": False,
|
||||
"show": True,
|
||||
"placeholder": "",
|
||||
"value": "",
|
||||
},
|
||||
"llm": {"type": "BaseLLM", "required": True, "list": False, "show": True},
|
||||
"func": {
|
||||
"type": "function",
|
||||
"required": True,
|
||||
"list": False,
|
||||
"show": True,
|
||||
"value": "",
|
||||
"multiline": True,
|
||||
},
|
||||
"code": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"list": False,
|
||||
"show": True,
|
||||
"value": "",
|
||||
"multiline": True,
|
||||
},
|
||||
"str": Field(
|
||||
field_type="str",
|
||||
required=True,
|
||||
is_list=False,
|
||||
show=True,
|
||||
placeholder="",
|
||||
value="",
|
||||
),
|
||||
"llm": Field(field_type="BaseLLM", required=True, is_list=False, show=True),
|
||||
"func": Field(
|
||||
field_type="function",
|
||||
required=True,
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=True,
|
||||
),
|
||||
"code": Field(
|
||||
field_type="str",
|
||||
required=True,
|
||||
is_list=False,
|
||||
show=True,
|
||||
value="",
|
||||
multiline=True,
|
||||
),
|
||||
}
|
||||
|
||||
tool_type = all_tools[name]["type"]
|
||||
tool_type: str = all_tools[name]["type"] # type: ignore
|
||||
|
||||
if tool_type in _BASE_TOOLS:
|
||||
params = []
|
||||
|
|
@ -139,27 +139,34 @@ def get_tool_signature(name: str):
|
|||
params = ["name", "description", "func"]
|
||||
elif tool_type in CUSTOM_TOOLS:
|
||||
# Get custom tool params
|
||||
params = all_tools[name]["params"]
|
||||
params = all_tools[name]["params"] # type: ignore
|
||||
base_classes = ["function"]
|
||||
if node := customs.get_custom_nodes("tools").get(tool_type):
|
||||
return node
|
||||
|
||||
else:
|
||||
params = []
|
||||
|
||||
template = {
|
||||
param: (
|
||||
type_dict[param].copy() if param in NODE_INPUTS else type_dict["str"].copy()
|
||||
)
|
||||
for param in params
|
||||
}
|
||||
# Copy the field and add the name
|
||||
fields = []
|
||||
for param in params:
|
||||
if param in NODE_INPUTS:
|
||||
field = type_dict[param].copy()
|
||||
else:
|
||||
field = type_dict["str"].copy()
|
||||
field.name = param
|
||||
if param == "aiosession":
|
||||
field.show = False
|
||||
field.required = False
|
||||
fields.append(field)
|
||||
|
||||
# Remove required from aiosession
|
||||
if "aiosession" in template.keys():
|
||||
template["aiosession"]["required"] = False
|
||||
template["aiosession"]["show"] = False
|
||||
template = Template(fields=fields, type_name=tool_type)
|
||||
|
||||
template["_type"] = tool_type # type: ignore
|
||||
tool_params = util.get_tool_params(util.get_tool_by_name(tool_type))
|
||||
if tool_params is None:
|
||||
tool_params = {}
|
||||
return {
|
||||
"template": util.format_dict(template),
|
||||
**util.get_tool_params(util.get_tool_by_name(tool_type)),
|
||||
"template": util.format_dict(template.to_dict()),
|
||||
**tool_params,
|
||||
"base_classes": base_classes,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue