From f92fefba46c252b8812ca8fc5b6e7d6d05b62dda Mon Sep 17 00:00:00 2001 From: gustavoschaedler Date: Thu, 6 Jul 2023 00:21:35 +0100 Subject: [PATCH] Refactor code for CustomComponent class and build_langchain_template_custom_component function The code in the CustomComponent class has been refactored to improve readability and maintainability. The `_handle_function` method now handles function arguments correctly. Additionally, the `is_valid_class_template` method has been renamed to `_class_template_validation`, and it now raises an HTTPException with a detailed error message when the main class or the build function contains invalid information. In the `build_langchain_template_custom_component` function, base classes are now retrieved from the `return_type` and added to the `template.base_classes` list. A try-except block is used to handle possible KeyError or AttributeError exceptions, and an HTTPException is raised with the corresponding error message and traceback if an error occurs. These changes ensure more accurate validation and handle potential errors more gracefully. --- .../langflow/interface/tools/custom.py | 20 ++++++++++++--- src/backend/langflow/interface/types.py | 25 ++++++++++++++----- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/backend/langflow/interface/tools/custom.py b/src/backend/langflow/interface/tools/custom.py index 5b0b589da..60541960e 100644 --- a/src/backend/langflow/interface/tools/custom.py +++ b/src/backend/langflow/interface/tools/custom.py @@ -185,10 +185,16 @@ class CustomComponent(BaseModel): return function_args, return_type - def is_valid_class_template(self, code: dict): + def _class_template_validation(self, code: dict): class_name = code.get("class", {}).get("name", None) if not class_name: # this will also check for None, empty string, etc. - return False + raise HTTPException( + status_code=400, + detail={ + "error": "The main class must have a valid name.", + "traceback": "", + }, + ) functions = code.get("functions", []) if build_function := next( @@ -198,7 +204,13 @@ class CustomComponent(BaseModel): # Check if the return type of the build function is valid return build_function.get("return_type") in self.return_type_valid_list else: - return False + raise HTTPException( + status_code=400, + detail={ + "error": f"The class return [{str(build_function.get('return_type'))}] needs to be an item from this list. [{str(self.return_type_valid_list)}]", + "traceback": "", + }, + ) def get_function(self): return validate.create_function(self.code, self.function_entrypoint_name) @@ -209,7 +221,7 @@ class CustomComponent(BaseModel): @property def is_valid(self): - return self.is_valid_class_template(self.data) + return self._class_template_validation(self.data) @property def args_and_return_type(self): diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 854e0a9ac..58238958c 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -18,6 +18,13 @@ from langflow.template.field.base import TemplateField from langflow.template.frontend_node.tools import CustomComponentNode from langflow.interface.retrievers.base import retriever_creator +from langflow.utils.util import get_base_classes + +from fastapi import HTTPException +import traceback + +# Used to get the base_classes list + def get_type_list(): """Get a list of all langchain types""" @@ -117,12 +124,18 @@ def build_langchain_template_custom_component(extractor: CustomComponent): template = add_code_field(template, raw_code) - # TODO: Get base classes from "return_type" and add to template.base_classes - template.get("base_classes").append("ConversationChain") - template.get("base_classes").append("LLMChain") - template.get("base_classes").append("Chain") - template.get("base_classes").append("Serializable") - template.get("base_classes").append("function") + # Get base classes from "return_type" and add to template.base_classes + try: + return_type_instance = globals()[return_type] + base_classes = get_base_classes(return_type_instance) + except (KeyError, AttributeError) as err: + raise HTTPException( + status_code=400, + detail={"error": type(err).__name__, "traceback": traceback.format_exc()}, + ) from err + + for base_class in base_classes: + template.get("base_classes").append(base_class) return template