Refactor Graph update method to handle new and removed vertices

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-20 12:17:41 -03:00
commit fdfbe810c9

View file

@ -82,24 +82,60 @@ class Graph:
# then update the .data of the vertex to the self
# both graphs have the same vertices and edges
# but the data of the vertices might be different
def update(self, other: "Graph", different_vertices: List[str] = None) -> None:
if different_vertices is None:
different_vertices = []
for vertex in self.vertices:
other_vertex = other.get_vertex(vertex.id)
if other_vertex is None:
continue
if (
vertex.id in different_vertices
or vertex.__repr__() != other.get_vertex(vertex.id).__repr__()
):
vertex.data = other.get_vertex(vertex.id).data
vertex.params = {}
vertex._build_params()
vertex.graph = self
vertex._built = False
def update(self, other: "Graph") -> None:
# Existing vertices in self graph
existing_vertex_ids = set(vertex.id for vertex in self.vertices)
# Vertex IDs in the other graph
other_vertex_ids = set(other.vertex_map.keys())
# Find vertices that are in other but not in self (new vertices)
new_vertex_ids = other_vertex_ids - existing_vertex_ids
# Find vertices that are in self but not in other (removed vertices)
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
# Update existing vertices that have changed
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
self_vertex = self.get_vertex(vertex_id)
other_vertex = other.get_vertex(vertex_id)
if self_vertex.__repr__() != other_vertex.__repr__():
self_vertex.data = other_vertex.data
self_vertex.params = {}
self_vertex._build_params()
self_vertex.graph = self
self_vertex._built = False
self.reset_all_edges_of_vertex(self_vertex)
# Remove vertices
for vertex_id in removed_vertex_ids:
self.remove_vertex(vertex_id)
# Add new vertices
for vertex_id in new_vertex_ids:
new_vertex = other.get_vertex(vertex_id)
self._add_vertex(new_vertex)
return self
def reset_all_edges_of_vertex(self, vertex: Vertex) -> None:
"""Resets all the edges of a vertex."""
for edge in vertex.edges:
for vid in [edge.source_id, edge.target_id]:
if vid in self.vertex_map:
_vertex = self.vertex_map[vid]
if not _vertex.pinned:
_vertex._build_params()
def _add_vertex(self, vertex: Vertex) -> None:
"""Adds a new vertex to the graph."""
self.vertices.append(vertex)
self.vertex_map[vertex.id] = vertex
# Vertex has edges, so we need to update the edges
for edge in vertex.edges:
if edge.source_id in self.vertex_map and edge.target_id in self.vertex_map:
self.edges.append(edge)
def _build_graph(self) -> None:
"""Builds the graph from the vertices and edges."""
self.vertices = self._build_vertices()
@ -137,6 +173,19 @@ class Graph:
source_vertex.has_external_output = True
self.outputs.append(target_vertex.id)
def remove_vertex(self, vertex_id: str) -> None:
"""Removes a vertex from the graph."""
vertex = self.get_vertex(vertex_id)
if vertex is None:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
llm_vertex = None
@ -339,24 +388,35 @@ class Graph:
# Get the vertices that are connected to the vertex
# and the vertex itself
vertex = self.get_vertex(vertex_id)
vertices = [vertex]
for edge in vertex.edges:
if edge.target_id == vertex_id:
vertices.append(self.get_vertex(edge.source_id))
# Get the edges that are connected to the vertices
# We need to remove the edge coming from the vertex
# and all vertices after the vertex
edges = []
for vertex in vertices:
edges.extend(self.get_vertex_edges(vertex.id))
source_vertex = self.get_vertex(edge.source_id)
target_vertex = self.get_vertex(edge.target_id)
if source_vertex not in vertices:
vertices.append(source_vertex)
if target_vertex not in vertices:
vertices.append(target_vertex)
edges = [edge for vertex in vertices for edge in vertex.edges]
vertices_to_remove = []
for edge in vertex.edges:
if edge.source_id != vertex_id:
edges.append(edge)
if edge.target_id != vertex_id:
vertices_to_remove.append(self.get_vertex(edge.target_id))
for vertex in vertices_to_remove:
for edge in vertex.edges:
if edge.target_id == vertex.id:
continue
vertices_to_remove.append(self.get_vertex(edge.target_id))
vertices_to_remove = set(vertices_to_remove)
vertices = []
edges_to_remove = {
edge for vertex in vertices_to_remove for edge in vertex.edges
}
for vertex in self.vertices:
if vertex in vertices_to_remove:
continue
vertices.append(vertex)
for edge in vertex.edges:
if edge in edges_to_remove:
continue
if edge not in edges:
edges.append(edge)
vertices = self.layered_topological_sort(vertices, edges)
return self.sort_chat_inputs_first(vertices)