refactor: separate initialization of nodes and edges in test_graph.py (#2828)

* refactor: move test_graph.py

* refactor: allow Graph to be initialized with no nodes and edges

The Graph class in `base.py` was refactored to separate the initialization of nodes and edges into a separate method called `add_nodes_and_edges()`. This improves code readability and maintainability by organizing the code logic more effectively.

* refactor: separate initialization of nodes and edges in get_graph()

The `get_graph()` function in `conftest.py` was refactored to separate the initialization of nodes and edges. This improves code readability and maintainability by organizing the code logic more effectively.

* refactor: separate initialization of nodes and edges in test_graph.py

* refactor: separate initialization of nodes and edges in base.py

The `add_node()` and `add_edge()` methods were added to the `Graph` class in `base.py` to separate the initialization of nodes and edges. This improves code readability and maintainability by organizing the code logic more effectively.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-07-22 16:48:29 -03:00 committed by GitHub
commit 77cc789e62
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 51 additions and 33 deletions

View file

@ -34,8 +34,6 @@ class Graph:
def __init__(
self,
nodes: List[Dict],
edges: List[Dict[str, str]],
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
user_id: Optional[str] = None,
@ -48,9 +46,6 @@ class Graph:
edges (List[Dict[str, str]]): A list of dictionaries representing the edges of the graph.
flow_id (Optional[str], optional): The ID of the flow. Defaults to None.
"""
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
self._runs = 0
self._updates = 0
self.flow_id = flow_id
@ -63,7 +58,26 @@ class Graph:
self._sorted_vertices_layers: List[List[str]] = []
self._run_id = ""
self._start_time = datetime.now(timezone.utc)
self.inactivated_vertices: set = set()
self.activated_vertices: List[str] = []
self.vertices_layers: List[List[str]] = []
self.vertices_to_run: set[str] = set()
self.stop_vertex: Optional[str] = None
self.inactive_vertices: set = set()
self.edges: List[ContractEdge] = []
self.vertices: List[Vertex] = []
self.run_manager = RunnableVerticesManager()
self.state_manager = GraphStateManager()
try:
self.tracing_service: "TracingService" | None = get_tracing_service()
except Exception as exc:
logger.error(f"Error getting tracing service: {exc}")
self.tracing_service = None
def add_nodes_and_edges(self, nodes: List[Dict], edges: List[Dict[str, str]]):
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
self.top_level_vertices = []
for vertex in self._vertices:
if vertex_id := vertex.get("id"):
@ -72,25 +86,20 @@ class Graph:
self._vertices = self._graph_data["nodes"]
self._edges = self._graph_data["edges"]
self.inactivated_vertices: set = set()
self.activated_vertices: List[str] = []
self.vertices_layers: List[List[str]] = []
self.vertices_to_run: set[str] = set()
self.stop_vertex: Optional[str] = None
self.initialize()
self.inactive_vertices: set = set()
self.edges: List[ContractEdge] = []
self.vertices: List[Vertex] = []
self.run_manager = RunnableVerticesManager()
# TODO: Create a TypedDict to represente the node
def add_node(self, node: dict):
self._vertices.append(node)
# TODO: Create a TypedDict to represente the edge
def add_edge(self, edge: dict):
self._edges.append(edge)
def initialize(self):
self._build_graph()
self.build_graph_maps(self.edges)
self.define_vertices_lists()
self.state_manager = GraphStateManager()
try:
self.tracing_service: "TracingService" | None = get_tracing_service()
except Exception as exc:
logger.error(f"Error getting tracing service: {exc}")
self.tracing_service = None
def get_state(self, name: str) -> Optional[Data]:
"""
@ -638,7 +647,9 @@ class Graph:
try:
vertices = payload["nodes"]
edges = payload["edges"]
return cls(vertices, edges, flow_id, flow_name, user_id)
graph = cls(flow_id, flow_name, user_id)
graph.add_nodes_and_edges(vertices, edges)
return graph
except KeyError as exc:
logger.exception(exc)
if "nodes" not in payload and "edges" not in payload:
@ -1188,20 +1199,24 @@ class Graph:
"""Builds the vertices of the graph."""
vertices: List[Vertex] = []
for vertex in self._vertices:
vertex_data = vertex["data"]
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
if "id" not in vertex_data:
raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id")
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertex_instance = self._create_vertex(vertex)
vertices.append(vertex_instance)
return vertices
def _create_vertex(self, vertex: dict):
vertex_data = vertex["data"]
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
if "id" not in vertex_data:
raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id")
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
return vertex_instance
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []

View file

@ -158,7 +158,9 @@ def get_graph(_type="basic"):
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
return Graph(nodes, edges)
graph = Graph()
graph.add_nodes_and_edges(nodes, edges)
return graph
@pytest.fixture

View file

@ -116,7 +116,8 @@ def test_invalid_node_types():
"edges": [],
}
with pytest.raises(Exception):
Graph(graph_data["nodes"], graph_data["edges"])
g = Graph()
g.add_nodes_and_edges(graph_data["nodes"], graph_data["edges"])
def test_get_vertices_with_target(basic_graph):