fix: removing dicts from inside class to stop recreating it

This commit is contained in:
Gabriel Almeida 2023-03-31 14:01:35 -03:00
commit 0858734eb0

View file

@ -2,9 +2,10 @@ from langflow.custom import customs
from langflow.interface.tools.constants import (
ALL_TOOLS_NAMES,
CUSTOM_TOOLS,
OTHER_TOOLS,
FILE_TOOLS,
)
from langflow.template.template import Field, Template
from langflow.template.base import Field
from langflow.template.base import Template
from langflow.utils import util
from langflow.settings import settings
from langflow.interface.base import LangChainTypeCreator
@ -22,6 +23,41 @@ from langflow.interface.tools.util import (
)
TOOL_INPUTS = {
"str": Field(
field_type="str",
required=True,
is_list=False,
show=True,
placeholder="",
value="",
),
"llm": Field(field_type="BaseLLM", required=True, is_list=False, show=True),
"func": Field(
field_type="function",
required=True,
is_list=False,
show=True,
multiline=True,
),
"code": Field(
field_type="str",
required=True,
is_list=False,
show=True,
value="",
multiline=True,
),
"dict_": Field(
field_type="file",
required=True,
is_list=False,
show=True,
value="",
),
}
class ToolCreator(LangChainTypeCreator):
type_name: str = "tools"
tools_dict: Dict | None = None
@ -35,7 +71,6 @@ class ToolCreator(LangChainTypeCreator):
def get_signature(self, name: str) -> Dict | None:
"""Get the signature of a tool."""
NODE_INPUTS = ["llm", "func"]
base_classes = ["Tool"]
all_tools = {}
for tool in self.type_to_loader_dict.keys():
@ -47,40 +82,6 @@ class ToolCreator(LangChainTypeCreator):
if name not in all_tools.keys():
raise ValueError("Tool not found")
type_dict = {
"str": Field(
field_type="str",
required=True,
is_list=False,
show=True,
placeholder="",
value="",
),
"llm": Field(field_type="BaseLLM", required=True, is_list=False, show=True),
"func": Field(
field_type="function",
required=True,
is_list=False,
show=True,
multiline=True,
),
"code": Field(
field_type="str",
required=True,
is_list=False,
show=True,
value="",
multiline=True,
),
"dict_": Field(
field_type="file",
required=True,
is_list=False,
show=True,
value="",
),
}
tool_type: str = all_tools[name]["type"] # type: ignore
if tool_type in _BASE_TOOLS:
@ -101,8 +102,9 @@ class ToolCreator(LangChainTypeCreator):
base_classes = ["function"]
if node := customs.get_custom_nodes("tools").get(tool_type):
return node
elif tool_type in OTHER_TOOLS:
elif tool_type in FILE_TOOLS:
params = all_tools[name]["params"] # type: ignore
base_classes += [name]
else:
params = []
@ -110,10 +112,7 @@ class ToolCreator(LangChainTypeCreator):
# Copy the field and add the name
fields = []
for param in params:
if param in NODE_INPUTS:
field = type_dict[param].copy()
else:
field = type_dict["str"].copy()
field = TOOL_INPUTS.get(param, TOOL_INPUTS["str"])
field.name = param
if param == "aiosession":
field.show = False
@ -122,9 +121,7 @@ class ToolCreator(LangChainTypeCreator):
template = Template(fields=fields, type_name=tool_type)
tool_params = get_tool_params(get_tool_by_name(tool_type))
if tool_params is None:
tool_params = {}
tool_params = all_tools[name]["params"]
return {
"template": util.format_dict(template.to_dict()),
**tool_params,