diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 3481fda87..d1629d49b 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -209,15 +209,18 @@ class Graph: edges.append(Edge(source, target, edge)) return edges - def _get_vertex_class(self, vertex_type: str, vertex_lc_type: str) -> Type[Vertex]: + def _get_vertex_class(self, vertex_type: str, vertex_base_type: str) -> Type[Vertex]: """Returns the vertex class based on the vertex type.""" if vertex_type in FILE_TOOLS: return FileToolVertex - if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: - return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type] + if vertex_base_type == "CustomComponent": + return lazy_load_vertex_dict.get_custom_component_vertex_type() + + if vertex_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: + return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_base_type] return ( - lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_lc_type] - if vertex_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP + lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type] + if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP else Vertex ) @@ -227,9 +230,9 @@ class Graph: for vertex in self._vertices: vertex_data = vertex["data"] vertex_type: str = vertex_data["type"] # type: ignore - vertex_lc_type: str = vertex_data["node"]["template"]["_type"] # type: ignore + vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - VertexClass = self._get_vertex_class(vertex_type, vertex_lc_type) + VertexClass = self._get_vertex_class(vertex_type, vertex_base_type) vertex_instance = VertexClass(vertex, graph=self) vertex_instance.set_top_level(self.top_level_vertices) vertices.append(vertex_instance) diff --git a/src/backend/langflow/graph/graph/constants.py b/src/backend/langflow/graph/graph/constants.py index abfc2970f..9514764b7 100644 --- a/src/backend/langflow/graph/graph/constants.py +++ b/src/backend/langflow/graph/graph/constants.py @@ -1,19 +1,19 @@ from langflow.graph.vertex import types from langflow.interface.agents.base import agent_creator from langflow.interface.chains.base import chain_creator +from langflow.interface.custom.base import custom_component_creator from langflow.interface.document_loaders.base import documentloader_creator from langflow.interface.embeddings.base import embedding_creator from langflow.interface.llms.base import llm_creator from langflow.interface.memories.base import memory_creator +from langflow.interface.output_parsers.base import output_parser_creator from langflow.interface.prompts.base import prompt_creator +from langflow.interface.retrievers.base import retriever_creator from langflow.interface.text_splitters.base import textsplitter_creator from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.tools.base import tool_creator from langflow.interface.vector_store.base import vectorstore_creator from langflow.interface.wrappers.base import wrapper_creator -from langflow.interface.output_parsers.base import output_parser_creator -from langflow.interface.retrievers.base import retriever_creator -from langflow.interface.custom.base import custom_component_creator from langflow.utils.lazy_load import LazyLoadDictBase @@ -32,6 +32,9 @@ class VertexTypesDict(LazyLoadDictBase): "Custom": ["Custom Tool", "Python Function"], } + def get_custom_component_vertex_type(self): + return types.CustomComponentVertex + def get_type_dict(self): return { **{t: types.PromptVertex for t in prompt_creator.to_list()},