diff --git a/src/backend/langflow/graph/nodes.py b/src/backend/langflow/graph/nodes.py index 9db6260e9..21fe0f673 100644 --- a/src/backend/langflow/graph/nodes.py +++ b/src/backend/langflow/graph/nodes.py @@ -1,7 +1,12 @@ from typing import Any, Dict, List, Optional, Union from langflow.graph.base import Node -from langflow.graph.utils import extract_input_variables_from_prompt +from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list + + +class ToolkitNode(Node): + def __init__(self, data: Dict): + super().__init__(data, base_type="toolkits") class AgentNode(Node): @@ -47,7 +52,7 @@ class PromptNode(Node): def build( self, force: bool = False, - tools: Optional[Union[List[Node], List[ToolNode]]] = None, + tools: Optional[List[Union[ToolNode, ToolkitNode]]] = None, ) -> Any: if not self._built or force: if ( @@ -65,8 +70,7 @@ class PromptNode(Node): # flatten the list of tools if it is a list of lists # first check if it is a list if tools and isinstance(tools, list) and isinstance(tools[0], list): - tools = [tool for sublist in tools for tool in sublist] - + tools = flatten_list(tools) self.params["tools"] = tools prompt_params = [ key @@ -85,30 +89,6 @@ class PromptNode(Node): return self._built_object -class ChainNode(Node): - def __init__(self, data: Dict): - super().__init__(data, base_type="chains") - - def build( - self, - force: bool = False, - tools: Optional[Union[List[Node], List[ToolNode]]] = None, - ) -> Any: - if not self._built or force: - # Check if the chain requires a PromptNode - for key, value in self.params.items(): - if isinstance(value, PromptNode): - # Build the PromptNode, passing the tools if available - self.params[key] = value.build(tools=tools, force=force) - - self._build() - - #! Cannot deepcopy SQLDatabaseChain - if self.node_type in ["SQLDatabaseChain"]: - return self._built_object - return self._built_object - - class LLMNode(Node): built_node_type = None class_built_object = None @@ -130,11 +110,6 @@ class LLMNode(Node): return self._built_object -class ToolkitNode(Node): - def __init__(self, data: Dict): - super().__init__(data, base_type="toolkits") - - class FileToolNode(ToolNode): def __init__(self, data: Dict): super().__init__(data) @@ -193,3 +168,24 @@ class TextSplitterNode(Node): if self._built_object: return f"""{self.node_type}({len(self._built_object)} documents)\nDocuments: {self._built_object[:3]}...""" return f"{self.node_type}()" + + +class ChainNode(Node): + def __init__(self, data: Dict): + super().__init__(data, base_type="chains") + + def build( + self, + force: bool = False, + tools: Optional[List[Union[ToolNode, ToolkitNode]]] = None, + ) -> Any: + if not self._built or force: + # Check if the chain requires a PromptNode + for key, value in self.params.items(): + if isinstance(value, PromptNode): + # Build the PromptNode, passing the tools if available + self.params[key] = value.build(tools=tools, force=force) + + self._build() + + return self._built_object diff --git a/src/backend/langflow/graph/utils.py b/src/backend/langflow/graph/utils.py index 6d56e933e..e22b27cf5 100644 --- a/src/backend/langflow/graph/utils.py +++ b/src/backend/langflow/graph/utils.py @@ -1,4 +1,5 @@ import re +from typing import Any, Union def validate_prompt(prompt: str): @@ -17,3 +18,14 @@ def fix_prompt(prompt: str): def extract_input_variables_from_prompt(prompt: str) -> list[str]: """Extract input variables from prompt.""" return re.findall(r"{(.*?)}", prompt) + + +def flatten_list(list_of_lists: list[Union[list, Any]]) -> list: + """Flatten list of lists.""" + new_list = [] + for item in list_of_lists: + if isinstance(item, list): + new_list.extend(item) + else: + new_list.append(item) + return new_list