🔨 refactor(nodes.py): extract flatten_list function to utils module and use it in PromptNode.build method
🐛 fix(nodes.py): change tools parameter type hint in PromptNode.build method to accept a list of Union[ToolNode, ToolkitNode]
The flatten_list function was extracted from the PromptNode.build method and moved to the utils module to improve code reusability. The PromptNode.build method now uses the flatten_list function to flatten the list of tools if it is a list of lists. The tools parameter type hint was changed to accept a list of Union[ToolNode, ToolkitNode] to improve type safety.
This commit is contained in:
parent
4b6a8595df
commit
041748b2fb
2 changed files with 41 additions and 33 deletions
|
|
@ -1,7 +1,12 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langflow.graph.base import Node
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list
|
||||
|
||||
|
||||
class ToolkitNode(Node):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="toolkits")
|
||||
|
||||
|
||||
class AgentNode(Node):
|
||||
|
|
@ -47,7 +52,7 @@ class PromptNode(Node):
|
|||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
tools: Optional[Union[List[Node], List[ToolNode]]] = None,
|
||||
tools: Optional[List[Union[ToolNode, ToolkitNode]]] = None,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
if (
|
||||
|
|
@ -65,8 +70,7 @@ class PromptNode(Node):
|
|||
# flatten the list of tools if it is a list of lists
|
||||
# first check if it is a list
|
||||
if tools and isinstance(tools, list) and isinstance(tools[0], list):
|
||||
tools = [tool for sublist in tools for tool in sublist]
|
||||
|
||||
tools = flatten_list(tools)
|
||||
self.params["tools"] = tools
|
||||
prompt_params = [
|
||||
key
|
||||
|
|
@ -85,30 +89,6 @@ class PromptNode(Node):
|
|||
return self._built_object
|
||||
|
||||
|
||||
class ChainNode(Node):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="chains")
|
||||
|
||||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
tools: Optional[Union[List[Node], List[ToolNode]]] = None,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
# Check if the chain requires a PromptNode
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, PromptNode):
|
||||
# Build the PromptNode, passing the tools if available
|
||||
self.params[key] = value.build(tools=tools, force=force)
|
||||
|
||||
self._build()
|
||||
|
||||
#! Cannot deepcopy SQLDatabaseChain
|
||||
if self.node_type in ["SQLDatabaseChain"]:
|
||||
return self._built_object
|
||||
return self._built_object
|
||||
|
||||
|
||||
class LLMNode(Node):
|
||||
built_node_type = None
|
||||
class_built_object = None
|
||||
|
|
@ -130,11 +110,6 @@ class LLMNode(Node):
|
|||
return self._built_object
|
||||
|
||||
|
||||
class ToolkitNode(Node):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="toolkits")
|
||||
|
||||
|
||||
class FileToolNode(ToolNode):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data)
|
||||
|
|
@ -193,3 +168,24 @@ class TextSplitterNode(Node):
|
|||
if self._built_object:
|
||||
return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}..."""
|
||||
return f"{self.node_type}()"
|
||||
|
||||
|
||||
class ChainNode(Node):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="chains")
|
||||
|
||||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
tools: Optional[List[Union[ToolNode, ToolkitNode]]] = None,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
# Check if the chain requires a PromptNode
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, PromptNode):
|
||||
# Build the PromptNode, passing the tools if available
|
||||
self.params[key] = value.build(tools=tools, force=force)
|
||||
|
||||
self._build()
|
||||
|
||||
return self._built_object
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
from typing import Any, Union
|
||||
|
||||
|
||||
def validate_prompt(prompt: str):
|
||||
|
|
@ -17,3 +18,14 @@ def fix_prompt(prompt: str):
|
|||
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
||||
"""Extract input variables from prompt."""
|
||||
return re.findall(r"{(.*?)}", prompt)
|
||||
|
||||
|
||||
def flatten_list(list_of_lists: list[Union[list, Any]]) -> list:
|
||||
"""Flatten list of lists."""
|
||||
new_list = []
|
||||
for item in list_of_lists:
|
||||
if isinstance(item, list):
|
||||
new_list.extend(item)
|
||||
else:
|
||||
new_list.append(item)
|
||||
return new_list
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue