fix: mypy errors

This commit is contained in:
Gabriel Almeida 2023-03-27 06:32:42 -03:00
commit 66b387264d
5 changed files with 20 additions and 18 deletions

View file

@ -124,7 +124,7 @@ def load_agent_executor_from_config(
def load_agent_executor(
agent_class: agent_module.Agent, allowed_tools, llm_chain, **kwargs
agent_class: type[agent_module.Agent], allowed_tools, llm_chain, **kwargs
):
"""Load agent executor from agent class, tools and chain"""
tool_names = [tool.name for tool in allowed_tools]

View file

@ -4,8 +4,8 @@ from langflow.utils.constants import DEFAULT_PYTHON_FUNCTION
class ZeroShotPromptNode(FrontendNode):
name = "ZeroShotPrompt"
template = Template(
name: str = "ZeroShotPrompt"
template: Template = Template(
type_name="zero_shot",
fields=[
Field(
@ -40,16 +40,16 @@ class ZeroShotPromptNode(FrontendNode):
),
],
)
description = "Prompt template for Zero Shot Agent."
base_classes = ["BasePromptTemplate"]
description: str = "Prompt template for Zero Shot Agent."
base_classes: list[str] = ["BasePromptTemplate"]
def to_dict(self):
return super().to_dict()
class PythonFunctionNode(FrontendNode):
name = "PythonFunction"
template = Template(
name: str = "PythonFunction"
template: Template = Template(
type_name="python_function",
fields=[
Field(
@ -64,16 +64,16 @@ class PythonFunctionNode(FrontendNode):
),
],
)
description = "Python function to be executed."
base_classes = ["function"]
description: str = "Python function to be executed."
base_classes: list[str] = ["function"]
def to_dict(self):
return super().to_dict()
class ToolNode(FrontendNode):
name = "Tool"
template = Template(
name: str = "Tool"
template: Template = Template(
type_name="tool",
fields=[
Field(
@ -108,8 +108,8 @@ class ToolNode(FrontendNode):
),
],
)
description = "Tool to be used in the flow."
base_classes = ["BaseTool"]
description: str = "Tool to be used in the flow."
base_classes: list[str] = ["BaseTool"]
def to_dict(self):
return super().to_dict()

View file

@ -12,7 +12,7 @@ class Field(BaseModel):
value: Any = None
# _name will be used to store the name of the field
# in the template
name: str = None
name: str = ""
def to_dict(self):
result = self.dict()
@ -34,7 +34,7 @@ class FrontendNode(BaseModel):
template: Template
description: str
base_classes: list
name: str = None
name: str = ""
def to_dict(self):
return {

View file

@ -63,7 +63,7 @@ class Node:
continue
if value["type"] not in ["str", "bool"]:
# Get the edge that connects to this node
edge: Edge = next(
edge = next(
(
edge
for edge in self.edges
@ -222,7 +222,7 @@ class Graph:
root_node = payload.get_root_node(self)
return root_node.build()
def get_node_neighbors(self, node: Node) -> Dict[str, int]:
def get_node_neighbors(self, node: Node) -> Dict[Node, int]:
neighbors: Dict[Node, int] = {}
for edge in self.edges:
if edge.source == node:

View file

@ -76,7 +76,9 @@ def build_json(root, graph) -> Dict:
if value["required"] and not children:
raise ValueError(f"No child with type {node_type} found")
values = [build_json(child, graph) for child in children]
value = list(values) if value["list"] else next(iter(values), None)
value = (
list(values) if value["list"] else next(iter(values), None) # type: ignore
)
final_dict[key] = value
return final_dict