add or remove new edges
This commit is contained in:
parent
d7f1a5cc8c
commit
5cbeb8abb1
1 changed files with 32 additions and 60 deletions
|
|
@ -85,9 +85,7 @@ class Graph:
|
|||
def build_parent_child_map(self):
|
||||
parent_child_map = defaultdict(list)
|
||||
for vertex in self.vertices:
|
||||
parent_child_map[vertex.id] = [
|
||||
child.id for child in self.get_successors(vertex)
|
||||
]
|
||||
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
|
||||
return parent_child_map
|
||||
|
||||
def increment_run_count(self):
|
||||
|
|
@ -176,7 +174,8 @@ class Graph:
|
|||
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
|
||||
|
||||
# Create a set for new edges
|
||||
new_edges = set()
|
||||
edges_to_add = set()
|
||||
edges_to_remove = set()
|
||||
|
||||
# Update existing vertices that have changed
|
||||
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
|
||||
|
|
@ -197,13 +196,24 @@ class Graph:
|
|||
self_vertex.set_top_level(self.top_level_vertices)
|
||||
self.reset_all_edges_of_vertex(self_vertex)
|
||||
if not self.vertex_edges_are_identical(self_vertex, other_vertex):
|
||||
new_edges.update(other_vertex.edges)
|
||||
# New edges are the edges of the other vertex and not the self vertex
|
||||
# If there are more edges in the other vertex than in the self vertex
|
||||
# then we need to add the new edges to the self vertex
|
||||
# if there are less edges in the other vertex than in the self vertex
|
||||
# then we need to remove the edges that are not in the other vertex
|
||||
|
||||
if len(self_vertex.edges) < len(other_vertex.edges):
|
||||
edges_to_add.update(edge for edge in other_vertex.edges if edge not in self_vertex.edges)
|
||||
elif len(self_vertex.edges) > len(other_vertex.edges):
|
||||
edges_to_remove.update(edge for edge in self_vertex.edges if edge not in other_vertex.edges)
|
||||
|
||||
# Add new edges
|
||||
# to self.edges if they are not already in self.edges
|
||||
for edge in new_edges:
|
||||
for edge in edges_to_add:
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
for edge in edges_to_remove:
|
||||
self.edges.remove(edge)
|
||||
|
||||
# Remove vertices
|
||||
for vertex_id in removed_vertex_ids:
|
||||
|
|
@ -280,11 +290,7 @@ class Graph:
|
|||
return
|
||||
self.vertices.remove(vertex)
|
||||
self.vertex_map.pop(vertex_id)
|
||||
self.edges = [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.source_id != vertex_id and edge.target_id != vertex_id
|
||||
]
|
||||
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
|
||||
|
||||
def _build_vertex_params(self) -> None:
|
||||
"""Identifies and handles the LLM vertex within the graph."""
|
||||
|
|
@ -305,9 +311,7 @@ class Graph:
|
|||
return
|
||||
for vertex in self.vertices:
|
||||
if not self._validate_vertex(vertex):
|
||||
raise ValueError(
|
||||
f"{vertex.vertex_type} is not connected to any other components"
|
||||
)
|
||||
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
|
||||
|
||||
def _validate_vertex(self, vertex: Vertex) -> bool:
|
||||
"""Validates a vertex."""
|
||||
|
|
@ -323,11 +327,7 @@ class Graph:
|
|||
|
||||
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
|
||||
"""Returns a list of edges for a given vertex."""
|
||||
return [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.source_id == vertex_id or edge.target_id == vertex_id
|
||||
]
|
||||
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
|
||||
|
||||
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
|
||||
"""Returns the vertices connected to a vertex."""
|
||||
|
|
@ -365,9 +365,7 @@ class Graph:
|
|||
def dfs(vertex):
|
||||
if state[vertex] == 1:
|
||||
# We have a cycle
|
||||
raise ValueError(
|
||||
"Graph contains a cycle, cannot perform topological sort"
|
||||
)
|
||||
raise ValueError("Graph contains a cycle, cannot perform topological sort")
|
||||
if state[vertex] == 0:
|
||||
state[vertex] = 1
|
||||
for edge in vertex.edges:
|
||||
|
|
@ -391,17 +389,11 @@ class Graph:
|
|||
|
||||
def get_predecessors(self, vertex):
|
||||
"""Returns the predecessors of a vertex."""
|
||||
return [
|
||||
self.get_vertex(source_id)
|
||||
for source_id in self.predecessor_map.get(vertex.id, [])
|
||||
]
|
||||
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
|
||||
|
||||
def get_successors(self, vertex):
|
||||
"""Returns the successors of a vertex."""
|
||||
return [
|
||||
self.get_vertex(target_id)
|
||||
for target_id in self.successor_map.get(vertex.id, [])
|
||||
]
|
||||
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
|
||||
|
||||
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
|
||||
"""Returns the neighbors of a vertex."""
|
||||
|
|
@ -440,9 +432,7 @@ class Graph:
|
|||
edges.append(ContractEdge(source, target, edge))
|
||||
return edges
|
||||
|
||||
def _get_vertex_class(
|
||||
self, node_type: str, node_base_type: str, node_id: str
|
||||
) -> Type[Vertex]:
|
||||
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
|
||||
"""Returns the node class based on the node type."""
|
||||
# First we check for the node_base_type
|
||||
node_name = node_id.split("-")[0]
|
||||
|
|
@ -473,18 +463,14 @@ class Graph:
|
|||
vertex_type: str = vertex_data["type"] # type: ignore
|
||||
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
|
||||
VertexClass = self._get_vertex_class(
|
||||
vertex_type, vertex_base_type, vertex_data["id"]
|
||||
)
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
|
||||
vertex_instance = VertexClass(vertex, graph=self)
|
||||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
vertices.append(vertex_instance)
|
||||
|
||||
return vertices
|
||||
|
||||
def get_children_by_vertex_type(
|
||||
self, vertex: Vertex, vertex_type: str
|
||||
) -> List[Vertex]:
|
||||
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
|
||||
"""Returns the children of a vertex based on the vertex type."""
|
||||
children = []
|
||||
vertex_types = [vertex.data["type"]]
|
||||
|
|
@ -496,9 +482,7 @@ class Graph:
|
|||
|
||||
def __repr__(self):
|
||||
vertex_ids = [vertex.id for vertex in self.vertices]
|
||||
edges_repr = "\n".join(
|
||||
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
|
||||
)
|
||||
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
|
||||
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
|
||||
|
||||
def sort_up_to_vertex(self, vertex_id: str) -> "Graph":
|
||||
|
|
@ -529,9 +513,7 @@ class Graph:
|
|||
"""Performs a layered topological sort of the vertices in the graph."""
|
||||
|
||||
# Queue for vertices with no incoming edges
|
||||
queue = deque(
|
||||
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
|
||||
)
|
||||
queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0)
|
||||
layers = []
|
||||
|
||||
current_layer = 0
|
||||
|
|
@ -587,9 +569,7 @@ class Graph:
|
|||
|
||||
return refined_layers
|
||||
|
||||
def sort_chat_inputs_first(
|
||||
self, vertices_layers: List[List[str]]
|
||||
) -> List[List[str]]:
|
||||
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
chat_inputs_first = []
|
||||
for layer in vertices_layers:
|
||||
for vertex_id in layer:
|
||||
|
|
@ -617,15 +597,11 @@ class Graph:
|
|||
self.increment_run_count()
|
||||
return vertices_layers
|
||||
|
||||
def sort_interface_components_first(
|
||||
self, vertices_layers: List[List[str]]
|
||||
) -> List[List[str]]:
|
||||
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
|
||||
|
||||
def contains_interface_component(vertex):
|
||||
return any(
|
||||
component.value in vertex for component in InterfaceComponentTypes
|
||||
)
|
||||
return any(component.value in vertex for component in InterfaceComponentTypes)
|
||||
|
||||
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
|
||||
sorted_vertices = [
|
||||
|
|
@ -644,13 +620,9 @@ class Graph:
|
|||
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
|
||||
if len(vertices_ids) == 1:
|
||||
return vertices_ids
|
||||
vertices_ids.sort(
|
||||
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
|
||||
)
|
||||
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
|
||||
|
||||
return vertices_ids
|
||||
|
||||
sorted_vertices = [
|
||||
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
|
||||
]
|
||||
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
|
||||
return sorted_vertices
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue