Refactor code for CustomComponent class and build_langchain_template_custom_component function

The code in the CustomComponent class has been refactored to improve readability and maintainability. The `_handle_function` method now handles function arguments correctly. Additionally, the `is_valid_class_template` method has been renamed to `_class_template_validation`, and it now raises an HTTPException with a detailed error message when the main class or the build function contains invalid information.

In the `build_langchain_template_custom_component` function, base classes are now retrieved from the `return_type` and added to the `template.base_classes` list. A try-except block is used to handle possible KeyError or AttributeError exceptions, and an HTTPException is raised with the corresponding error message and traceback if an error occurs.

These changes ensure more accurate validation and handle potential errors more gracefully.
This commit is contained in:
gustavoschaedler 2023-07-06 00:21:35 +01:00
commit f92fefba46
2 changed files with 35 additions and 10 deletions

View file

@ -185,10 +185,16 @@ class CustomComponent(BaseModel):
return function_args, return_type
def is_valid_class_template(self, code: dict):
def _class_template_validation(self, code: dict):
class_name = code.get("class", {}).get("name", None)
if not class_name: # this will also check for None, empty string, etc.
return False
raise HTTPException(
status_code=400,
detail={
"error": "The main class must have a valid name.",
"traceback": "",
},
)
functions = code.get("functions", [])
if build_function := next(
@ -198,7 +204,13 @@ class CustomComponent(BaseModel):
# Check if the return type of the build function is valid
return build_function.get("return_type") in self.return_type_valid_list
else:
return False
raise HTTPException(
status_code=400,
detail={
"error": f"The class return [{str(build_function.get('return_type'))}] needs to be an item from this list. [{str(self.return_type_valid_list)}]",
"traceback": "",
},
)
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
@ -209,7 +221,7 @@ class CustomComponent(BaseModel):
@property
def is_valid(self):
return self.is_valid_class_template(self.data)
return self._class_template_validation(self.data)
@property
def args_and_return_type(self):

View file

@ -18,6 +18,13 @@ from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.tools import CustomComponentNode
from langflow.interface.retrievers.base import retriever_creator
from langflow.utils.util import get_base_classes
from fastapi import HTTPException
import traceback
# Used to get the base_classes list
def get_type_list():
"""Get a list of all langchain types"""
@ -117,12 +124,18 @@ def build_langchain_template_custom_component(extractor: CustomComponent):
template = add_code_field(template, raw_code)
# TODO: Get base classes from "return_type" and add to template.base_classes
template.get("base_classes").append("ConversationChain")
template.get("base_classes").append("LLMChain")
template.get("base_classes").append("Chain")
template.get("base_classes").append("Serializable")
template.get("base_classes").append("function")
# Get base classes from "return_type" and add to template.base_classes
try:
return_type_instance = globals()[return_type]
base_classes = get_base_classes(return_type_instance)
except (KeyError, AttributeError) as err:
raise HTTPException(
status_code=400,
detail={"error": type(err).__name__, "traceback": traceback.format_exc()},
) from err
for base_class in base_classes:
template.get("base_classes").append(base_class)
return template