Refactor vertex class selection in Graph class
This commit is contained in:
parent
f109a4097e
commit
a2336aa7c8
2 changed files with 16 additions and 10 deletions
|
|
@ -209,15 +209,18 @@ class Graph:
|
|||
edges.append(Edge(source, target, edge))
|
||||
return edges
|
||||
|
||||
def _get_vertex_class(self, vertex_type: str, vertex_lc_type: str) -> Type[Vertex]:
|
||||
def _get_vertex_class(self, vertex_type: str, vertex_base_type: str) -> Type[Vertex]:
|
||||
"""Returns the vertex class based on the vertex type."""
|
||||
if vertex_type in FILE_TOOLS:
|
||||
return FileToolVertex
|
||||
if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
|
||||
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type]
|
||||
if vertex_base_type == "CustomComponent":
|
||||
return lazy_load_vertex_dict.get_custom_component_vertex_type()
|
||||
|
||||
if vertex_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
|
||||
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_base_type]
|
||||
return (
|
||||
lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_lc_type]
|
||||
if vertex_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
|
||||
lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type]
|
||||
if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
|
||||
else Vertex
|
||||
)
|
||||
|
||||
|
|
@ -227,9 +230,9 @@ class Graph:
|
|||
for vertex in self._vertices:
|
||||
vertex_data = vertex["data"]
|
||||
vertex_type: str = vertex_data["type"] # type: ignore
|
||||
vertex_lc_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_lc_type)
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type)
|
||||
vertex_instance = VertexClass(vertex, graph=self)
|
||||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
vertices.append(vertex_instance)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
from langflow.graph.vertex import types
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from langflow.interface.document_loaders.base import documentloader_creator
|
||||
from langflow.interface.embeddings.base import embedding_creator
|
||||
from langflow.interface.llms.base import llm_creator
|
||||
from langflow.interface.memories.base import memory_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.prompts.base import prompt_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.text_splitters.base import textsplitter_creator
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.tools.base import tool_creator
|
||||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from langflow.utils.lazy_load import LazyLoadDictBase
|
||||
|
||||
|
||||
|
|
@ -32,6 +32,9 @@ class VertexTypesDict(LazyLoadDictBase):
|
|||
"Custom": ["Custom Tool", "Python Function"],
|
||||
}
|
||||
|
||||
def get_custom_component_vertex_type(self):
|
||||
return types.CustomComponentVertex
|
||||
|
||||
def get_type_dict(self):
|
||||
return {
|
||||
**{t: types.PromptVertex for t in prompt_creator.to_list()},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue