Refactor _get_vertex_class method to handle node_base_type

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-07 14:07:13 -03:00
commit 0abeda51e9

View file

@ -237,8 +237,12 @@ class Graph:
edges.append(ContractEdge(source, target, edge))
return edges
def _get_vertex_class(self, node_type: str, node_lc_type: str, node_id: str) -> Type[Vertex]:
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
if node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type]
node_name = node_id.split("-")[0]
if node_name in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_name]
@ -248,8 +252,8 @@ class Graph:
if node_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_type]
return (
lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_lc_type]
if node_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type]
if node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
else Vertex
)
@ -312,3 +316,4 @@ class Graph:
return layers
return layers
return layers
return layers