Refactor code to improve performance and readability

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-07 09:57:23 -03:00
commit 6c415d0865

View file

@ -76,7 +76,9 @@ class Graph:
"""Returns the state of the graph."""
return self.state_manager.get_state(name, run_id=self._run_id)
def update_state(self, name: str, record: Union[str, Record], caller: Optional[str] = None) -> None:
def update_state(
self, name: str, record: Union[str, Record], caller: Optional[str] = None
) -> None:
"""Updates the state of the graph."""
if caller:
# If there is a caller which is a vertex_id, I want to activate
@ -108,7 +110,9 @@ class Graph:
def reset_activated_vertices(self):
self.activated_vertices = []
def append_state(self, name: str, record: Union[str, Record], caller: Optional[str] = None) -> None:
def append_state(
self, name: str, record: Union[str, Record], caller: Optional[str] = None
) -> None:
"""Appends the state of the graph."""
if caller:
self.activate_state_vertices(name, caller)
@ -156,7 +160,10 @@ class Graph:
"""Runs the graph with the given inputs."""
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
if input_components and (vertex_id not in input_components or vertex.display_name not in input_components):
if input_components and (
vertex_id not in input_components
or vertex.display_name not in input_components
):
continue
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
@ -179,9 +186,13 @@ class Graph:
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
if not vertex.result and not stream and hasattr(vertex, "consume_async_generator"):
if (
not vertex.result
and not stream
and hasattr(vertex, "consume_async_generator")
):
await vertex.consume_async_generator()
if vertex.display_name in outputs or vertex.id in outputs:
if not outputs or (vertex.display_name in outputs or vertex.id in outputs):
vertex_outputs.append(vertex.result)
return vertex_outputs
@ -189,8 +200,8 @@ class Graph:
self,
inputs: Dict[str, Union[str, list[str]]],
outputs: list[str],
stream: bool,
session_id: str,
stream: Optional[bool] = False,
) -> List[Optional["ResultData"]]:
"""Runs the graph with the given inputs."""
@ -257,7 +268,9 @@ class Graph:
def build_parent_child_map(self):
parent_child_map = defaultdict(list)
for vertex in self.vertices:
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
parent_child_map[vertex.id] = [
child.id for child in self.get_successors(vertex)
]
return parent_child_map
def increment_run_count(self):
@ -442,7 +455,11 @@ class Graph:
"""Updates the edges of a vertex."""
# Vertex has edges, so we need to update the edges
for edge in vertex.edges:
if edge not in self.edges and edge.source_id in self.vertex_map and edge.target_id in self.vertex_map:
if (
edge not in self.edges
and edge.source_id in self.vertex_map
and edge.target_id in self.vertex_map
):
self.edges.append(edge)
def _build_graph(self) -> None:
@ -467,7 +484,11 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -488,7 +509,9 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(f"{vertex.display_name} is not connected to any other components")
raise ValueError(
f"{vertex.display_name} is not connected to any other components"
)
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -550,7 +573,9 @@ class Graph:
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
)
tasks.append(task)
vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1
vertex_task_run_count[vertex_id] = (
vertex_task_run_count.get(vertex_id, 0) + 1
)
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
await self._execute_tasks(tasks)
logger.debug("Graph processing complete")
@ -592,7 +617,9 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -616,7 +643,10 @@ class Graph:
def get_predecessors(self, vertex):
"""Returns the predecessors of a vertex."""
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
return [
self.get_vertex(source_id)
for source_id in self.predecessor_map.get(vertex.id, [])
]
def get_all_successors(self, vertex, recursive=True, flat=True):
# Recursively get the successors of the current vertex
@ -657,7 +687,10 @@ class Graph:
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
return [
self.get_vertex(target_id)
for target_id in self.successor_map.get(vertex.id, [])
]
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
@ -703,7 +736,9 @@ class Graph:
edges_added.add((source.id, target.id))
return edges
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -736,14 +771,18 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -755,7 +794,9 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
@ -823,7 +864,8 @@ class Graph:
vertex.id
for vertex in vertices
# if filter_graphs then only vertex.is_input will be considered
if self.in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input)
if self.in_degree_map[vertex.id] == 0
and (not filter_graphs or vertex.is_input)
)
layers: List[List[str]] = []
visited = set(queue)
@ -897,7 +939,9 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -938,7 +982,9 @@ class Graph:
first_layer = vertices_layers[0]
# save the only the rest
self.vertices_layers = vertices_layers[1:]
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
self.vertices_to_run = {
vertex_id for vertex_id in chain.from_iterable(vertices_layers)
}
# Return just the first layer
return first_layer
@ -949,11 +995,15 @@ class Graph:
self.vertices_to_run.remove(vertex_id)
return should_run
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
return any(component.value in vertex for component in InterfaceComponentTypes)
return any(
component.value in vertex for component in InterfaceComponentTypes
)
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
sorted_vertices = [
@ -965,16 +1015,22 @@ class Graph:
]
return sorted_vertices
def sort_by_avg_build_time(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_by_avg_build_time(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
return vertices_ids
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
return sorted_vertices