add or remove new edges

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 16:18:16 -03:00
commit 5cbeb8abb1

View file

@ -85,9 +85,7 @@ class Graph:
def build_parent_child_map(self):
parent_child_map = defaultdict(list)
for vertex in self.vertices:
parent_child_map[vertex.id] = [
child.id for child in self.get_successors(vertex)
]
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
return parent_child_map
def increment_run_count(self):
@ -176,7 +174,8 @@ class Graph:
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
# Create a set for new edges
new_edges = set()
edges_to_add = set()
edges_to_remove = set()
# Update existing vertices that have changed
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
@ -197,13 +196,24 @@ class Graph:
self_vertex.set_top_level(self.top_level_vertices)
self.reset_all_edges_of_vertex(self_vertex)
if not self.vertex_edges_are_identical(self_vertex, other_vertex):
new_edges.update(other_vertex.edges)
# New edges are the edges of the other vertex and not the self vertex
# If there are more edges in the other vertex than in the self vertex
# then we need to add the new edges to the self vertex
# if there are less edges in the other vertex than in the self vertex
# then we need to remove the edges that are not in the other vertex
if len(self_vertex.edges) < len(other_vertex.edges):
edges_to_add.update(edge for edge in other_vertex.edges if edge not in self_vertex.edges)
elif len(self_vertex.edges) > len(other_vertex.edges):
edges_to_remove.update(edge for edge in self_vertex.edges if edge not in other_vertex.edges)
# Add new edges
# to self.edges if they are not already in self.edges
for edge in new_edges:
for edge in edges_to_add:
if edge not in self.edges:
self.edges.append(edge)
for edge in edges_to_remove:
self.edges.remove(edge)
# Remove vertices
for vertex_id in removed_vertex_ids:
@ -280,11 +290,7 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -305,9 +311,7 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(
f"{vertex.vertex_type} is not connected to any other components"
)
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -323,11 +327,7 @@ class Graph:
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
"""Returns a list of edges for a given vertex."""
return [
edge
for edge in self.edges
if edge.source_id == vertex_id or edge.target_id == vertex_id
]
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
"""Returns the vertices connected to a vertex."""
@ -365,9 +365,7 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
raise ValueError("Graph contains a cycle, cannot perform topological sort")
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -391,17 +389,11 @@ class Graph:
def get_predecessors(self, vertex):
"""Returns the predecessors of a vertex."""
return [
self.get_vertex(source_id)
for source_id in self.predecessor_map.get(vertex.id, [])
]
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [
self.get_vertex(target_id)
for target_id in self.successor_map.get(vertex.id, [])
]
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
@ -440,9 +432,7 @@ class Graph:
edges.append(ContractEdge(source, target, edge))
return edges
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -473,18 +463,14 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -496,9 +482,7 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str) -> "Graph":
@ -529,9 +513,7 @@ class Graph:
"""Performs a layered topological sort of the vertices in the graph."""
# Queue for vertices with no incoming edges
queue = deque(
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
)
queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0)
layers = []
current_layer = 0
@ -587,9 +569,7 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -617,15 +597,11 @@ class Graph:
self.increment_run_count()
return vertices_layers
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
return any(
component.value in vertex for component in InterfaceComponentTypes
)
return any(component.value in vertex for component in InterfaceComponentTypes)
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
sorted_vertices = [
@ -644,13 +620,9 @@ class Graph:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
return vertices_ids
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
return sorted_vertices