diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 325fa7b13..59a8966b5 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -85,9 +85,7 @@ class Graph: def build_parent_child_map(self): parent_child_map = defaultdict(list) for vertex in self.vertices: - parent_child_map[vertex.id] = [ - child.id for child in self.get_successors(vertex) - ] + parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)] return parent_child_map def increment_run_count(self): @@ -176,7 +174,8 @@ class Graph: removed_vertex_ids = existing_vertex_ids - other_vertex_ids # Create a set for new edges - new_edges = set() + edges_to_add = set() + edges_to_remove = set() # Update existing vertices that have changed for vertex_id in existing_vertex_ids.intersection(other_vertex_ids): @@ -197,13 +196,24 @@ class Graph: self_vertex.set_top_level(self.top_level_vertices) self.reset_all_edges_of_vertex(self_vertex) if not self.vertex_edges_are_identical(self_vertex, other_vertex): - new_edges.update(other_vertex.edges) + # New edges are the edges of the other vertex and not the self vertex + # If there are more edges in the other vertex than in the self vertex + # then we need to add the new edges to the self vertex + # if there are less edges in the other vertex than in the self vertex + # then we need to remove the edges that are not in the other vertex + + if len(self_vertex.edges) < len(other_vertex.edges): + edges_to_add.update(edge for edge in other_vertex.edges if edge not in self_vertex.edges) + elif len(self_vertex.edges) > len(other_vertex.edges): + edges_to_remove.update(edge for edge in self_vertex.edges if edge not in other_vertex.edges) # Add new edges # to self.edges if they are not already in self.edges - for edge in new_edges: + for edge in edges_to_add: if edge not in self.edges: self.edges.append(edge) + for edge in edges_to_remove: + self.edges.remove(edge) # Remove vertices for vertex_id in removed_vertex_ids: @@ -280,11 +290,7 @@ class Graph: return self.vertices.remove(vertex) self.vertex_map.pop(vertex_id) - self.edges = [ - edge - for edge in self.edges - if edge.source_id != vertex_id and edge.target_id != vertex_id - ] + self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id] def _build_vertex_params(self) -> None: """Identifies and handles the LLM vertex within the graph.""" @@ -305,9 +311,7 @@ class Graph: return for vertex in self.vertices: if not self._validate_vertex(vertex): - raise ValueError( - f"{vertex.vertex_type} is not connected to any other components" - ) + raise ValueError(f"{vertex.vertex_type} is not connected to any other components") def _validate_vertex(self, vertex: Vertex) -> bool: """Validates a vertex.""" @@ -323,11 +327,7 @@ class Graph: def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]: """Returns a list of edges for a given vertex.""" - return [ - edge - for edge in self.edges - if edge.source_id == vertex_id or edge.target_id == vertex_id - ] + return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id] def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]: """Returns the vertices connected to a vertex.""" @@ -365,9 +365,7 @@ class Graph: def dfs(vertex): if state[vertex] == 1: # We have a cycle - raise ValueError( - "Graph contains a cycle, cannot perform topological sort" - ) + raise ValueError("Graph contains a cycle, cannot perform topological sort") if state[vertex] == 0: state[vertex] = 1 for edge in vertex.edges: @@ -391,17 +389,11 @@ class Graph: def get_predecessors(self, vertex): """Returns the predecessors of a vertex.""" - return [ - self.get_vertex(source_id) - for source_id in self.predecessor_map.get(vertex.id, []) - ] + return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])] def get_successors(self, vertex): """Returns the successors of a vertex.""" - return [ - self.get_vertex(target_id) - for target_id in self.successor_map.get(vertex.id, []) - ] + return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])] def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]: """Returns the neighbors of a vertex.""" @@ -440,9 +432,7 @@ class Graph: edges.append(ContractEdge(source, target, edge)) return edges - def _get_vertex_class( - self, node_type: str, node_base_type: str, node_id: str - ) -> Type[Vertex]: + def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type node_name = node_id.split("-")[0] @@ -473,18 +463,14 @@ class Graph: vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class( - vertex_type, vertex_base_type, vertex_data["id"] - ) + VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) return vertices - def get_children_by_vertex_type( - self, vertex: Vertex, vertex_type: str - ) -> List[Vertex]: + def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] vertex_types = [vertex.data["type"]] @@ -496,9 +482,7 @@ class Graph: def __repr__(self): vertex_ids = [vertex.id for vertex in self.vertices] - edges_repr = "\n".join( - [f"{edge.source_id} --> {edge.target_id}" for edge in self.edges] - ) + edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" def sort_up_to_vertex(self, vertex_id: str) -> "Graph": @@ -529,9 +513,7 @@ class Graph: """Performs a layered topological sort of the vertices in the graph.""" # Queue for vertices with no incoming edges - queue = deque( - vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0 - ) + queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0) layers = [] current_layer = 0 @@ -587,9 +569,7 @@ class Graph: return refined_layers - def sort_chat_inputs_first( - self, vertices_layers: List[List[str]] - ) -> List[List[str]]: + def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: chat_inputs_first = [] for layer in vertices_layers: for vertex_id in layer: @@ -617,15 +597,11 @@ class Graph: self.increment_run_count() return vertices_layers - def sort_interface_components_first( - self, vertices_layers: List[List[str]] - ) -> List[List[str]]: + def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: """Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first.""" def contains_interface_component(vertex): - return any( - component.value in vertex for component in InterfaceComponentTypes - ) + return any(component.value in vertex for component in InterfaceComponentTypes) # Sort each inner list so that vertices containing ChatInput or ChatOutput come first sorted_vertices = [ @@ -644,13 +620,9 @@ class Graph: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" if len(vertices_ids) == 1: return vertices_ids - vertices_ids.sort( - key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time - ) + vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time) return vertices_ids - sorted_vertices = [ - sort_layer_by_avg_build_time(layer) for layer in vertices_layers - ] + sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers] return sorted_vertices