Add edges validation

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 13:25:57 -03:00
commit ae629a528c

View file

@ -85,7 +85,9 @@ class Graph:
def build_parent_child_map(self):
parent_child_map = defaultdict(list)
for vertex in self.vertices:
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
parent_child_map[vertex.id] = [
child.id for child in self.get_successors(vertex)
]
return parent_child_map
def increment_run_count(self):
@ -149,6 +151,18 @@ class Graph:
# both graphs have the same vertices and edges
# but the data of the vertices might be different
def vertex_data_is_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool:
return vertex.__repr__() == other_vertex.__repr__()
def vertex_edges_are_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool:
same_length = len(vertex.edges) == len(other_vertex.edges)
if not same_length:
return False
for edge in vertex.edges:
if edge not in other_vertex.edges:
return False
return True
def update(self, other: "Graph") -> None:
# Existing vertices in self graph
existing_vertex_ids = set(vertex.id for vertex in self.vertices)
@ -161,11 +175,14 @@ class Graph:
# Find vertices that are in self but not in other (removed vertices)
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
# Create a set for new edges
new_edges = set()
# Update existing vertices that have changed
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
self_vertex = self.get_vertex(vertex_id)
other_vertex = other.get_vertex(vertex_id)
if self_vertex.__repr__() != other_vertex.__repr__():
if not self.vertex_data_is_identical(self_vertex, other_vertex):
self_vertex._data = other_vertex._data
self_vertex._parse_data()
self_vertex.params = {}
@ -179,6 +196,14 @@ class Graph:
self_vertex.artifacts = None
self_vertex.set_top_level(self.top_level_vertices)
self.reset_all_edges_of_vertex(self_vertex)
if not self.vertex_edges_are_identical(self_vertex, other_vertex):
new_edges.update(other_vertex.edges)
# Add new edges
# to self.edges if they are not already in self.edges
for edge in new_edges:
if edge not in self.edges:
self.edges.append(edge)
# Remove vertices
for vertex_id in removed_vertex_ids:
@ -255,7 +280,11 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -276,7 +305,9 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
raise ValueError(
f"{vertex.vertex_type} is not connected to any other components"
)
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -292,7 +323,11 @@ class Graph:
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
"""Returns a list of edges for a given vertex."""
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
return [
edge
for edge in self.edges
if edge.source_id == vertex_id or edge.target_id == vertex_id
]
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
"""Returns the vertices connected to a vertex."""
@ -330,7 +365,9 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -354,11 +391,17 @@ class Graph:
def get_predecessors(self, vertex):
"""Returns the predecessors of a vertex."""
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
return [
self.get_vertex(source_id)
for source_id in self.predecessor_map.get(vertex.id, [])
]
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
return [
self.get_vertex(target_id)
for target_id in self.successor_map.get(vertex.id, [])
]
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
@ -397,7 +440,9 @@ class Graph:
edges.append(ContractEdge(source, target, edge))
return edges
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -428,14 +473,18 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -447,7 +496,9 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str) -> "Graph":
@ -478,7 +529,9 @@ class Graph:
"""Performs a layered topological sort of the vertices in the graph."""
# Queue for vertices with no incoming edges
queue = deque(vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0)
queue = deque(
vertex.id for vertex in vertices if self.in_degree_map[vertex.id] == 0
)
layers = []
current_layer = 0
@ -534,7 +587,9 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -562,11 +617,15 @@ class Graph:
self.increment_run_count()
return vertices_layers
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
return any(component.value in vertex for component in InterfaceComponentTypes)
return any(
component.value in vertex for component in InterfaceComponentTypes
)
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
sorted_vertices = [
@ -585,9 +644,13 @@ class Graph:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
return vertices_ids
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
return sorted_vertices