Add support for ChatVertex in _get_vertex_class() method

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-07 15:31:23 -03:00
commit 5d7f00c5ad

View file

@ -240,10 +240,12 @@ class Graph:
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
if node_name in ["ChatOutput", "ChatInput"]:
return ChatVertex
if node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type]
node_name = node_id.split("-")[0]
if node_name in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_name]
@ -317,3 +319,4 @@ class Graph:
return layers
return layers
return layers
return layers