From df3cdb90b7d13c4c38008f0aa8fbb17bf504ebef Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 2 Jun 2023 14:21:38 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20refactor(types.py):=20reorder=20?= =?UTF-8?q?class=20definitions=20to=20match=20the=20order=20of=20their=20u?= =?UTF-8?q?sage=20in=20the=20code=20The=20order=20of=20the=20class=20defin?= =?UTF-8?q?itions=20in=20the=20file=20has=20been=20changed=20to=20match=20?= =?UTF-8?q?the=20order=20of=20their=20usage=20in=20the=20code.=20This=20im?= =?UTF-8?q?proves=20the=20readability=20of=20the=20code=20and=20makes=20it?= =?UTF-8?q?=20easier=20to=20understand=20the=20relationships=20between=20t?= =?UTF-8?q?he=20classes.=20No=20functionality=20has=20been=20changed.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/graph/vertex/types.py | 137 ++++++++++----------- 1 file changed, 67 insertions(+), 70 deletions(-) diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index 0b0d0923f..b81e72439 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -8,13 +8,13 @@ class AgentVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="agents") - self.tools: List[Union[ToolVertex, ToolkitVertex]] = [] + self.tools: List[Union[ToolkitVertex, ToolVertex]] = [] self.chains: List[ChainVertex] = [] def _set_tools_and_chains(self) -> None: for edge in self.edges: source_node = edge.source - if isinstance(source_node, ToolVertex): + if isinstance(source_node, (ToolVertex, ToolkitVertex)): self.tools.append(source_node) elif isinstance(source_node, ChainVertex): self.chains.append(source_node) @@ -40,74 +40,6 @@ class ToolVertex(Vertex): super().__init__(data, base_type="tools") -class PromptVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="prompts") - - def build( - self, - force: bool = False, - tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None, - ) -> Any: - if not self._built or force: - if ( - "input_variables" not in self.params - or self.params["input_variables"] is None - ): - self.params["input_variables"] = [] - # Check if it is a ZeroShotPrompt and needs a tool - if "ShotPrompt" in self.vertex_type: - tools = ( - [tool_node.build() for tool_node in tools] - if tools is not None - else [] - ) - # flatten the list of tools if it is a list of lists - # first check if it is a list - if tools and isinstance(tools, list) and isinstance(tools[0], list): - tools = flatten_list(tools) - self.params["tools"] = tools - prompt_params = [ - key - for key, value in self.params.items() - if isinstance(value, str) and key != "format_instructions" - ] - else: - prompt_params = ["template"] - for param in prompt_params: - prompt_text = self.params[param] - variables = extract_input_variables_from_prompt(prompt_text) - self.params["input_variables"].extend(variables) - self.params["input_variables"] = list(set(self.params["input_variables"])) - - self._build() - return self._built_object - - -class ChainVertex(Vertex): - def __init__(self, data: Dict): - super().__init__(data, base_type="chains") - - def build( - self, - force: bool = False, - tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None, - ) -> Any: - if not self._built or force: - # Check if the chain requires a PromptVertex - for key, value in self.params.items(): - if isinstance(value, PromptVertex): - # Build the PromptVertex, passing the tools if available - self.params[key] = value.build(tools=tools, force=force) - - self._build() - - #! Cannot deepcopy SQLDatabaseChain - if self.vertex_type in ["SQLDatabaseChain"]: - return self._built_object - return self._built_object - - class LLMVertex(Vertex): built_node_type = None class_built_object = None @@ -193,3 +125,68 @@ class TextSplitterVertex(Vertex): return f"""{self.vertex_type}({len(self._built_object)} documents) \nDocuments: {self._built_object[:3]}...""" return f"{self.vertex_type}()" + + +class ChainVertex(Vertex): + def __init__(self, data: Dict): + super().__init__(data, base_type="chains") + + def build( + self, + force: bool = False, + tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None, + ) -> Any: + if not self._built or force: + # Check if the chain requires a PromptVertex + for key, value in self.params.items(): + if isinstance(value, PromptVertex): + # Build the PromptVertex, passing the tools if available + self.params[key] = value.build(tools=tools, force=force) + + self._build() + + return self._built_object + + +class PromptVertex(Vertex): + def __init__(self, data: Dict): + super().__init__(data, base_type="prompts") + + def build( + self, + force: bool = False, + tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None, + ) -> Any: + if not self._built or force: + if ( + "input_variables" not in self.params + or self.params["input_variables"] is None + ): + self.params["input_variables"] = [] + # Check if it is a ZeroShotPrompt and needs a tool + if "ShotPrompt" in self.vertex_type: + tools = ( + [tool_node.build() for tool_node in tools] + if tools is not None + else [] + ) + # flatten the list of tools if it is a list of lists + # first check if it is a list + if tools and isinstance(tools, list) and isinstance(tools[0], list): + tools = flatten_list(tools) + self.params["tools"] = tools + prompt_params = [ + key + for key, value in self.params.items() + if isinstance(value, str) and key != "format_instructions" + ] + else: + prompt_params = ["template"] + for param in prompt_params: + prompt_text = self.params[param] + variables = extract_input_variables_from_prompt(prompt_text) + self.params["input_variables"].extend(variables) + self.params["input_variables"] = list(set(self.params["input_variables"])) + + self._build() + return self._built_object