From 041748b2fb29a1641e889973baddf56941611fe9 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 30 May 2023 21:47:42 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20refactor(nodes.py):=20extract=20?= =?UTF-8?q?flatten=5Flist=20function=20to=20utils=20module=20and=20use=20i?= =?UTF-8?q?t=20in=20PromptNode.build=20method=20=F0=9F=90=9B=20fix(nodes.p?= =?UTF-8?q?y):=20change=20tools=20parameter=20type=20hint=20in=20PromptNod?= =?UTF-8?q?e.build=20method=20to=20accept=20a=20list=20of=20Union[ToolNode?= =?UTF-8?q?,=20ToolkitNode]=20The=20flatten=5Flist=20function=20was=20extr?= =?UTF-8?q?acted=20from=20the=20PromptNode.build=20method=20and=20moved=20?= =?UTF-8?q?to=20the=20utils=20module=20to=20improve=20code=20reusability.?= =?UTF-8?q?=20The=20PromptNode.build=20method=20now=20uses=20the=20flatten?= =?UTF-8?q?=5Flist=20function=20to=20flatten=20the=20list=20of=20tools=20i?= =?UTF-8?q?f=20it=20is=20a=20list=20of=20lists.=20The=20tools=20parameter?= =?UTF-8?q?=20type=20hint=20was=20changed=20to=20accept=20a=20list=20of=20?= =?UTF-8?q?Union[ToolNode,=20ToolkitNode]=20to=20improve=20type=20safety.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/graph/nodes.py | 62 ++++++++++++++--------------- src/backend/langflow/graph/utils.py | 12 ++++++ 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/src/backend/langflow/graph/nodes.py b/src/backend/langflow/graph/nodes.py index 9db6260e9..21fe0f673 100644 --- a/src/backend/langflow/graph/nodes.py +++ b/src/backend/langflow/graph/nodes.py @@ -1,7 +1,12 @@ from typing import Any, Dict, List, Optional, Union from langflow.graph.base import Node -from langflow.graph.utils import extract_input_variables_from_prompt +from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list + + +class ToolkitNode(Node): + def __init__(self, data: Dict): + super().__init__(data, base_type="toolkits") class AgentNode(Node): @@ -47,7 +52,7 @@ class PromptNode(Node): def build( self, force: bool = False, - tools: Optional[Union[List[Node], List[ToolNode]]] = None, + tools: Optional[List[Union[ToolNode, ToolkitNode]]] = None, ) -> Any: if not self._built or force: if ( @@ -65,8 +70,7 @@ class PromptNode(Node): # flatten the list of tools if it is a list of lists # first check if it is a list if tools and isinstance(tools, list) and isinstance(tools[0], list): - tools = [tool for sublist in tools for tool in sublist] - + tools = flatten_list(tools) self.params["tools"] = tools prompt_params = [ key @@ -85,30 +89,6 @@ class PromptNode(Node): return self._built_object -class ChainNode(Node): - def __init__(self, data: Dict): - super().__init__(data, base_type="chains") - - def build( - self, - force: bool = False, - tools: Optional[Union[List[Node], List[ToolNode]]] = None, - ) -> Any: - if not self._built or force: - # Check if the chain requires a PromptNode - for key, value in self.params.items(): - if isinstance(value, PromptNode): - # Build the PromptNode, passing the tools if available - self.params[key] = value.build(tools=tools, force=force) - - self._build() - - #! Cannot deepcopy SQLDatabaseChain - if self.node_type in ["SQLDatabaseChain"]: - return self._built_object - return self._built_object - - class LLMNode(Node): built_node_type = None class_built_object = None @@ -130,11 +110,6 @@ class LLMNode(Node): return self._built_object -class ToolkitNode(Node): - def __init__(self, data: Dict): - super().__init__(data, base_type="toolkits") - - class FileToolNode(ToolNode): def __init__(self, data: Dict): super().__init__(data) @@ -193,3 +168,24 @@ class TextSplitterNode(Node): if self._built_object: return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}...""" return f"{self.node_type}()" + + +class ChainNode(Node): + def __init__(self, data: Dict): + super().__init__(data, base_type="chains") + + def build( + self, + force: bool = False, + tools: Optional[List[Union[ToolNode, ToolkitNode]]] = None, + ) -> Any: + if not self._built or force: + # Check if the chain requires a PromptNode + for key, value in self.params.items(): + if isinstance(value, PromptNode): + # Build the PromptNode, passing the tools if available + self.params[key] = value.build(tools=tools, force=force) + + self._build() + + return self._built_object diff --git a/src/backend/langflow/graph/utils.py b/src/backend/langflow/graph/utils.py index 6d56e933e..e22b27cf5 100644 --- a/src/backend/langflow/graph/utils.py +++ b/src/backend/langflow/graph/utils.py @@ -1,4 +1,5 @@ import re +from typing import Any, Union def validate_prompt(prompt: str): @@ -17,3 +18,14 @@ def fix_prompt(prompt: str): def extract_input_variables_from_prompt(prompt: str) -> list[str]: """Extract input variables from prompt.""" return re.findall(r"{(.*?)}", prompt) + + +def flatten_list(list_of_lists: list[Union[list, Any]]) -> list: + """Flatten list of lists.""" + new_list = [] + for item in list_of_lists: + if isinstance(item, list): + new_list.extend(item) + else: + new_list.append(item) + return new_list