diff --git a/src/backend/langflow/interface/tools/custom.py b/src/backend/langflow/interface/tools/custom.py index 5b0b589da..60541960e 100644 --- a/src/backend/langflow/interface/tools/custom.py +++ b/src/backend/langflow/interface/tools/custom.py @@ -185,10 +185,16 @@ class CustomComponent(BaseModel): return function_args, return_type - def is_valid_class_template(self, code: dict): + def _class_template_validation(self, code: dict): class_name = code.get("class", {}).get("name", None) if not class_name: # this will also check for None, empty string, etc. - return False + raise HTTPException( + status_code=400, + detail={ + "error": "The main class must have a valid name.", + "traceback": "", + }, + ) functions = code.get("functions", []) if build_function := next( @@ -198,7 +204,13 @@ class CustomComponent(BaseModel): # Check if the return type of the build function is valid return build_function.get("return_type") in self.return_type_valid_list else: - return False + raise HTTPException( + status_code=400, + detail={ + "error": f"The class return [{str(build_function.get('return_type'))}] needs to be an item from this list. [{str(self.return_type_valid_list)}]", + "traceback": "", + }, + ) def get_function(self): return validate.create_function(self.code, self.function_entrypoint_name) @@ -209,7 +221,7 @@ class CustomComponent(BaseModel): @property def is_valid(self): - return self.is_valid_class_template(self.data) + return self._class_template_validation(self.data) @property def args_and_return_type(self): diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 854e0a9ac..58238958c 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -18,6 +18,13 @@ from langflow.template.field.base import TemplateField from langflow.template.frontend_node.tools import CustomComponentNode from langflow.interface.retrievers.base import retriever_creator +from langflow.utils.util import get_base_classes + +from fastapi import HTTPException +import traceback + +# Used to get the base_classes list + def get_type_list(): """Get a list of all langchain types""" @@ -117,12 +124,18 @@ def build_langchain_template_custom_component(extractor: CustomComponent): template = add_code_field(template, raw_code) - # TODO: Get base classes from "return_type" and add to template.base_classes - template.get("base_classes").append("ConversationChain") - template.get("base_classes").append("LLMChain") - template.get("base_classes").append("Chain") - template.get("base_classes").append("Serializable") - template.get("base_classes").append("function") + # Get base classes from "return_type" and add to template.base_classes + try: + return_type_instance = globals()[return_type] + base_classes = get_base_classes(return_type_instance) + except (KeyError, AttributeError) as err: + raise HTTPException( + status_code=400, + detail={"error": type(err).__name__, "traceback": traceback.format_exc()}, + ) from err + + for base_class in base_classes: + template.get("base_classes").append(base_class) return template