feat(custom.py): add CustomComponent class to handle custom code components

🐛 fix(custom.py): fix typo in function_entrypoint_name variable assignment
The CustomComponent class is added to handle custom code components. It includes methods to handle imports, classes, and functions in the provided code. The class also has methods to extract class information, get entrypoint function arguments and return type, build a template configuration, validate the class template, and get the entrypoint function. A typo in the assignment of the function_entrypoint_name variable is fixed.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-06 23:57:04 -03:00
commit 97399632e2

View file

@ -0,0 +1,192 @@
import ast
import traceback
from typing import Callable, Optional
from fastapi import HTTPException
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
from langflow.utils import validate
from pydantic import BaseModel
class CustomComponent(BaseModel):
field_config: dict = {}
code: str
function: Optional[Callable] = None
function_entrypoint_name = "build"
return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys())
class_template = {
"imports": [],
"class": {"inherited_classes": "", "name": "", "init": "", "attributes": {}},
"functions": [],
}
def __init__(self, **data):
super().__init__(**data)
def _handle_import(self, node):
for alias in node.names:
module_name = getattr(node, "module", None)
self.class_template["imports"].append(
f"{module_name}.{alias.name}" if module_name else alias.name
)
def _handle_class(self, node):
self.class_template["class"].update(
{
"name": node.name,
"inherited_classes": [ast.unparse(base) for base in node.bases],
}
)
attributes = {} # To store the attributes and their values
for inner_node in node.body:
if isinstance(inner_node, ast.Assign): # An assignment
for target in inner_node.targets: # Targets of the assignment
if isinstance(target, ast.Name): # A simple variable
# Add the attribute and its value to the dictionary
attributes[target.id] = ast.unparse(inner_node.value)
elif isinstance(inner_node, ast.FunctionDef):
self._handle_function(inner_node)
# You can add these attributes to your class_template if you want
self.class_template["class"]["attributes"] = attributes
def _handle_function(self, node):
function_name = node.name
function_args_str = ast.unparse(node.args)
function_args = function_args_str.split(", ") if function_args_str else []
return_type = ast.unparse(node.returns) if node.returns else "None"
function_data = {
"name": function_name,
"arguments": function_args,
"return_type": return_type,
}
if function_name == "__init__":
self.class_template["class"]["init"] = (
function_args_str.split(", ") if function_args_str else []
)
else:
self.class_template["functions"].append(function_data)
def transform_list(self, input_list):
output_list = []
for item in input_list:
# Split each item on ':' to separate variable name and type
split_item = item.split(":")
# If there is a type, strip any leading/trailing spaces from it
if len(split_item) > 1:
split_item[1] = split_item[1].strip()
# If there isn't a type, append None
else:
split_item.append(None)
output_list.append(split_item)
return output_list
def extract_class_info(self):
try:
module = ast.parse(self.code)
except SyntaxError as err:
raise HTTPException(
status_code=400,
detail={"error": err.msg, "traceback": traceback.format_exc()},
) from err
for node in module.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._handle_import(node)
elif isinstance(node, ast.ClassDef):
self._handle_class(node)
return self.class_template
def get_entrypoint_function_args_and_return_type(self):
data = self.extract_class_info()
attributes = data.get("class", {}).get("attributes", {})
functions = data.get("functions", [])
template_config = self._build_template_config(attributes)
if build_function := next(
(f for f in functions if f["name"] == self.function_entrypoint_name),
None,
):
function_args = build_function.get("arguments", None)
function_args = self.transform_list(function_args)
return_type = build_function.get("return_type", None)
else:
function_args = None
return_type = None
return function_args, return_type, template_config
def _build_template_config(self, attributes):
template_config = {}
if "field_config" in attributes:
template_config["field_config"] = ast.literal_eval(
attributes["field_config"]
)
if "display_name" in attributes:
template_config["display_name"] = ast.literal_eval(
attributes["display_name"]
)
return template_config
def _class_template_validation(self, code: dict):
class_name = code.get("class", {}).get("name", None)
if not class_name: # this will also check for None, empty string, etc.
raise HTTPException(
status_code=400,
detail={
"error": "The main class must have a valid name.",
"traceback": "",
},
)
functions = code.get("functions", [])
build_function = next(
(f for f in functions if f["name"] == self.function_entrypoint_name),
None,
)
if not build_function:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid entrypoint function name",
"traceback": f"There needs to be at least one entrypoint function named '{self.function_entrypoint_name}' and it needs to return one of the types from this list {str(self.return_type_valid_list)}.",
},
)
return_type = build_function.get("return_type")
if return_type not in self.return_type_valid_list:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid entrypoint function return",
"traceback": f"The entrypoint function return '{return_type}' needs to be an item from this list {str(self.return_type_valid_list)}.",
},
)
return True
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
def build(self):
pass
@property
def data(self):
return self.extract_class_info()
def is_check_valid(self):
return self._class_template_validation(self.data)
@property
def args_and_return_type(self):
return self.get_entrypoint_function_args_and_return_type()