Refactor code for CustomComponent class and build_langchain_template_custom_component function
The code in the CustomComponent class has been refactored to improve readability and maintainability. The `_handle_function` method now handles function arguments correctly. Additionally, the `is_valid_class_template` method has been renamed to `_class_template_validation`, and it now raises an HTTPException with a detailed error message when the main class or the build function contains invalid information. In the `build_langchain_template_custom_component` function, base classes are now retrieved from the `return_type` and added to the `template.base_classes` list. A try-except block is used to handle possible KeyError or AttributeError exceptions, and an HTTPException is raised with the corresponding error message and traceback if an error occurs. These changes ensure more accurate validation and handle potential errors more gracefully.
This commit is contained in:
parent
c9a2ba5821
commit
f92fefba46
2 changed files with 35 additions and 10 deletions
|
|
@ -185,10 +185,16 @@ class CustomComponent(BaseModel):
|
|||
|
||||
return function_args, return_type
|
||||
|
||||
def is_valid_class_template(self, code: dict):
|
||||
def _class_template_validation(self, code: dict):
|
||||
class_name = code.get("class", {}).get("name", None)
|
||||
if not class_name: # this will also check for None, empty string, etc.
|
||||
return False
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "The main class must have a valid name.",
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
functions = code.get("functions", [])
|
||||
if build_function := next(
|
||||
|
|
@ -198,7 +204,13 @@ class CustomComponent(BaseModel):
|
|||
# Check if the return type of the build function is valid
|
||||
return build_function.get("return_type") in self.return_type_valid_list
|
||||
else:
|
||||
return False
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"The class return [{str(build_function.get('return_type'))}] needs to be an item from this list. [{str(self.return_type_valid_list)}]",
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
def get_function(self):
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
|
@ -209,7 +221,7 @@ class CustomComponent(BaseModel):
|
|||
|
||||
@property
|
||||
def is_valid(self):
|
||||
return self.is_valid_class_template(self.data)
|
||||
return self._class_template_validation(self.data)
|
||||
|
||||
@property
|
||||
def args_and_return_type(self):
|
||||
|
|
|
|||
|
|
@ -18,6 +18,13 @@ from langflow.template.field.base import TemplateField
|
|||
from langflow.template.frontend_node.tools import CustomComponentNode
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
|
||||
from langflow.utils.util import get_base_classes
|
||||
|
||||
from fastapi import HTTPException
|
||||
import traceback
|
||||
|
||||
# Used to get the base_classes list
|
||||
|
||||
|
||||
def get_type_list():
|
||||
"""Get a list of all langchain types"""
|
||||
|
|
@ -117,12 +124,18 @@ def build_langchain_template_custom_component(extractor: CustomComponent):
|
|||
|
||||
template = add_code_field(template, raw_code)
|
||||
|
||||
# TODO: Get base classes from "return_type" and add to template.base_classes
|
||||
template.get("base_classes").append("ConversationChain")
|
||||
template.get("base_classes").append("LLMChain")
|
||||
template.get("base_classes").append("Chain")
|
||||
template.get("base_classes").append("Serializable")
|
||||
template.get("base_classes").append("function")
|
||||
# Get base classes from "return_type" and add to template.base_classes
|
||||
try:
|
||||
return_type_instance = globals()[return_type]
|
||||
base_classes = get_base_classes(return_type_instance)
|
||||
except (KeyError, AttributeError) as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": type(err).__name__, "traceback": traceback.format_exc()},
|
||||
) from err
|
||||
|
||||
for base_class in base_classes:
|
||||
template.get("base_classes").append(base_class)
|
||||
|
||||
return template
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue