diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 59a8966b5..de126374d 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -85,7 +85,9 @@ class Graph: def build_parent_child_map(self): parent_child_map = defaultdict(list) for vertex in self.vertices: - parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)] + parent_child_map[vertex.id] = [ + child.id for child in self.get_successors(vertex) + ] return parent_child_map def increment_run_count(self): @@ -149,6 +151,16 @@ class Graph: # both graphs have the same vertices and edges # but the data of the vertices might be different + def update_edges_from_vertex(self, vertex: Vertex, other_vertex: Vertex) -> None: + """Updates the edges of a vertex in the Graph.""" + new_edges = [] + for edge in self.edges: + if edge.source_id == other_vertex.id or edge.target_id == other_vertex.id: + continue + new_edges.append(edge) + new_edges += other_vertex.edges + self.edges = new_edges + def vertex_data_is_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool: return vertex.__repr__() == other_vertex.__repr__() @@ -173,10 +185,6 @@ class Graph: # Find vertices that are in self but not in other (removed vertices) removed_vertex_ids = existing_vertex_ids - other_vertex_ids - # Create a set for new edges - edges_to_add = set() - edges_to_remove = set() - # Update existing vertices that have changed for vertex_id in existing_vertex_ids.intersection(other_vertex_ids): self_vertex = self.get_vertex(vertex_id) @@ -184,6 +192,8 @@ class Graph: if not self.vertex_data_is_identical(self_vertex, other_vertex): self_vertex._data = other_vertex._data self_vertex._parse_data() + # Now we update the edges of the vertex + self.update_edges_from_vertex(self_vertex, other_vertex) self_vertex.params = {} self_vertex._build_params() self_vertex.graph = self @@ -195,25 +205,6 @@ class Graph: self_vertex.artifacts = None self_vertex.set_top_level(self.top_level_vertices) self.reset_all_edges_of_vertex(self_vertex) - if not self.vertex_edges_are_identical(self_vertex, other_vertex): - # New edges are the edges of the other vertex and not the self vertex - # If there are more edges in the other vertex than in the self vertex - # then we need to add the new edges to the self vertex - # if there are less edges in the other vertex than in the self vertex - # then we need to remove the edges that are not in the other vertex - - if len(self_vertex.edges) < len(other_vertex.edges): - edges_to_add.update(edge for edge in other_vertex.edges if edge not in self_vertex.edges) - elif len(self_vertex.edges) > len(other_vertex.edges): - edges_to_remove.update(edge for edge in self_vertex.edges if edge not in other_vertex.edges) - - # Add new edges - # to self.edges if they are not already in self.edges - for edge in edges_to_add: - if edge not in self.edges: - self.edges.append(edge) - for edge in edges_to_remove: - self.edges.remove(edge) # Remove vertices for vertex_id in removed_vertex_ids: @@ -290,7 +281,11 @@ class Graph: return self.vertices.remove(vertex) self.vertex_map.pop(vertex_id) - self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id] + self.edges = [ + edge + for edge in self.edges + if edge.source_id != vertex_id and edge.target_id != vertex_id + ] def _build_vertex_params(self) -> None: """Identifies and handles the LLM vertex within the graph.""" @@ -311,7 +306,9 @@ class Graph: return for vertex in self.vertices: if not self._validate_vertex(vertex): - raise ValueError(f"{vertex.vertex_type} is not connected to any other components") + raise ValueError( + f"{vertex.vertex_type} is not connected to any other components" + ) def _validate_vertex(self, vertex: Vertex) -> bool: """Validates a vertex.""" @@ -327,7 +324,11 @@ class Graph: def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]: """Returns a list of edges for a given vertex.""" - return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id] + return [ + edge + for edge in self.edges + if edge.source_id == vertex_id or edge.target_id == vertex_id + ] def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]: """Returns the vertices connected to a vertex.""" @@ -365,7 +366,9 @@ class Graph: def dfs(vertex): if state[vertex] == 1: # We have a cycle - raise ValueError("Graph contains a cycle, cannot perform topological sort") + raise ValueError( + "Graph contains a cycle, cannot perform topological sort" + ) if state[vertex] == 0: state[vertex] = 1 for edge in vertex.edges: @@ -389,11 +392,17 @@ class Graph: def get_predecessors(self, vertex): """Returns the predecessors of a vertex.""" - return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])] + return [ + self.get_vertex(source_id) + for source_id in self.predecessor_map.get(vertex.id, []) + ] def get_successors(self, vertex): """Returns the successors of a vertex.""" - return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])] + return [ + self.get_vertex(target_id) + for target_id in self.successor_map.get(vertex.id, []) + ] def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]: """Returns the neighbors of a vertex.""" @@ -432,7 +441,9 @@ class Graph: edges.append(ContractEdge(source, target, edge)) return edges - def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: + def _get_vertex_class( + self, node_type: str, node_base_type: str, node_id: str + ) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type node_name = node_id.split("-")[0] @@ -463,14 +474,18 @@ class Graph: vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) + VertexClass = self._get_vertex_class( + vertex_type, vertex_base_type, vertex_data["id"] + ) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) return vertices - def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: + def get_children_by_vertex_type( + self, vertex: Vertex, vertex_type: str + ) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] vertex_types = [vertex.data["type"]] @@ -482,7 +497,9 @@ class Graph: def __repr__(self): vertex_ids = [vertex.id for vertex in self.vertices] - edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]) + edges_repr = "\n".join( + [f"{edge.source_id} --> {edge.target_id}" for edge in self.edges] + ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" def sort_up_to_vertex(self, vertex_id: str) -> "Graph": @@ -513,7 +530,9 @@ class Graph: """Performs a layered topological sort of the vertices in the graph.""" # Queue for vertices with no incoming edges - queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0) + queue = deque( + vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0 + ) layers = [] current_layer = 0 @@ -569,7 +588,9 @@ class Graph: return refined_layers - def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_chat_inputs_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: chat_inputs_first = [] for layer in vertices_layers: for vertex_id in layer: @@ -597,11 +618,15 @@ class Graph: self.increment_run_count() return vertices_layers - def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_interface_components_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: """Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first.""" def contains_interface_component(vertex): - return any(component.value in vertex for component in InterfaceComponentTypes) + return any( + component.value in vertex for component in InterfaceComponentTypes + ) # Sort each inner list so that vertices containing ChatInput or ChatOutput come first sorted_vertices = [ @@ -620,9 +645,13 @@ class Graph: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" if len(vertices_ids) == 1: return vertices_ids - vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time) + vertices_ids.sort( + key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time + ) return vertices_ids - sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers] + sorted_vertices = [ + sort_layer_by_avg_build_time(layer) for layer in vertices_layers + ] return sorted_vertices