Refactor code to improve performance and readability
This commit is contained in:
parent
544f4e8265
commit
6c415d0865
1 changed files with 82 additions and 26 deletions
|
|
@ -76,7 +76,9 @@ class Graph:
|
|||
"""Returns the state of the graph."""
|
||||
return self.state_manager.get_state(name, run_id=self._run_id)
|
||||
|
||||
def update_state(self, name: str, record: Union[str, Record], caller: Optional[str] = None) -> None:
|
||||
def update_state(
|
||||
self, name: str, record: Union[str, Record], caller: Optional[str] = None
|
||||
) -> None:
|
||||
"""Updates the state of the graph."""
|
||||
if caller:
|
||||
# If there is a caller which is a vertex_id, I want to activate
|
||||
|
|
@ -108,7 +110,9 @@ class Graph:
|
|||
def reset_activated_vertices(self):
|
||||
self.activated_vertices = []
|
||||
|
||||
def append_state(self, name: str, record: Union[str, Record], caller: Optional[str] = None) -> None:
|
||||
def append_state(
|
||||
self, name: str, record: Union[str, Record], caller: Optional[str] = None
|
||||
) -> None:
|
||||
"""Appends the state of the graph."""
|
||||
if caller:
|
||||
self.activate_state_vertices(name, caller)
|
||||
|
|
@ -156,7 +160,10 @@ class Graph:
|
|||
"""Runs the graph with the given inputs."""
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if input_components and (vertex_id not in input_components or vertex.display_name not in input_components):
|
||||
if input_components and (
|
||||
vertex_id not in input_components
|
||||
or vertex.display_name not in input_components
|
||||
):
|
||||
continue
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
|
|
@ -179,9 +186,13 @@ class Graph:
|
|||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
|
||||
if not vertex.result and not stream and hasattr(vertex, "consume_async_generator"):
|
||||
if (
|
||||
not vertex.result
|
||||
and not stream
|
||||
and hasattr(vertex, "consume_async_generator")
|
||||
):
|
||||
await vertex.consume_async_generator()
|
||||
if vertex.display_name in outputs or vertex.id in outputs:
|
||||
if not outputs or (vertex.display_name in outputs or vertex.id in outputs):
|
||||
vertex_outputs.append(vertex.result)
|
||||
return vertex_outputs
|
||||
|
||||
|
|
@ -189,8 +200,8 @@ class Graph:
|
|||
self,
|
||||
inputs: Dict[str, Union[str, list[str]]],
|
||||
outputs: list[str],
|
||||
stream: bool,
|
||||
session_id: str,
|
||||
stream: Optional[bool] = False,
|
||||
) -> List[Optional["ResultData"]]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
|
|
@ -257,7 +268,9 @@ class Graph:
|
|||
def build_parent_child_map(self):
|
||||
parent_child_map = defaultdict(list)
|
||||
for vertex in self.vertices:
|
||||
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
|
||||
parent_child_map[vertex.id] = [
|
||||
child.id for child in self.get_successors(vertex)
|
||||
]
|
||||
return parent_child_map
|
||||
|
||||
def increment_run_count(self):
|
||||
|
|
@ -442,7 +455,11 @@ class Graph:
|
|||
"""Updates the edges of a vertex."""
|
||||
# Vertex has edges, so we need to update the edges
|
||||
for edge in vertex.edges:
|
||||
if edge not in self.edges and edge.source_id in self.vertex_map and edge.target_id in self.vertex_map:
|
||||
if (
|
||||
edge not in self.edges
|
||||
and edge.source_id in self.vertex_map
|
||||
and edge.target_id in self.vertex_map
|
||||
):
|
||||
self.edges.append(edge)
|
||||
|
||||
def _build_graph(self) -> None:
|
||||
|
|
@ -467,7 +484,11 @@ class Graph:
|
|||
return
|
||||
self.vertices.remove(vertex)
|
||||
self.vertex_map.pop(vertex_id)
|
||||
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
|
||||
self.edges = [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.source_id != vertex_id and edge.target_id != vertex_id
|
||||
]
|
||||
|
||||
def _build_vertex_params(self) -> None:
|
||||
"""Identifies and handles the LLM vertex within the graph."""
|
||||
|
|
@ -488,7 +509,9 @@ class Graph:
|
|||
return
|
||||
for vertex in self.vertices:
|
||||
if not self._validate_vertex(vertex):
|
||||
raise ValueError(f"{vertex.display_name} is not connected to any other components")
|
||||
raise ValueError(
|
||||
f"{vertex.display_name} is not connected to any other components"
|
||||
)
|
||||
|
||||
def _validate_vertex(self, vertex: Vertex) -> bool:
|
||||
"""Validates a vertex."""
|
||||
|
|
@ -550,7 +573,9 @@ class Graph:
|
|||
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
|
||||
)
|
||||
tasks.append(task)
|
||||
vertex_task_run_count[vertex_id] = vertex_task_run_count.get(vertex_id, 0) + 1
|
||||
vertex_task_run_count[vertex_id] = (
|
||||
vertex_task_run_count.get(vertex_id, 0) + 1
|
||||
)
|
||||
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
|
||||
await self._execute_tasks(tasks)
|
||||
logger.debug("Graph processing complete")
|
||||
|
|
@ -592,7 +617,9 @@ class Graph:
|
|||
def dfs(vertex):
|
||||
if state[vertex] == 1:
|
||||
# We have a cycle
|
||||
raise ValueError("Graph contains a cycle, cannot perform topological sort")
|
||||
raise ValueError(
|
||||
"Graph contains a cycle, cannot perform topological sort"
|
||||
)
|
||||
if state[vertex] == 0:
|
||||
state[vertex] = 1
|
||||
for edge in vertex.edges:
|
||||
|
|
@ -616,7 +643,10 @@ class Graph:
|
|||
|
||||
def get_predecessors(self, vertex):
|
||||
"""Returns the predecessors of a vertex."""
|
||||
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
|
||||
return [
|
||||
self.get_vertex(source_id)
|
||||
for source_id in self.predecessor_map.get(vertex.id, [])
|
||||
]
|
||||
|
||||
def get_all_successors(self, vertex, recursive=True, flat=True):
|
||||
# Recursively get the successors of the current vertex
|
||||
|
|
@ -657,7 +687,10 @@ class Graph:
|
|||
|
||||
def get_successors(self, vertex):
|
||||
"""Returns the successors of a vertex."""
|
||||
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
|
||||
return [
|
||||
self.get_vertex(target_id)
|
||||
for target_id in self.successor_map.get(vertex.id, [])
|
||||
]
|
||||
|
||||
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
|
||||
"""Returns the neighbors of a vertex."""
|
||||
|
|
@ -703,7 +736,9 @@ class Graph:
|
|||
edges_added.add((source.id, target.id))
|
||||
return edges
|
||||
|
||||
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
|
||||
def _get_vertex_class(
|
||||
self, node_type: str, node_base_type: str, node_id: str
|
||||
) -> Type[Vertex]:
|
||||
"""Returns the node class based on the node type."""
|
||||
# First we check for the node_base_type
|
||||
node_name = node_id.split("-")[0]
|
||||
|
|
@ -736,14 +771,18 @@ class Graph:
|
|||
vertex_type: str = vertex_data["type"] # type: ignore
|
||||
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
|
||||
VertexClass = self._get_vertex_class(
|
||||
vertex_type, vertex_base_type, vertex_data["id"]
|
||||
)
|
||||
vertex_instance = VertexClass(vertex, graph=self)
|
||||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
vertices.append(vertex_instance)
|
||||
|
||||
return vertices
|
||||
|
||||
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
|
||||
def get_children_by_vertex_type(
|
||||
self, vertex: Vertex, vertex_type: str
|
||||
) -> List[Vertex]:
|
||||
"""Returns the children of a vertex based on the vertex type."""
|
||||
children = []
|
||||
vertex_types = [vertex.data["type"]]
|
||||
|
|
@ -755,7 +794,9 @@ class Graph:
|
|||
|
||||
def __repr__(self):
|
||||
vertex_ids = [vertex.id for vertex in self.vertices]
|
||||
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
|
||||
edges_repr = "\n".join(
|
||||
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
|
||||
)
|
||||
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
|
||||
|
||||
def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
|
||||
|
|
@ -823,7 +864,8 @@ class Graph:
|
|||
vertex.id
|
||||
for vertex in vertices
|
||||
# if filter_graphs then only vertex.is_input will be considered
|
||||
if self.in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input)
|
||||
if self.in_degree_map[vertex.id] == 0
|
||||
and (not filter_graphs or vertex.is_input)
|
||||
)
|
||||
layers: List[List[str]] = []
|
||||
visited = set(queue)
|
||||
|
|
@ -897,7 +939,9 @@ class Graph:
|
|||
|
||||
return refined_layers
|
||||
|
||||
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
def sort_chat_inputs_first(
|
||||
self, vertices_layers: List[List[str]]
|
||||
) -> List[List[str]]:
|
||||
chat_inputs_first = []
|
||||
for layer in vertices_layers:
|
||||
for vertex_id in layer:
|
||||
|
|
@ -938,7 +982,9 @@ class Graph:
|
|||
first_layer = vertices_layers[0]
|
||||
# save the only the rest
|
||||
self.vertices_layers = vertices_layers[1:]
|
||||
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
|
||||
self.vertices_to_run = {
|
||||
vertex_id for vertex_id in chain.from_iterable(vertices_layers)
|
||||
}
|
||||
# Return just the first layer
|
||||
return first_layer
|
||||
|
||||
|
|
@ -949,11 +995,15 @@ class Graph:
|
|||
self.vertices_to_run.remove(vertex_id)
|
||||
return should_run
|
||||
|
||||
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
def sort_interface_components_first(
|
||||
self, vertices_layers: List[List[str]]
|
||||
) -> List[List[str]]:
|
||||
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
|
||||
|
||||
def contains_interface_component(vertex):
|
||||
return any(component.value in vertex for component in InterfaceComponentTypes)
|
||||
return any(
|
||||
component.value in vertex for component in InterfaceComponentTypes
|
||||
)
|
||||
|
||||
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
|
||||
sorted_vertices = [
|
||||
|
|
@ -965,16 +1015,22 @@ class Graph:
|
|||
]
|
||||
return sorted_vertices
|
||||
|
||||
def sort_by_avg_build_time(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
def sort_by_avg_build_time(
|
||||
self, vertices_layers: List[List[str]]
|
||||
) -> List[List[str]]:
|
||||
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
|
||||
|
||||
def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]:
|
||||
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
|
||||
if len(vertices_ids) == 1:
|
||||
return vertices_ids
|
||||
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
|
||||
vertices_ids.sort(
|
||||
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
|
||||
)
|
||||
|
||||
return vertices_ids
|
||||
|
||||
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
|
||||
sorted_vertices = [
|
||||
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
|
||||
]
|
||||
return sorted_vertices
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue