Refactor edge equality check and optimize edge creation in Graph class
This commit is contained in:
parent
4aa726706f
commit
87adaa4255
2 changed files with 9 additions and 8 deletions
|
|
@ -101,7 +101,11 @@ class Edge:
|
|||
def __eq__(self, __o: object) -> bool:
|
||||
if not isinstance(__o, Edge):
|
||||
return False
|
||||
return self._source_handle == __o._source_handle and self._target_handle == __o._target_handle
|
||||
return (
|
||||
self._source_handle == __o._source_handle
|
||||
and self._target_handle == __o._target_handle
|
||||
and self.target_param == __o.target_param
|
||||
)
|
||||
|
||||
|
||||
class ContractEdge(Edge):
|
||||
|
|
|
|||
|
|
@ -890,8 +890,7 @@ class Graph:
|
|||
# and then build the edges
|
||||
# if we can't find a vertex, we raise an error
|
||||
|
||||
edges: List[ContractEdge] = []
|
||||
edges_added = set()
|
||||
edges: set[ContractEdge] = set()
|
||||
for edge in self._edges:
|
||||
source = self.get_vertex(edge["source"])
|
||||
target = self.get_vertex(edge["target"])
|
||||
|
|
@ -900,13 +899,11 @@ class Graph:
|
|||
raise ValueError(f"Source vertex {edge['source']} not found")
|
||||
if target is None:
|
||||
raise ValueError(f"Target vertex {edge['target']} not found")
|
||||
edge = ContractEdge(source, target, edge)
|
||||
|
||||
if (source.id, target.id) in edges_added:
|
||||
continue
|
||||
edges.add(edge)
|
||||
|
||||
edges.append(ContractEdge(source, target, edge))
|
||||
edges_added.add((source.id, target.id))
|
||||
return edges
|
||||
return list(edges)
|
||||
|
||||
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
|
||||
"""Returns the node class based on the node type."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue