From 128d7d7b88731d472dbeda33aec99e669f3169f0 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 16 Feb 2024 16:15:23 -0300 Subject: [PATCH] Refactor Vertex class to improve readability and maintainability --- src/backend/langflow/graph/vertex/base.py | 88 ++++++++++++++++++----- 1 file changed, 69 insertions(+), 19 deletions(-) diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 3a2005a63..503855e6e 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -78,14 +78,18 @@ class Vertex: ): if edge.target_id not in edge_results: edge_results[edge.target_id] = {} - edge_results[edge.target_id][edge.target_param] = await edge.get_result(source=self, target=target) + edge_results[edge.target_id][edge.target_param] = await edge.get_result( + source=self, target=target + ) return edge_results def get_built_result(self): # If the Vertex.type is a power component # then we need to return the built object # instead of the result dict - if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject): + if self.is_interface_component and not isinstance( + self._built_object, UnbuiltObject + ): result = self._built_object # if it is not a dict or a string and hasattr model_dump then # return the model_dump @@ -95,7 +99,11 @@ class Vertex: if isinstance(self._built_result, UnbuiltResult): return {} - return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result} + return ( + self._built_result + if isinstance(self._built_result, dict) + else {"result": self._built_result} + ) def set_artifacts(self) -> None: pass @@ -149,18 +157,29 @@ class Vertex: self.data = self._data["data"] self.output = self.data["node"]["base_classes"] self.pinned = self.data["node"].get("pinned", False) - template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} - template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} + template_dicts = { + key: value + for key, value in self.data["node"]["template"].items() + if isinstance(value, dict) + } self.required_inputs = [ - template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"] + template_dicts[key]["type"] + for key, value in template_dicts.items() + if value["required"] ] self.optional_inputs = [ - template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"] + template_dicts[key]["type"] + for key, value in template_dicts.items() + if not value["required"] ] # Add the template_dicts[key]["input_types"] to the optional_inputs self.optional_inputs.extend( - [input_type for value in template_dicts.values() for input_type in value.get("input_types", [])] + [ + input_type + for value in template_dicts.values() + for input_type in value.get("input_types", []) + ] ) template_dict = self.data["node"]["template"] @@ -203,7 +222,11 @@ class Vertex: if self.graph is None: raise ValueError("Graph not found") - template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} + template_dict = { + key: value + for key, value in self.data["node"]["template"].items() + if isinstance(value, dict) + } params = self.params.copy() if self.params else {} for edge in self.edges: @@ -255,7 +278,11 @@ class Vertex: # list of dicts, so we need to convert it to a dict # before passing it to the build method if isinstance(val, list): - params[key] = {k: v for item in value.get("value", []) for k, v in item.items()} + params[key] = { + k: v + for item in value.get("value", []) + for k, v in item.items() + } elif isinstance(val, dict): params[key] = val elif value.get("type") == "int" and val is not None: @@ -292,7 +319,12 @@ class Vertex: self._built = True - async def _run(self, user_id: str, inputs: Optional[dict] = None, session_id: Optional[str] = None): + async def _run( + self, + user_id: str, + inputs: Optional[dict] = None, + session_id: Optional[str] = None, + ): # user_id is just for compatibility with the other build methods inputs = inputs or {} # inputs = {key: value or "" for key, value in inputs.items()} @@ -307,7 +339,9 @@ class Vertex: if isinstance(self._built_object, str): self._built_result = self._built_object - result = await generate_result(self._built_object, inputs, self.has_external_output, session_id) + result = await generate_result( + self._built_object, inputs, self.has_external_output, session_id + ) self._built_result = result async def _build_each_node_in_params_dict(self, user_id=None): @@ -335,7 +369,9 @@ class Vertex: """ return all(self._is_node(node) for node in value) - async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any: + async def get_result( + self, requester: Optional["Vertex"] = None, user_id=None, timeout=None + ) -> Any: # PLEASE REVIEW THIS IF STATEMENT # Check if the Vertex was built already if self._built: @@ -369,7 +405,9 @@ class Vertex: self._extend_params_list_with_result(key, result) self.params[key] = result - async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None): + async def _build_list_of_nodes_and_update_params( + self, key, nodes: List["Vertex"], user_id=None + ): """ Iterates over a list of nodes, builds each and updates the params dictionary. """ @@ -421,7 +459,9 @@ class Vertex: self._update_built_object_and_artifacts(result) except Exception as exc: logger.exception(exc) - raise ValueError(f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}") from exc + raise ValueError( + f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}" + ) from exc def _update_built_object_and_artifacts(self, result): """ @@ -458,7 +498,7 @@ class Vertex: requester: Optional["Vertex"] = None, **kwargs, ) -> Any: - if self.pinned: + if self.pinned and self._built: return self.get_requester_result(requester) self._reset() @@ -480,9 +520,15 @@ class Vertex: return self._built_object # Get the requester edge - requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None) + requester_edge = next( + (edge for edge in self.edges if edge.target_id == requester.id), None + ) # Return the result of the requester edge - return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester) + return ( + None + if requester_edge is None + else await requester_edge.get_result(source=self, target=requester) + ) def add_edge(self, edge: "ContractEdge") -> None: if edge not in self.edges: @@ -502,7 +548,11 @@ class Vertex: def _built_object_repr(self): # Add a message with an emoji, stars for sucess, - return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫" + return ( + "Built sucessfully ✨" + if self._built_object is not None + else "Failed to build 😵‍💫" + ) class StatefulVertex(Vertex):