From 6c415d08654ace4c2067cb22c19686363718df15 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 7 Mar 2024 09:57:23 -0300 Subject: [PATCH] Refactor code to improve performance and readability --- src/backend/langflow/graph/graph/base.py | 108 +++++++++++++++++------ 1 file changed, 82 insertions(+), 26 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index a7d5a1d42..edb053f4d 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -76,7 +76,9 @@ class Graph: """Returns the state of the graph.""" return self.state_manager.get_state(name, run_id=self._run_id) - def update_state(self, name: str, record: Union[str, Record], caller: Optional[str] = None) -> None: + def update_state( + self, name: str, record: Union[str, Record], caller: Optional[str] = None + ) -> None: """Updates the state of the graph.""" if caller: # If there is a caller which is a vertex_id, I want to activate @@ -108,7 +110,9 @@ class Graph: def reset_activated_vertices(self): self.activated_vertices = [] - def append_state(self, name: str, record: Union[str, Record], caller: Optional[str] = None) -> None: + def append_state( + self, name: str, record: Union[str, Record], caller: Optional[str] = None + ) -> None: """Appends the state of the graph.""" if caller: self.activate_state_vertices(name, caller) @@ -156,7 +160,10 @@ class Graph: """Runs the graph with the given inputs.""" for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) - if input_components and (vertex_id not in input_components or vertex.display_name not in input_components): + if input_components and ( + vertex_id not in input_components + or vertex.display_name not in input_components + ): continue if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") @@ -179,9 +186,13 @@ class Graph: if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") - if not vertex.result and not stream and hasattr(vertex, "consume_async_generator"): + if ( + not vertex.result + and not stream + and hasattr(vertex, "consume_async_generator") + ): await vertex.consume_async_generator() - if vertex.display_name in outputs or vertex.id in outputs: + if not outputs or (vertex.display_name in outputs or vertex.id in outputs): vertex_outputs.append(vertex.result) return vertex_outputs @@ -189,8 +200,8 @@ class Graph: self, inputs: Dict[str, Union[str, list[str]]], outputs: list[str], - stream: bool, session_id: str, + stream: Optional[bool] = False, ) -> List[Optional["ResultData"]]: """Runs the graph with the given inputs.""" @@ -257,7 +268,9 @@ class Graph: def build_parent_child_map(self): parent_child_map = defaultdict(list) for vertex in self.vertices: - parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)] + parent_child_map[vertex.id] = [ + child.id for child in self.get_successors(vertex) + ] return parent_child_map def increment_run_count(self): @@ -442,7 +455,11 @@ class Graph: """Updates the edges of a vertex.""" # Vertex has edges, so we need to update the edges for edge in vertex.edges: - if edge not in self.edges and edge.source_id in self.vertex_map and edge.target_id in self.vertex_map: + if ( + edge not in self.edges + and edge.source_id in self.vertex_map + and edge.target_id in self.vertex_map + ): self.edges.append(edge) def _build_graph(self) -> None: @@ -467,7 +484,11 @@ class Graph: return self.vertices.remove(vertex) self.vertex_map.pop(vertex_id) - self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id] + self.edges = [ + edge + for edge in self.edges + if edge.source_id != vertex_id and edge.target_id != vertex_id + ] def _build_vertex_params(self) -> None: """Identifies and handles the LLM vertex within the graph.""" @@ -488,7 +509,9 @@ class Graph: return for vertex in self.vertices: if not self._validate_vertex(vertex): - raise ValueError(f"{vertex.display_name} is not connected to any other components") + raise ValueError( + f"{vertex.display_name} is not connected to any other components" + ) def _validate_vertex(self, vertex: Vertex) -> bool: """Validates a vertex.""" @@ -550,7 +573,9 @@ class Graph: name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}", ) tasks.append(task) - vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1 + vertex_task_run_count[vertex_id] = ( + vertex_task_run_count.get(vertex_id, 0) + 1 + ) logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks") await self._execute_tasks(tasks) logger.debug("Graph processing complete") @@ -592,7 +617,9 @@ class Graph: def dfs(vertex): if state[vertex] == 1: # We have a cycle - raise ValueError("Graph contains a cycle, cannot perform topological sort") + raise ValueError( + "Graph contains a cycle, cannot perform topological sort" + ) if state[vertex] == 0: state[vertex] = 1 for edge in vertex.edges: @@ -616,7 +643,10 @@ class Graph: def get_predecessors(self, vertex): """Returns the predecessors of a vertex.""" - return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])] + return [ + self.get_vertex(source_id) + for source_id in self.predecessor_map.get(vertex.id, []) + ] def get_all_successors(self, vertex, recursive=True, flat=True): # Recursively get the successors of the current vertex @@ -657,7 +687,10 @@ class Graph: def get_successors(self, vertex): """Returns the successors of a vertex.""" - return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])] + return [ + self.get_vertex(target_id) + for target_id in self.successor_map.get(vertex.id, []) + ] def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]: """Returns the neighbors of a vertex.""" @@ -703,7 +736,9 @@ class Graph: edges_added.add((source.id, target.id)) return edges - def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: + def _get_vertex_class( + self, node_type: str, node_base_type: str, node_id: str + ) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type node_name = node_id.split("-")[0] @@ -736,14 +771,18 @@ class Graph: vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) + VertexClass = self._get_vertex_class( + vertex_type, vertex_base_type, vertex_data["id"] + ) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) return vertices - def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: + def get_children_by_vertex_type( + self, vertex: Vertex, vertex_type: str + ) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] vertex_types = [vertex.data["type"]] @@ -755,7 +794,9 @@ class Graph: def __repr__(self): vertex_ids = [vertex.id for vertex in self.vertices] - edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]) + edges_repr = "\n".join( + [f"{edge.source_id} --> {edge.target_id}" for edge in self.edges] + ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]: @@ -823,7 +864,8 @@ class Graph: vertex.id for vertex in vertices # if filter_graphs then only vertex.is_input will be considered - if self.in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input) + if self.in_degree_map[vertex.id] == 0 + and (not filter_graphs or vertex.is_input) ) layers: List[List[str]] = [] visited = set(queue) @@ -897,7 +939,9 @@ class Graph: return refined_layers - def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_chat_inputs_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: chat_inputs_first = [] for layer in vertices_layers: for vertex_id in layer: @@ -938,7 +982,9 @@ class Graph: first_layer = vertices_layers[0] # save the only the rest self.vertices_layers = vertices_layers[1:] - self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)} + self.vertices_to_run = { + vertex_id for vertex_id in chain.from_iterable(vertices_layers) + } # Return just the first layer return first_layer @@ -949,11 +995,15 @@ class Graph: self.vertices_to_run.remove(vertex_id) return should_run - def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_interface_components_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: """Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first.""" def contains_interface_component(vertex): - return any(component.value in vertex for component in InterfaceComponentTypes) + return any( + component.value in vertex for component in InterfaceComponentTypes + ) # Sort each inner list so that vertices containing ChatInput or ChatOutput come first sorted_vertices = [ @@ -965,16 +1015,22 @@ class Graph: ] return sorted_vertices - def sort_by_avg_build_time(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_by_avg_build_time( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" if len(vertices_ids) == 1: return vertices_ids - vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time) + vertices_ids.sort( + key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time + ) return vertices_ids - sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers] + sorted_vertices = [ + sort_layer_by_avg_build_time(layer) for layer in vertices_layers + ] return sorted_vertices