Implement update_edges_from_vertex

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 16:52:53 -03:00
commit 694418827d

View file

@ -85,7 +85,9 @@ class Graph:
def build_parent_child_map(self):
parent_child_map = defaultdict(list)
for vertex in self.vertices:
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
parent_child_map[vertex.id] = [
child.id for child in self.get_successors(vertex)
]
return parent_child_map
def increment_run_count(self):
@ -149,6 +151,16 @@ class Graph:
# both graphs have the same vertices and edges
# but the data of the vertices might be different
def update_edges_from_vertex(self, vertex: Vertex, other_vertex: Vertex) -> None:
"""Updates the edges of a vertex in the Graph."""
new_edges = []
for edge in self.edges:
if edge.source_id == other_vertex.id or edge.target_id == other_vertex.id:
continue
new_edges.append(edge)
new_edges += other_vertex.edges
self.edges = new_edges
def vertex_data_is_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool:
return vertex.__repr__() == other_vertex.__repr__()
@ -173,10 +185,6 @@ class Graph:
# Find vertices that are in self but not in other (removed vertices)
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
# Create a set for new edges
edges_to_add = set()
edges_to_remove = set()
# Update existing vertices that have changed
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
self_vertex = self.get_vertex(vertex_id)
@ -184,6 +192,8 @@ class Graph:
if not self.vertex_data_is_identical(self_vertex, other_vertex):
self_vertex._data = other_vertex._data
self_vertex._parse_data()
# Now we update the edges of the vertex
self.update_edges_from_vertex(self_vertex, other_vertex)
self_vertex.params = {}
self_vertex._build_params()
self_vertex.graph = self
@ -195,25 +205,6 @@ class Graph:
self_vertex.artifacts = None
self_vertex.set_top_level(self.top_level_vertices)
self.reset_all_edges_of_vertex(self_vertex)
if not self.vertex_edges_are_identical(self_vertex, other_vertex):
# New edges are the edges of the other vertex and not the self vertex
# If there are more edges in the other vertex than in the self vertex
# then we need to add the new edges to the self vertex
# if there are less edges in the other vertex than in the self vertex
# then we need to remove the edges that are not in the other vertex
if len(self_vertex.edges) < len(other_vertex.edges):
edges_to_add.update(edge for edge in other_vertex.edges if edge not in self_vertex.edges)
elif len(self_vertex.edges) > len(other_vertex.edges):
edges_to_remove.update(edge for edge in self_vertex.edges if edge not in other_vertex.edges)
# Add new edges
# to self.edges if they are not already in self.edges
for edge in edges_to_add:
if edge not in self.edges:
self.edges.append(edge)
for edge in edges_to_remove:
self.edges.remove(edge)
# Remove vertices
for vertex_id in removed_vertex_ids:
@ -290,7 +281,11 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -311,7 +306,9 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
raise ValueError(
f"{vertex.vertex_type} is not connected to any other components"
)
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -327,7 +324,11 @@ class Graph:
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
"""Returns a list of edges for a given vertex."""
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
return [
edge
for edge in self.edges
if edge.source_id == vertex_id or edge.target_id == vertex_id
]
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
"""Returns the vertices connected to a vertex."""
@ -365,7 +366,9 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -389,11 +392,17 @@ class Graph:
def get_predecessors(self, vertex):
"""Returns the predecessors of a vertex."""
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
return [
self.get_vertex(source_id)
for source_id in self.predecessor_map.get(vertex.id, [])
]
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
return [
self.get_vertex(target_id)
for target_id in self.successor_map.get(vertex.id, [])
]
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
@ -432,7 +441,9 @@ class Graph:
edges.append(ContractEdge(source, target, edge))
return edges
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -463,14 +474,18 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -482,7 +497,9 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str) -> "Graph":
@ -513,7 +530,9 @@ class Graph:
"""Performs a layered topological sort of the vertices in the graph."""
# Queue for vertices with no incoming edges
queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0)
queue = deque(
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
)
layers = []
current_layer = 0
@ -569,7 +588,9 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -597,11 +618,15 @@ class Graph:
self.increment_run_count()
return vertices_layers
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
return any(component.value in vertex for component in InterfaceComponentTypes)
return any(
component.value in vertex for component in InterfaceComponentTypes
)
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
sorted_vertices = [
@ -620,9 +645,13 @@ class Graph:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
return vertices_ids
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
return sorted_vertices