Refactor Graph update method to handle new and removed vertices
This commit is contained in:
parent
790a9bc112
commit
fdfbe810c9
1 changed files with 92 additions and 32 deletions
|
|
@ -82,24 +82,60 @@ class Graph:
|
|||
# then update the .data of the vertex to the self
|
||||
# both graphs have the same vertices and edges
|
||||
# but the data of the vertices might be different
|
||||
def update(self, other: "Graph", different_vertices: List[str] = None) -> None:
|
||||
if different_vertices is None:
|
||||
different_vertices = []
|
||||
for vertex in self.vertices:
|
||||
other_vertex = other.get_vertex(vertex.id)
|
||||
if other_vertex is None:
|
||||
continue
|
||||
if (
|
||||
vertex.id in different_vertices
|
||||
or vertex.__repr__() != other.get_vertex(vertex.id).__repr__()
|
||||
):
|
||||
vertex.data = other.get_vertex(vertex.id).data
|
||||
vertex.params = {}
|
||||
vertex._build_params()
|
||||
vertex.graph = self
|
||||
vertex._built = False
|
||||
|
||||
def update(self, other: "Graph") -> None:
|
||||
# Existing vertices in self graph
|
||||
existing_vertex_ids = set(vertex.id for vertex in self.vertices)
|
||||
# Vertex IDs in the other graph
|
||||
other_vertex_ids = set(other.vertex_map.keys())
|
||||
|
||||
# Find vertices that are in other but not in self (new vertices)
|
||||
new_vertex_ids = other_vertex_ids - existing_vertex_ids
|
||||
|
||||
# Find vertices that are in self but not in other (removed vertices)
|
||||
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
|
||||
|
||||
# Update existing vertices that have changed
|
||||
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
|
||||
self_vertex = self.get_vertex(vertex_id)
|
||||
other_vertex = other.get_vertex(vertex_id)
|
||||
if self_vertex.__repr__() != other_vertex.__repr__():
|
||||
self_vertex.data = other_vertex.data
|
||||
self_vertex.params = {}
|
||||
self_vertex._build_params()
|
||||
self_vertex.graph = self
|
||||
self_vertex._built = False
|
||||
self.reset_all_edges_of_vertex(self_vertex)
|
||||
|
||||
# Remove vertices
|
||||
for vertex_id in removed_vertex_ids:
|
||||
self.remove_vertex(vertex_id)
|
||||
|
||||
# Add new vertices
|
||||
for vertex_id in new_vertex_ids:
|
||||
new_vertex = other.get_vertex(vertex_id)
|
||||
self._add_vertex(new_vertex)
|
||||
|
||||
return self
|
||||
|
||||
def reset_all_edges_of_vertex(self, vertex: Vertex) -> None:
|
||||
"""Resets all the edges of a vertex."""
|
||||
for edge in vertex.edges:
|
||||
for vid in [edge.source_id, edge.target_id]:
|
||||
if vid in self.vertex_map:
|
||||
_vertex = self.vertex_map[vid]
|
||||
if not _vertex.pinned:
|
||||
_vertex._build_params()
|
||||
|
||||
def _add_vertex(self, vertex: Vertex) -> None:
|
||||
"""Adds a new vertex to the graph."""
|
||||
self.vertices.append(vertex)
|
||||
self.vertex_map[vertex.id] = vertex
|
||||
# Vertex has edges, so we need to update the edges
|
||||
for edge in vertex.edges:
|
||||
if edge.source_id in self.vertex_map and edge.target_id in self.vertex_map:
|
||||
self.edges.append(edge)
|
||||
|
||||
def _build_graph(self) -> None:
|
||||
"""Builds the graph from the vertices and edges."""
|
||||
self.vertices = self._build_vertices()
|
||||
|
|
@ -137,6 +173,19 @@ class Graph:
|
|||
source_vertex.has_external_output = True
|
||||
self.outputs.append(target_vertex.id)
|
||||
|
||||
def remove_vertex(self, vertex_id: str) -> None:
|
||||
"""Removes a vertex from the graph."""
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
return
|
||||
self.vertices.remove(vertex)
|
||||
self.vertex_map.pop(vertex_id)
|
||||
self.edges = [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.source_id != vertex_id and edge.target_id != vertex_id
|
||||
]
|
||||
|
||||
def _build_vertex_params(self) -> None:
|
||||
"""Identifies and handles the LLM vertex within the graph."""
|
||||
llm_vertex = None
|
||||
|
|
@ -339,24 +388,35 @@ class Graph:
|
|||
# Get the vertices that are connected to the vertex
|
||||
# and the vertex itself
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
vertices = [vertex]
|
||||
for edge in vertex.edges:
|
||||
if edge.target_id == vertex_id:
|
||||
vertices.append(self.get_vertex(edge.source_id))
|
||||
|
||||
# Get the edges that are connected to the vertices
|
||||
# We need to remove the edge coming from the vertex
|
||||
# and all vertices after the vertex
|
||||
edges = []
|
||||
for vertex in vertices:
|
||||
edges.extend(self.get_vertex_edges(vertex.id))
|
||||
source_vertex = self.get_vertex(edge.source_id)
|
||||
target_vertex = self.get_vertex(edge.target_id)
|
||||
if source_vertex not in vertices:
|
||||
vertices.append(source_vertex)
|
||||
if target_vertex not in vertices:
|
||||
vertices.append(target_vertex)
|
||||
|
||||
edges = [edge for vertex in vertices for edge in vertex.edges]
|
||||
vertices_to_remove = []
|
||||
for edge in vertex.edges:
|
||||
if edge.source_id != vertex_id:
|
||||
edges.append(edge)
|
||||
if edge.target_id != vertex_id:
|
||||
vertices_to_remove.append(self.get_vertex(edge.target_id))
|
||||
|
||||
for vertex in vertices_to_remove:
|
||||
for edge in vertex.edges:
|
||||
if edge.target_id == vertex.id:
|
||||
continue
|
||||
vertices_to_remove.append(self.get_vertex(edge.target_id))
|
||||
vertices_to_remove = set(vertices_to_remove)
|
||||
vertices = []
|
||||
edges_to_remove = {
|
||||
edge for vertex in vertices_to_remove for edge in vertex.edges
|
||||
}
|
||||
for vertex in self.vertices:
|
||||
if vertex in vertices_to_remove:
|
||||
continue
|
||||
vertices.append(vertex)
|
||||
for edge in vertex.edges:
|
||||
if edge in edges_to_remove:
|
||||
continue
|
||||
if edge not in edges:
|
||||
edges.append(edge)
|
||||
vertices = self.layered_topological_sort(vertices, edges)
|
||||
return self.sort_chat_inputs_first(vertices)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue