From fdfbe810c986f61ee0096f249a3ad8a4e0f12ed2 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 20 Feb 2024 12:17:41 -0300 Subject: [PATCH] Refactor Graph update method to handle new and removed vertices --- src/backend/langflow/graph/graph/base.py | 124 +++++++++++++++++------ 1 file changed, 92 insertions(+), 32 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 719f8be01..3c772c897 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -82,24 +82,60 @@ class Graph: # then update the .data of the vertex to the self # both graphs have the same vertices and edges # but the data of the vertices might be different - def update(self, other: "Graph", different_vertices: List[str] = None) -> None: - if different_vertices is None: - different_vertices = [] - for vertex in self.vertices: - other_vertex = other.get_vertex(vertex.id) - if other_vertex is None: - continue - if ( - vertex.id in different_vertices - or vertex.__repr__() != other.get_vertex(vertex.id).__repr__() - ): - vertex.data = other.get_vertex(vertex.id).data - vertex.params = {} - vertex._build_params() - vertex.graph = self - vertex._built = False + + def update(self, other: "Graph") -> None: + # Existing vertices in self graph + existing_vertex_ids = set(vertex.id for vertex in self.vertices) + # Vertex IDs in the other graph + other_vertex_ids = set(other.vertex_map.keys()) + + # Find vertices that are in other but not in self (new vertices) + new_vertex_ids = other_vertex_ids - existing_vertex_ids + + # Find vertices that are in self but not in other (removed vertices) + removed_vertex_ids = existing_vertex_ids - other_vertex_ids + + # Update existing vertices that have changed + for vertex_id in existing_vertex_ids.intersection(other_vertex_ids): + self_vertex = self.get_vertex(vertex_id) + other_vertex = other.get_vertex(vertex_id) + if self_vertex.__repr__() != other_vertex.__repr__(): + self_vertex.data = other_vertex.data + self_vertex.params = {} + self_vertex._build_params() + self_vertex.graph = self + self_vertex._built = False + self.reset_all_edges_of_vertex(self_vertex) + + # Remove vertices + for vertex_id in removed_vertex_ids: + self.remove_vertex(vertex_id) + + # Add new vertices + for vertex_id in new_vertex_ids: + new_vertex = other.get_vertex(vertex_id) + self._add_vertex(new_vertex) + return self + def reset_all_edges_of_vertex(self, vertex: Vertex) -> None: + """Resets all the edges of a vertex.""" + for edge in vertex.edges: + for vid in [edge.source_id, edge.target_id]: + if vid in self.vertex_map: + _vertex = self.vertex_map[vid] + if not _vertex.pinned: + _vertex._build_params() + + def _add_vertex(self, vertex: Vertex) -> None: + """Adds a new vertex to the graph.""" + self.vertices.append(vertex) + self.vertex_map[vertex.id] = vertex + # Vertex has edges, so we need to update the edges + for edge in vertex.edges: + if edge.source_id in self.vertex_map and edge.target_id in self.vertex_map: + self.edges.append(edge) + def _build_graph(self) -> None: """Builds the graph from the vertices and edges.""" self.vertices = self._build_vertices() @@ -137,6 +173,19 @@ class Graph: source_vertex.has_external_output = True self.outputs.append(target_vertex.id) + def remove_vertex(self, vertex_id: str) -> None: + """Removes a vertex from the graph.""" + vertex = self.get_vertex(vertex_id) + if vertex is None: + return + self.vertices.remove(vertex) + self.vertex_map.pop(vertex_id) + self.edges = [ + edge + for edge in self.edges + if edge.source_id != vertex_id and edge.target_id != vertex_id + ] + def _build_vertex_params(self) -> None: """Identifies and handles the LLM vertex within the graph.""" llm_vertex = None @@ -339,24 +388,35 @@ class Graph: # Get the vertices that are connected to the vertex # and the vertex itself vertex = self.get_vertex(vertex_id) - vertices = [vertex] - for edge in vertex.edges: - if edge.target_id == vertex_id: - vertices.append(self.get_vertex(edge.source_id)) - - # Get the edges that are connected to the vertices + # We need to remove the edge coming from the vertex + # and all vertices after the vertex edges = [] - for vertex in vertices: - edges.extend(self.get_vertex_edges(vertex.id)) - source_vertex = self.get_vertex(edge.source_id) - target_vertex = self.get_vertex(edge.target_id) - if source_vertex not in vertices: - vertices.append(source_vertex) - if target_vertex not in vertices: - vertices.append(target_vertex) - - edges = [edge for vertex in vertices for edge in vertex.edges] + vertices_to_remove = [] + for edge in vertex.edges: + if edge.source_id != vertex_id: + edges.append(edge) + if edge.target_id != vertex_id: + vertices_to_remove.append(self.get_vertex(edge.target_id)) + for vertex in vertices_to_remove: + for edge in vertex.edges: + if edge.target_id == vertex.id: + continue + vertices_to_remove.append(self.get_vertex(edge.target_id)) + vertices_to_remove = set(vertices_to_remove) + vertices = [] + edges_to_remove = { + edge for vertex in vertices_to_remove for edge in vertex.edges + } + for vertex in self.vertices: + if vertex in vertices_to_remove: + continue + vertices.append(vertex) + for edge in vertex.edges: + if edge in edges_to_remove: + continue + if edge not in edges: + edges.append(edge) vertices = self.layered_topological_sort(vertices, edges) return self.sort_chat_inputs_first(vertices)