From 66b387264d8c83a91c62361e215cf2647f06a8f6 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Mon, 27 Mar 2023 06:32:42 -0300 Subject: [PATCH] fix: mypy errors --- src/backend/langflow/interface/loading.py | 2 +- src/backend/langflow/node/nodes.py | 24 +++++++++++------------ src/backend/langflow/node/template.py | 4 ++-- src/backend/langflow/utils/graph.py | 4 ++-- src/backend/langflow/utils/payload.py | 4 +++- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 7f0e3756f..b376ac87f 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -124,7 +124,7 @@ def load_agent_executor_from_config( def load_agent_executor( - agent_class: agent_module.Agent, allowed_tools, llm_chain, **kwargs + agent_class: type[agent_module.Agent], allowed_tools, llm_chain, **kwargs ): """Load agent executor from agent class, tools and chain""" tool_names = [tool.name for tool in allowed_tools] diff --git a/src/backend/langflow/node/nodes.py b/src/backend/langflow/node/nodes.py index 16d0ae5d0..c38ad8e2f 100644 --- a/src/backend/langflow/node/nodes.py +++ b/src/backend/langflow/node/nodes.py @@ -4,8 +4,8 @@ from langflow.utils.constants import DEFAULT_PYTHON_FUNCTION class ZeroShotPromptNode(FrontendNode): - name = "ZeroShotPrompt" - template = Template( + name: str = "ZeroShotPrompt" + template: Template = Template( type_name="zero_shot", fields=[ Field( @@ -40,16 +40,16 @@ class ZeroShotPromptNode(FrontendNode): ), ], ) - description = "Prompt template for Zero Shot Agent." - base_classes = ["BasePromptTemplate"] + description: str = "Prompt template for Zero Shot Agent." + base_classes: list[str] = ["BasePromptTemplate"] def to_dict(self): return super().to_dict() class PythonFunctionNode(FrontendNode): - name = "PythonFunction" - template = Template( + name: str = "PythonFunction" + template: Template = Template( type_name="python_function", fields=[ Field( @@ -64,16 +64,16 @@ class PythonFunctionNode(FrontendNode): ), ], ) - description = "Python function to be executed." - base_classes = ["function"] + description: str = "Python function to be executed." + base_classes: list[str] = ["function"] def to_dict(self): return super().to_dict() class ToolNode(FrontendNode): - name = "Tool" - template = Template( + name: str = "Tool" + template: Template = Template( type_name="tool", fields=[ Field( @@ -108,8 +108,8 @@ class ToolNode(FrontendNode): ), ], ) - description = "Tool to be used in the flow." - base_classes = ["BaseTool"] + description: str = "Tool to be used in the flow." + base_classes: list[str] = ["BaseTool"] def to_dict(self): return super().to_dict() diff --git a/src/backend/langflow/node/template.py b/src/backend/langflow/node/template.py index 68b172662..cc0245730 100644 --- a/src/backend/langflow/node/template.py +++ b/src/backend/langflow/node/template.py @@ -12,7 +12,7 @@ class Field(BaseModel): value: Any = None # _name will be used to store the name of the field # in the template - name: str = None + name: str = "" def to_dict(self): result = self.dict() @@ -34,7 +34,7 @@ class FrontendNode(BaseModel): template: Template description: str base_classes: list - name: str = None + name: str = "" def to_dict(self): return { diff --git a/src/backend/langflow/utils/graph.py b/src/backend/langflow/utils/graph.py index a83c79cb6..36eda7660 100644 --- a/src/backend/langflow/utils/graph.py +++ b/src/backend/langflow/utils/graph.py @@ -63,7 +63,7 @@ class Node: continue if value["type"] not in ["str", "bool"]: # Get the edge that connects to this node - edge: Edge = next( + edge = next( ( edge for edge in self.edges @@ -222,7 +222,7 @@ class Graph: root_node = payload.get_root_node(self) return root_node.build() - def get_node_neighbors(self, node: Node) -> Dict[str, int]: + def get_node_neighbors(self, node: Node) -> Dict[Node, int]: neighbors: Dict[Node, int] = {} for edge in self.edges: if edge.source == node: diff --git a/src/backend/langflow/utils/payload.py b/src/backend/langflow/utils/payload.py index f35ac6f59..9f3fad38e 100644 --- a/src/backend/langflow/utils/payload.py +++ b/src/backend/langflow/utils/payload.py @@ -76,7 +76,9 @@ def build_json(root, graph) -> Dict: if value["required"] and not children: raise ValueError(f"No child with type {node_type} found") values = [build_json(child, graph) for child in children] - value = list(values) if value["list"] else next(iter(values), None) + value = ( + list(values) if value["list"] else next(iter(values), None) # type: ignore + ) final_dict[key] = value return final_dict