Refactor vertex class selection in Graph class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-13 10:38:25 -03:00
commit a2336aa7c8
2 changed files with 16 additions and 10 deletions

View file

@ -209,15 +209,18 @@ class Graph:
edges.append(Edge(source, target, edge))
return edges
def _get_vertex_class(self, vertex_type: str, vertex_lc_type: str) -> Type[Vertex]:
def _get_vertex_class(self, vertex_type: str, vertex_base_type: str) -> Type[Vertex]:
"""Returns the vertex class based on the vertex type."""
if vertex_type in FILE_TOOLS:
return FileToolVertex
if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type]
if vertex_base_type == "CustomComponent":
return lazy_load_vertex_dict.get_custom_component_vertex_type()
if vertex_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_base_type]
return (
lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_lc_type]
if vertex_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type]
if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
else Vertex
)
@ -227,9 +230,9 @@ class Graph:
for vertex in self._vertices:
vertex_data = vertex["data"]
vertex_type: str = vertex_data["type"] # type: ignore
vertex_lc_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_lc_type)
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type)
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)

View file

@ -1,19 +1,19 @@
from langflow.graph.vertex import types
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.custom.base import custom_component_creator
from langflow.interface.document_loaders.base import documentloader_creator
from langflow.interface.embeddings.base import embedding_creator
from langflow.interface.llms.base import llm_creator
from langflow.interface.memories.base import memory_creator
from langflow.interface.output_parsers.base import output_parser_creator
from langflow.interface.prompts.base import prompt_creator
from langflow.interface.retrievers.base import retriever_creator
from langflow.interface.text_splitters.base import textsplitter_creator
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.tools.base import tool_creator
from langflow.interface.vector_store.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
from langflow.interface.output_parsers.base import output_parser_creator
from langflow.interface.retrievers.base import retriever_creator
from langflow.interface.custom.base import custom_component_creator
from langflow.utils.lazy_load import LazyLoadDictBase
@ -32,6 +32,9 @@ class VertexTypesDict(LazyLoadDictBase):
"Custom": ["Custom Tool", "Python Function"],
}
def get_custom_component_vertex_type(self):
return types.CustomComponentVertex
def get_type_dict(self):
return {
**{t: types.PromptVertex for t in prompt_creator.to_list()},