diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index f331c3416..c91abcfbe 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -1,5 +1,6 @@ from langflow.interface.agents.base import agent_creator from langflow.interface.chains.base import chain_creator +from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES from langflow.interface.document_loaders.base import documentloader_creator from langflow.interface.embeddings.base import embedding_creator from langflow.interface.llms.base import llm_creator @@ -12,7 +13,8 @@ from langflow.interface.utilities.base import utility_creator from langflow.interface.vector_store.base import vectorstore_creator from langflow.interface.wrappers.base import wrapper_creator from langflow.interface.output_parsers.base import output_parser_creator -from langflow.interface.tools.custom import CustomComponent +from langflow.interface.custom.base import custom_component_creator +from langflow.interface.custom.custom import CustomComponent from langflow.template.field.base import TemplateField from langflow.template.frontend_node.tools import CustomComponentNode @@ -24,9 +26,6 @@ from fastapi import HTTPException import traceback # Used to get the base_classes list -from langchain.chains import ConversationChain # noqa: F401 -from langchain.llms.base import BaseLLM # noqa: F401 -from langchain.tools import Tool # noqa: F401 def get_type_list(): @@ -62,6 +61,7 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union utility_creator, output_parser_creator, retriever_creator, + custom_component_creator, ] all_types = {} @@ -73,9 +73,16 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union # TODO: Move to correct place -def add_new_custom_field(template, field_name: str, field_type: str): +def add_new_custom_field( + template, field_name: str, field_type: str, field_config: dict +): new_field = TemplateField( - name=field_name, field_type=field_type, show=True, required=True, advanced=False + name=field_name, + field_type=field_type, + show=True, + required=True, + advanced=False, + **field_config, ) template.get("template")[field_name] = new_field.to_dict() template.get("custom_fields")[field_name] = None @@ -108,28 +115,45 @@ def add_code_field(template, raw_code): def build_langchain_template_custom_component(extractor: CustomComponent): # Build base "CustomComponent" template - template = CustomComponentNode().to_dict().get(type(extractor).__name__) + frontend_node = CustomComponentNode().to_dict().get(type(extractor).__name__) - function_args, return_type = extractor.args_and_return_type + function_args, return_type, template_config = extractor.args_and_return_type + + if "display_name" in template_config and frontend_node is not None: + frontend_node["display_name"] = template_config["display_name"] raw_code = extractor.code + field_config = template_config.get("field_config", {}) + if function_args is not None: + # Add extra fields + for extra_field in function_args: + def_field = extra_field[0] + def_type = extra_field[1] - # Add extra fields - for extra_field in function_args: - def_field = extra_field[0] - def_type = extra_field[1] + if def_field != "self": + # TODO: Validate type - if is possible to render into frontend + if not def_type: + def_type = "str" + config = field_config.get(def_field, {}) + frontend_node = add_new_custom_field( + frontend_node, def_field, def_type, config + ) - if def_field != "self": - # TODO: Validate type - if is possible to render into frontend - if not def_type: - def_type = "str" - - template = add_new_custom_field(template, def_field, def_type) - - template = add_code_field(template, raw_code) + frontend_node = add_code_field(frontend_node, raw_code) # Get base classes from "return_type" and add to template.base_classes try: - return_type_instance = globals()[return_type] + if return_type not in LANGCHAIN_BASE_TYPES or return_type is None: + raise HTTPException( + status_code=400, + detail={ + "error": ( + "Invalid return type should be one of: " + f"{list(LANGCHAIN_BASE_TYPES.keys())}" + ), + "traceback": traceback.format_exc(), + }, + ) + return_type_instance = LANGCHAIN_BASE_TYPES.get(return_type) base_classes = get_base_classes(return_type_instance) except (KeyError, AttributeError) as err: raise HTTPException( @@ -138,9 +162,9 @@ def build_langchain_template_custom_component(extractor: CustomComponent): ) from err for base_class in base_classes: - template.get("base_classes").append(base_class) + frontend_node.get("base_classes").append(base_class) - return template + return frontend_node langchain_types_dict = build_langchain_types_dict()