🔨 refactor(nodes.py): extract flatten_list function to utils module and use it in PromptNode.build method

🐛 fix(nodes.py): change tools parameter type hint in PromptNode.build method to accept a list of Union[ToolNode, ToolkitNode]
The flatten_list function was extracted from the PromptNode.build method and moved to the utils module to improve code reusability. The PromptNode.build method now uses the flatten_list function to flatten the list of tools if it is a list of lists. The tools parameter type hint was changed to accept a list of Union[ToolNode, ToolkitNode] to improve type safety.
This commit is contained in:
Gabriel Almeida 2023-05-30 21:47:42 -03:00
commit 041748b2fb
2 changed files with 41 additions and 33 deletions

View file

@ -1,7 +1,12 @@
from typing import Any, Dict, List, Optional, Union
from langflow.graph.base import Node
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list
class ToolkitNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="toolkits")
class AgentNode(Node):
@ -47,7 +52,7 @@ class PromptNode(Node):
def build(
self,
force: bool = False,
tools: Optional[Union[List[Node], List[ToolNode]]] = None,
tools: Optional[List[Union[ToolNode, ToolkitNode]]] = None,
) -> Any:
if not self._built or force:
if (
@ -65,8 +70,7 @@ class PromptNode(Node):
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = [tool for sublist in tools for tool in sublist]
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key
@ -85,30 +89,6 @@ class PromptNode(Node):
return self._built_object
class ChainNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="chains")
def build(
self,
force: bool = False,
tools: Optional[Union[List[Node], List[ToolNode]]] = None,
) -> Any:
if not self._built or force:
# Check if the chain requires a PromptNode
for key, value in self.params.items():
if isinstance(value, PromptNode):
# Build the PromptNode, passing the tools if available
self.params[key] = value.build(tools=tools, force=force)
self._build()
#! Cannot deepcopy SQLDatabaseChain
if self.node_type in ["SQLDatabaseChain"]:
return self._built_object
return self._built_object
class LLMNode(Node):
built_node_type = None
class_built_object = None
@ -130,11 +110,6 @@ class LLMNode(Node):
return self._built_object
class ToolkitNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="toolkits")
class FileToolNode(ToolNode):
def __init__(self, data: Dict):
super().__init__(data)
@ -193,3 +168,24 @@ class TextSplitterNode(Node):
if self._built_object:
return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}..."""
return f"{self.node_type}()"
class ChainNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="chains")
def build(
self,
force: bool = False,
tools: Optional[List[Union[ToolNode, ToolkitNode]]] = None,
) -> Any:
if not self._built or force:
# Check if the chain requires a PromptNode
for key, value in self.params.items():
if isinstance(value, PromptNode):
# Build the PromptNode, passing the tools if available
self.params[key] = value.build(tools=tools, force=force)
self._build()
return self._built_object

View file

@ -1,4 +1,5 @@
import re
from typing import Any, Union
def validate_prompt(prompt: str):
@ -17,3 +18,14 @@ def fix_prompt(prompt: str):
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
"""Extract input variables from prompt."""
return re.findall(r"{(.*?)}", prompt)
def flatten_list(list_of_lists: list[Union[list, Any]]) -> list:
"""Flatten list of lists."""
new_list = []
for item in list_of_lists:
if isinstance(item, list):
new_list.extend(item)
else:
new_list.append(item)
return new_list