✨ feat(custom.py): add CustomComponent class to handle custom code components
🐛 fix(custom.py): fix typo in function_entrypoint_name variable assignment
The CustomComponent class is added to handle custom code components. It includes methods to handle imports, classes, and functions in the provided code. The class also has methods to extract class information, get entrypoint function arguments and return type, build a template configuration, validate the class template, and get the entrypoint function. A typo in the assignment of the function_entrypoint_name variable is fixed.
This commit is contained in:
parent
692994f100
commit
97399632e2
1 changed files with 192 additions and 0 deletions
192
src/backend/langflow/interface/custom/custom.py
Normal file
192
src/backend/langflow/interface/custom/custom.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
import ast
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
|
||||
|
||||
from langflow.utils import validate
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CustomComponent(BaseModel):
|
||||
field_config: dict = {}
|
||||
code: str
|
||||
function: Optional[Callable] = None
|
||||
function_entrypoint_name = "build"
|
||||
return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys())
|
||||
class_template = {
|
||||
"imports": [],
|
||||
"class": {"inherited_classes": "", "name": "", "init": "", "attributes": {}},
|
||||
"functions": [],
|
||||
}
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
def _handle_import(self, node):
|
||||
for alias in node.names:
|
||||
module_name = getattr(node, "module", None)
|
||||
self.class_template["imports"].append(
|
||||
f"{module_name}.{alias.name}" if module_name else alias.name
|
||||
)
|
||||
|
||||
def _handle_class(self, node):
|
||||
self.class_template["class"].update(
|
||||
{
|
||||
"name": node.name,
|
||||
"inherited_classes": [ast.unparse(base) for base in node.bases],
|
||||
}
|
||||
)
|
||||
|
||||
attributes = {} # To store the attributes and their values
|
||||
|
||||
for inner_node in node.body:
|
||||
if isinstance(inner_node, ast.Assign): # An assignment
|
||||
for target in inner_node.targets: # Targets of the assignment
|
||||
if isinstance(target, ast.Name): # A simple variable
|
||||
# Add the attribute and its value to the dictionary
|
||||
attributes[target.id] = ast.unparse(inner_node.value)
|
||||
elif isinstance(inner_node, ast.FunctionDef):
|
||||
self._handle_function(inner_node)
|
||||
|
||||
# You can add these attributes to your class_template if you want
|
||||
self.class_template["class"]["attributes"] = attributes
|
||||
|
||||
def _handle_function(self, node):
|
||||
function_name = node.name
|
||||
function_args_str = ast.unparse(node.args)
|
||||
function_args = function_args_str.split(", ") if function_args_str else []
|
||||
|
||||
return_type = ast.unparse(node.returns) if node.returns else "None"
|
||||
|
||||
function_data = {
|
||||
"name": function_name,
|
||||
"arguments": function_args,
|
||||
"return_type": return_type,
|
||||
}
|
||||
|
||||
if function_name == "__init__":
|
||||
self.class_template["class"]["init"] = (
|
||||
function_args_str.split(", ") if function_args_str else []
|
||||
)
|
||||
else:
|
||||
self.class_template["functions"].append(function_data)
|
||||
|
||||
def transform_list(self, input_list):
|
||||
output_list = []
|
||||
for item in input_list:
|
||||
# Split each item on ':' to separate variable name and type
|
||||
split_item = item.split(":")
|
||||
|
||||
# If there is a type, strip any leading/trailing spaces from it
|
||||
if len(split_item) > 1:
|
||||
split_item[1] = split_item[1].strip()
|
||||
# If there isn't a type, append None
|
||||
else:
|
||||
split_item.append(None)
|
||||
output_list.append(split_item)
|
||||
|
||||
return output_list
|
||||
|
||||
def extract_class_info(self):
|
||||
try:
|
||||
module = ast.parse(self.code)
|
||||
except SyntaxError as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": err.msg, "traceback": traceback.format_exc()},
|
||||
) from err
|
||||
|
||||
for node in module.body:
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
self._handle_import(node)
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
self._handle_class(node)
|
||||
|
||||
return self.class_template
|
||||
|
||||
def get_entrypoint_function_args_and_return_type(self):
|
||||
data = self.extract_class_info()
|
||||
attributes = data.get("class", {}).get("attributes", {})
|
||||
functions = data.get("functions", [])
|
||||
template_config = self._build_template_config(attributes)
|
||||
if build_function := next(
|
||||
(f for f in functions if f["name"] == self.function_entrypoint_name),
|
||||
None,
|
||||
):
|
||||
function_args = build_function.get("arguments", None)
|
||||
function_args = self.transform_list(function_args)
|
||||
|
||||
return_type = build_function.get("return_type", None)
|
||||
else:
|
||||
function_args = None
|
||||
return_type = None
|
||||
|
||||
return function_args, return_type, template_config
|
||||
|
||||
def _build_template_config(self, attributes):
|
||||
template_config = {}
|
||||
if "field_config" in attributes:
|
||||
template_config["field_config"] = ast.literal_eval(
|
||||
attributes["field_config"]
|
||||
)
|
||||
if "display_name" in attributes:
|
||||
template_config["display_name"] = ast.literal_eval(
|
||||
attributes["display_name"]
|
||||
)
|
||||
return template_config
|
||||
|
||||
def _class_template_validation(self, code: dict):
|
||||
class_name = code.get("class", {}).get("name", None)
|
||||
if not class_name: # this will also check for None, empty string, etc.
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "The main class must have a valid name.",
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
functions = code.get("functions", [])
|
||||
build_function = next(
|
||||
(f for f in functions if f["name"] == self.function_entrypoint_name),
|
||||
None,
|
||||
)
|
||||
|
||||
if not build_function:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid entrypoint function name",
|
||||
"traceback": f"There needs to be at least one entrypoint function named '{self.function_entrypoint_name}' and it needs to return one of the types from this list {str(self.return_type_valid_list)}.",
|
||||
},
|
||||
)
|
||||
|
||||
return_type = build_function.get("return_type")
|
||||
if return_type not in self.return_type_valid_list:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid entrypoint function return",
|
||||
"traceback": f"The entrypoint function return '{return_type}' needs to be an item from this list {str(self.return_type_valid_list)}.",
|
||||
},
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def get_function(self):
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def build(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.extract_class_info()
|
||||
|
||||
def is_check_valid(self):
|
||||
return self._class_template_validation(self.data)
|
||||
|
||||
@property
|
||||
def args_and_return_type(self):
|
||||
return self.get_entrypoint_function_args_and_return_type()
|
||||
Loading…
Add table
Add a link
Reference in a new issue