From ae629a528cce44f7c44ae7eacef3f6d14d5e7afa Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 26 Feb 2024 13:25:57 -0300 Subject: [PATCH] Add edges validation --- src/backend/langflow/graph/graph/base.py | 99 +++++++++++++++++++----- 1 file changed, 81 insertions(+), 18 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 453d7a98f..325fa7b13 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -85,7 +85,9 @@ class Graph: def build_parent_child_map(self): parent_child_map = defaultdict(list) for vertex in self.vertices: - parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)] + parent_child_map[vertex.id] = [ + child.id for child in self.get_successors(vertex) + ] return parent_child_map def increment_run_count(self): @@ -149,6 +151,18 @@ class Graph: # both graphs have the same vertices and edges # but the data of the vertices might be different + def vertex_data_is_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool: + return vertex.__repr__() == other_vertex.__repr__() + + def vertex_edges_are_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool: + same_length = len(vertex.edges) == len(other_vertex.edges) + if not same_length: + return False + for edge in vertex.edges: + if edge not in other_vertex.edges: + return False + return True + def update(self, other: "Graph") -> None: # Existing vertices in self graph existing_vertex_ids = set(vertex.id for vertex in self.vertices) @@ -161,11 +175,14 @@ class Graph: # Find vertices that are in self but not in other (removed vertices) removed_vertex_ids = existing_vertex_ids - other_vertex_ids + # Create a set for new edges + new_edges = set() + # Update existing vertices that have changed for vertex_id in existing_vertex_ids.intersection(other_vertex_ids): self_vertex = self.get_vertex(vertex_id) other_vertex = other.get_vertex(vertex_id) - if self_vertex.__repr__() != other_vertex.__repr__(): + if not self.vertex_data_is_identical(self_vertex, other_vertex): self_vertex._data = other_vertex._data self_vertex._parse_data() self_vertex.params = {} @@ -179,6 +196,14 @@ class Graph: self_vertex.artifacts = None self_vertex.set_top_level(self.top_level_vertices) self.reset_all_edges_of_vertex(self_vertex) + if not self.vertex_edges_are_identical(self_vertex, other_vertex): + new_edges.update(other_vertex.edges) + + # Add new edges + # to self.edges if they are not already in self.edges + for edge in new_edges: + if edge not in self.edges: + self.edges.append(edge) # Remove vertices for vertex_id in removed_vertex_ids: @@ -255,7 +280,11 @@ class Graph: return self.vertices.remove(vertex) self.vertex_map.pop(vertex_id) - self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id] + self.edges = [ + edge + for edge in self.edges + if edge.source_id != vertex_id and edge.target_id != vertex_id + ] def _build_vertex_params(self) -> None: """Identifies and handles the LLM vertex within the graph.""" @@ -276,7 +305,9 @@ class Graph: return for vertex in self.vertices: if not self._validate_vertex(vertex): - raise ValueError(f"{vertex.vertex_type} is not connected to any other components") + raise ValueError( + f"{vertex.vertex_type} is not connected to any other components" + ) def _validate_vertex(self, vertex: Vertex) -> bool: """Validates a vertex.""" @@ -292,7 +323,11 @@ class Graph: def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]: """Returns a list of edges for a given vertex.""" - return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id] + return [ + edge + for edge in self.edges + if edge.source_id == vertex_id or edge.target_id == vertex_id + ] def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]: """Returns the vertices connected to a vertex.""" @@ -330,7 +365,9 @@ class Graph: def dfs(vertex): if state[vertex] == 1: # We have a cycle - raise ValueError("Graph contains a cycle, cannot perform topological sort") + raise ValueError( + "Graph contains a cycle, cannot perform topological sort" + ) if state[vertex] == 0: state[vertex] = 1 for edge in vertex.edges: @@ -354,11 +391,17 @@ class Graph: def get_predecessors(self, vertex): """Returns the predecessors of a vertex.""" - return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])] + return [ + self.get_vertex(source_id) + for source_id in self.predecessor_map.get(vertex.id, []) + ] def get_successors(self, vertex): """Returns the successors of a vertex.""" - return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])] + return [ + self.get_vertex(target_id) + for target_id in self.successor_map.get(vertex.id, []) + ] def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]: """Returns the neighbors of a vertex.""" @@ -397,7 +440,9 @@ class Graph: edges.append(ContractEdge(source, target, edge)) return edges - def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: + def _get_vertex_class( + self, node_type: str, node_base_type: str, node_id: str + ) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type node_name = node_id.split("-")[0] @@ -428,14 +473,18 @@ class Graph: vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) + VertexClass = self._get_vertex_class( + vertex_type, vertex_base_type, vertex_data["id"] + ) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) return vertices - def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: + def get_children_by_vertex_type( + self, vertex: Vertex, vertex_type: str + ) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] vertex_types = [vertex.data["type"]] @@ -447,7 +496,9 @@ class Graph: def __repr__(self): vertex_ids = [vertex.id for vertex in self.vertices] - edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]) + edges_repr = "\n".join( + [f"{edge.source_id} --> {edge.target_id}" for edge in self.edges] + ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" def sort_up_to_vertex(self, vertex_id: str) -> "Graph": @@ -478,7 +529,9 @@ class Graph: """Performs a layered topological sort of the vertices in the graph.""" # Queue for vertices with no incoming edges - queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0) + queue = deque( + vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0 + ) layers = [] current_layer = 0 @@ -534,7 +587,9 @@ class Graph: return refined_layers - def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_chat_inputs_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: chat_inputs_first = [] for layer in vertices_layers: for vertex_id in layer: @@ -562,11 +617,15 @@ class Graph: self.increment_run_count() return vertices_layers - def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: + def sort_interface_components_first( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: """Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first.""" def contains_interface_component(vertex): - return any(component.value in vertex for component in InterfaceComponentTypes) + return any( + component.value in vertex for component in InterfaceComponentTypes + ) # Sort each inner list so that vertices containing ChatInput or ChatOutput come first sorted_vertices = [ @@ -585,9 +644,13 @@ class Graph: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" if len(vertices_ids) == 1: return vertices_ids - vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time) + vertices_ids.sort( + key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time + ) return vertices_ids - sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers] + sorted_vertices = [ + sort_layer_by_avg_build_time(layer) for layer in vertices_layers + ] return sorted_vertices