Refactor edge equality check and optimize edge creation in Graph class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-30 16:49:50 -03:00
commit 87adaa4255
2 changed files with 9 additions and 8 deletions

View file

@ -101,7 +101,11 @@ class Edge:
def __eq__(self, __o: object) -> bool:
if not isinstance(__o, Edge):
return False
return self._source_handle == __o._source_handle and self._target_handle == __o._target_handle
return (
self._source_handle == __o._source_handle
and self._target_handle == __o._target_handle
and self.target_param == __o.target_param
)
class ContractEdge(Edge):

View file

@ -890,8 +890,7 @@ class Graph:
# and then build the edges
# if we can't find a vertex, we raise an error
edges: List[ContractEdge] = []
edges_added = set()
edges: set[ContractEdge] = set()
for edge in self._edges:
source = self.get_vertex(edge["source"])
target = self.get_vertex(edge["target"])
@ -900,13 +899,11 @@ class Graph:
raise ValueError(f"Source vertex {edge['source']} not found")
if target is None:
raise ValueError(f"Target vertex {edge['target']} not found")
edge = ContractEdge(source, target, edge)
if (source.id, target.id) in edges_added:
continue
edges.add(edge)
edges.append(ContractEdge(source, target, edge))
edges_added.add((source.id, target.id))
return edges
return list(edges)
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
"""Returns the node class based on the node type."""