🔥 refactor(custom.py): remove unused imports and commented out code

🚀 feat(custom.py): refactor CustomComponent class to remove unused code and improve code organization
The changes in this commit remove unused imports and commented out code from the `custom.py` file. The `CustomComponent` class has been refactored to remove the `CustomComponent_old` class and unused methods. The code has been reorganized to improve readability and maintainability.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-06 23:57:28 -03:00
commit 2775789ccb

View file

@ -1,6 +1,3 @@
import ast
import traceback
from typing import Callable, Optional
from langflow.interface.importing.utils import get_function
@ -9,8 +6,6 @@ from pydantic import BaseModel, validator
from langflow.utils import validate
from langchain.agents.tools import Tool
from fastapi import HTTPException
class Function(BaseModel):
code: str
@ -79,160 +74,3 @@ class CustomComponent_old(BaseModel):
function_name = validate.extract_function_name(self.code)
return validate.create_function(self.code, function_name)
class CustomComponent(BaseModel):
code: str
function: Optional[Callable] = None
function_entrypoint_name = "build"
return_type_valid_list = ["ConversationChain", "BaseLLM", "Tool"]
class_template = {
"imports": [],
"class": {"inherited_classes": "", "name": "", "init": ""},
"functions": [],
}
def __init__(self, **data):
super().__init__(**data)
def _handle_import(self, node):
for alias in node.names:
module_name = getattr(node, "module", None)
self.class_template["imports"].append(
f"{module_name}.{alias.name}" if module_name else alias.name
)
def _handle_class(self, node):
self.class_template["class"].update(
{
"name": node.name,
"inherited_classes": [ast.unparse(base) for base in node.bases],
}
)
for inner_node in node.body:
if isinstance(inner_node, ast.FunctionDef):
self._handle_function(inner_node)
def _handle_function(self, node):
function_name = node.name
function_args_str = ast.unparse(node.args)
function_args = function_args_str.split(", ") if function_args_str else []
return_type = ast.unparse(node.returns) if node.returns else "None"
function_data = {
"name": function_name,
"arguments": function_args,
"return_type": return_type,
}
if function_name == "__init__":
self.class_template["class"]["init"] = (
function_args_str.split(", ") if function_args_str else []
)
else:
self.class_template["functions"].append(function_data)
def transform_list(self, input_list):
output_list = []
for item in input_list:
# Split each item on ':' to separate variable name and type
split_item = item.split(":")
# If there is a type, strip any leading/trailing spaces from it
if len(split_item) > 1:
split_item[1] = split_item[1].strip()
# If there isn't a type, append None
else:
split_item.append(None)
output_list.append(split_item)
return output_list
def extract_class_info(self):
try:
module = ast.parse(self.code)
except SyntaxError as err:
raise HTTPException(
status_code=400,
detail={"error": err.msg, "traceback": traceback.format_exc()},
) from err
for node in module.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._handle_import(node)
elif isinstance(node, ast.ClassDef):
self._handle_class(node)
return self.class_template
def get_entrypoint_function_args_and_return_type(self):
data = self.extract_class_info()
functions = data.get("functions", [])
if build_function := next(
(f for f in functions if f["name"] == self.function_entrypoint_name),
None,
):
function_args = build_function.get("arguments", None)
function_args = self.transform_list(function_args)
return_type = build_function.get("return_type", None)
else:
function_args = None
return_type = None
return function_args, return_type
def _class_template_validation(self, code: dict):
class_name = code.get("class", {}).get("name", None)
if not class_name: # this will also check for None, empty string, etc.
raise HTTPException(
status_code=400,
detail={
"error": "The main class must have a valid name.",
"traceback": "",
},
)
functions = code.get("functions", [])
build_function = next(
(f for f in functions if f["name"] == self.function_entrypoint_name),
None,
)
if not build_function:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid entrypoint function name",
"traceback": f"There needs to be at least one entrypoint function named '{self.function_entrypoint_name}' and it needs to return one of the types from this list {str(self.return_type_valid_list)}.",
},
)
return_type = build_function.get("return_type")
if return_type not in self.return_type_valid_list:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid entrypoint function return",
"traceback": f"The entrypoint function return '{return_type}' needs to be an item from this list {str(self.return_type_valid_list)}.",
},
)
return True
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
@property
def data(self):
return self.extract_class_info()
def is_check_valid(self):
return self._class_template_validation(self.data)
@property
def args_and_return_type(self):
return self.get_entrypoint_function_args_and_return_type()