refactor: separate initialization of nodes and edges in test_graph.py (#2828)
* refactor: move test_graph.py * refactor: allow Graph to be initialized with no nodes and edges The Graph class in `base.py` was refactored to separate the initialization of nodes and edges into a separate method called `add_nodes_and_edges()`. This improves code readability and maintainability by organizing the code logic more effectively. * refactor: separate initialization of nodes and edges in get_graph() The `get_graph()` function in `conftest.py` was refactored to separate the initialization of nodes and edges. This improves code readability and maintainability by organizing the code logic more effectively. * refactor: separate initialization of nodes and edges in test_graph.py * refactor: separate initialization of nodes and edges in base.py The `add_node()` and `add_edge()` methods were added to the `Graph` class in `base.py` to separate the initialization of nodes and edges. This improves code readability and maintainability by organizing the code logic more effectively.
This commit is contained in:
parent
077f68f17b
commit
77cc789e62
3 changed files with 51 additions and 33 deletions
|
|
@ -34,8 +34,6 @@ class Graph:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[Dict],
|
||||
edges: List[Dict[str, str]],
|
||||
flow_id: Optional[str] = None,
|
||||
flow_name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
|
|
@ -48,9 +46,6 @@ class Graph:
|
|||
edges (List[Dict[str, str]]): A list of dictionaries representing the edges of the graph.
|
||||
flow_id (Optional[str], optional): The ID of the flow. Defaults to None.
|
||||
"""
|
||||
self._vertices = nodes
|
||||
self._edges = edges
|
||||
self.raw_graph_data = {"nodes": nodes, "edges": edges}
|
||||
self._runs = 0
|
||||
self._updates = 0
|
||||
self.flow_id = flow_id
|
||||
|
|
@ -63,7 +58,26 @@ class Graph:
|
|||
self._sorted_vertices_layers: List[List[str]] = []
|
||||
self._run_id = ""
|
||||
self._start_time = datetime.now(timezone.utc)
|
||||
self.inactivated_vertices: set = set()
|
||||
self.activated_vertices: List[str] = []
|
||||
self.vertices_layers: List[List[str]] = []
|
||||
self.vertices_to_run: set[str] = set()
|
||||
self.stop_vertex: Optional[str] = None
|
||||
self.inactive_vertices: set = set()
|
||||
self.edges: List[ContractEdge] = []
|
||||
self.vertices: List[Vertex] = []
|
||||
self.run_manager = RunnableVerticesManager()
|
||||
self.state_manager = GraphStateManager()
|
||||
try:
|
||||
self.tracing_service: "TracingService" | None = get_tracing_service()
|
||||
except Exception as exc:
|
||||
logger.error(f"Error getting tracing service: {exc}")
|
||||
self.tracing_service = None
|
||||
|
||||
def add_nodes_and_edges(self, nodes: List[Dict], edges: List[Dict[str, str]]):
|
||||
self._vertices = nodes
|
||||
self._edges = edges
|
||||
self.raw_graph_data = {"nodes": nodes, "edges": edges}
|
||||
self.top_level_vertices = []
|
||||
for vertex in self._vertices:
|
||||
if vertex_id := vertex.get("id"):
|
||||
|
|
@ -72,25 +86,20 @@ class Graph:
|
|||
|
||||
self._vertices = self._graph_data["nodes"]
|
||||
self._edges = self._graph_data["edges"]
|
||||
self.inactivated_vertices: set = set()
|
||||
self.activated_vertices: List[str] = []
|
||||
self.vertices_layers: List[List[str]] = []
|
||||
self.vertices_to_run: set[str] = set()
|
||||
self.stop_vertex: Optional[str] = None
|
||||
self.initialize()
|
||||
|
||||
self.inactive_vertices: set = set()
|
||||
self.edges: List[ContractEdge] = []
|
||||
self.vertices: List[Vertex] = []
|
||||
self.run_manager = RunnableVerticesManager()
|
||||
# TODO: Create a TypedDict to represente the node
|
||||
def add_node(self, node: dict):
|
||||
self._vertices.append(node)
|
||||
|
||||
# TODO: Create a TypedDict to represente the edge
|
||||
def add_edge(self, edge: dict):
|
||||
self._edges.append(edge)
|
||||
|
||||
def initialize(self):
|
||||
self._build_graph()
|
||||
self.build_graph_maps(self.edges)
|
||||
self.define_vertices_lists()
|
||||
self.state_manager = GraphStateManager()
|
||||
try:
|
||||
self.tracing_service: "TracingService" | None = get_tracing_service()
|
||||
except Exception as exc:
|
||||
logger.error(f"Error getting tracing service: {exc}")
|
||||
self.tracing_service = None
|
||||
|
||||
def get_state(self, name: str) -> Optional[Data]:
|
||||
"""
|
||||
|
|
@ -638,7 +647,9 @@ class Graph:
|
|||
try:
|
||||
vertices = payload["nodes"]
|
||||
edges = payload["edges"]
|
||||
return cls(vertices, edges, flow_id, flow_name, user_id)
|
||||
graph = cls(flow_id, flow_name, user_id)
|
||||
graph.add_nodes_and_edges(vertices, edges)
|
||||
return graph
|
||||
except KeyError as exc:
|
||||
logger.exception(exc)
|
||||
if "nodes" not in payload and "edges" not in payload:
|
||||
|
|
@ -1188,20 +1199,24 @@ class Graph:
|
|||
"""Builds the vertices of the graph."""
|
||||
vertices: List[Vertex] = []
|
||||
for vertex in self._vertices:
|
||||
vertex_data = vertex["data"]
|
||||
vertex_type: str = vertex_data["type"] # type: ignore
|
||||
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
if "id" not in vertex_data:
|
||||
raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id")
|
||||
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
|
||||
|
||||
vertex_instance = VertexClass(vertex, graph=self)
|
||||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
vertex_instance = self._create_vertex(vertex)
|
||||
vertices.append(vertex_instance)
|
||||
|
||||
return vertices
|
||||
|
||||
def _create_vertex(self, vertex: dict):
|
||||
vertex_data = vertex["data"]
|
||||
vertex_type: str = vertex_data["type"] # type: ignore
|
||||
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
if "id" not in vertex_data:
|
||||
raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id")
|
||||
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
|
||||
|
||||
vertex_instance = VertexClass(vertex, graph=self)
|
||||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
return vertex_instance
|
||||
|
||||
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
|
||||
"""Returns the children of a vertex based on the vertex type."""
|
||||
children = []
|
||||
|
|
|
|||
|
|
@ -158,7 +158,9 @@ def get_graph(_type="basic"):
|
|||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
return Graph(nodes, edges)
|
||||
graph = Graph()
|
||||
graph.add_nodes_and_edges(nodes, edges)
|
||||
return graph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -116,7 +116,8 @@ def test_invalid_node_types():
|
|||
"edges": [],
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
Graph(graph_data["nodes"], graph_data["edges"])
|
||||
g = Graph()
|
||||
g.add_nodes_and_edges(graph_data["nodes"], graph_data["edges"])
|
||||
|
||||
|
||||
def test_get_vertices_with_target(basic_graph):
|
||||
Loading…
Add table
Add a link
Reference in a new issue