🔧 fix(constants.py): add support for additional Python base types in CUSTOM_COMPONENT_SUPPORTED_TYPES dictionary

🔧 fix(custom_component.py): update return_type_valid_list to use CUSTOM_COMPONENT_SUPPORTED_TYPES dictionary
🔧 fix(types.py): update add_base_classes function to use CUSTOM_COMPONENT_SUPPORTED_TYPES dictionary instead of LANGCHAIN_BASE_TYPES
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-26 10:15:22 -03:00
commit 63ead274c4
3 changed files with 16 additions and 6 deletions

View file

@ -20,7 +20,17 @@ LANGCHAIN_BASE_TYPES = {
"VectorStore": VectorStore,
"Embeddings": Embeddings,
"BaseRetriever": BaseRetriever,
}
# Langchain base types plus Python base types
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
**LANGCHAIN_BASE_TYPES,
"str": str,
"int": int,
"float": float,
"bool": bool,
"list": list,
"dict": dict,
}

View file

@ -1,6 +1,6 @@
from typing import Callable, Optional
from fastapi import HTTPException
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
from langflow.interface.custom.component import Component
from langflow.utils import validate
@ -16,7 +16,7 @@ class CustomComponent(Component, extra=Extra.allow):
code_class_base_inheritance = "CustomComponent"
function_entrypoint_name = "build"
function: Optional[Callable] = None
return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys())
return_type_valid_list = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
repr_value: Optional[str] = ""
def __init__(self, **data):

View file

@ -1,6 +1,6 @@
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
from langflow.interface.document_loaders.base import documentloader_creator
from langflow.interface.embeddings.base import embedding_creator
from langflow.interface.importing.utils import get_function_custom
@ -232,19 +232,19 @@ def get_field_properties(extra_field):
def add_base_classes(frontend_node, return_type):
"""Add base classes to the frontend node"""
if return_type not in LANGCHAIN_BASE_TYPES or return_type is None:
if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type should be one of: "
f"{list(LANGCHAIN_BASE_TYPES.keys())}"
f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
),
"traceback": traceback.format_exc(),
},
)
return_type_instance = LANGCHAIN_BASE_TYPES.get(return_type)
return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type)
base_classes = get_base_classes(return_type_instance)
for base_class in base_classes: