diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index a790e7686..f57f01f7d 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -240,10 +240,12 @@ class Graph: def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type + node_name = node_id.split("-")[0] + if node_name in ["ChatOutput", "ChatInput"]: + return ChatVertex if node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type] - node_name = node_id.split("-")[0] if node_name in lazy_load_vertex_dict.VERTEX_TYPE_MAP: return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_name] @@ -317,3 +319,4 @@ class Graph: return layers return layers return layers + return layers