🔥 refactor(custom.py): remove unused imports and commented out code
🚀 feat(custom.py): refactor CustomComponent class to remove unused code and improve code organization
The changes in this commit remove unused imports and commented out code from the `custom.py` file. The `CustomComponent` class has been refactored to remove the `CustomComponent_old` class and unused methods. The code has been reorganized to improve readability and maintainability.
This commit is contained in:
parent
97572bea25
commit
2775789ccb
1 changed files with 0 additions and 162 deletions
|
|
@ -1,6 +1,3 @@
|
|||
import ast
|
||||
import traceback
|
||||
|
||||
from typing import Callable, Optional
|
||||
from langflow.interface.importing.utils import get_function
|
||||
|
||||
|
|
@ -9,8 +6,6 @@ from pydantic import BaseModel, validator
|
|||
from langflow.utils import validate
|
||||
from langchain.agents.tools import Tool
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
code: str
|
||||
|
|
@ -79,160 +74,3 @@ class CustomComponent_old(BaseModel):
|
|||
function_name = validate.extract_function_name(self.code)
|
||||
|
||||
return validate.create_function(self.code, function_name)
|
||||
|
||||
|
||||
class CustomComponent(BaseModel):
|
||||
code: str
|
||||
function: Optional[Callable] = None
|
||||
function_entrypoint_name = "build"
|
||||
return_type_valid_list = ["ConversationChain", "BaseLLM", "Tool"]
|
||||
class_template = {
|
||||
"imports": [],
|
||||
"class": {"inherited_classes": "", "name": "", "init": ""},
|
||||
"functions": [],
|
||||
}
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
def _handle_import(self, node):
|
||||
for alias in node.names:
|
||||
module_name = getattr(node, "module", None)
|
||||
self.class_template["imports"].append(
|
||||
f"{module_name}.{alias.name}" if module_name else alias.name
|
||||
)
|
||||
|
||||
def _handle_class(self, node):
|
||||
self.class_template["class"].update(
|
||||
{
|
||||
"name": node.name,
|
||||
"inherited_classes": [ast.unparse(base) for base in node.bases],
|
||||
}
|
||||
)
|
||||
|
||||
for inner_node in node.body:
|
||||
if isinstance(inner_node, ast.FunctionDef):
|
||||
self._handle_function(inner_node)
|
||||
|
||||
def _handle_function(self, node):
|
||||
function_name = node.name
|
||||
function_args_str = ast.unparse(node.args)
|
||||
function_args = function_args_str.split(", ") if function_args_str else []
|
||||
|
||||
return_type = ast.unparse(node.returns) if node.returns else "None"
|
||||
|
||||
function_data = {
|
||||
"name": function_name,
|
||||
"arguments": function_args,
|
||||
"return_type": return_type,
|
||||
}
|
||||
|
||||
if function_name == "__init__":
|
||||
self.class_template["class"]["init"] = (
|
||||
function_args_str.split(", ") if function_args_str else []
|
||||
)
|
||||
else:
|
||||
self.class_template["functions"].append(function_data)
|
||||
|
||||
def transform_list(self, input_list):
|
||||
output_list = []
|
||||
for item in input_list:
|
||||
# Split each item on ':' to separate variable name and type
|
||||
split_item = item.split(":")
|
||||
|
||||
# If there is a type, strip any leading/trailing spaces from it
|
||||
if len(split_item) > 1:
|
||||
split_item[1] = split_item[1].strip()
|
||||
# If there isn't a type, append None
|
||||
else:
|
||||
split_item.append(None)
|
||||
output_list.append(split_item)
|
||||
|
||||
return output_list
|
||||
|
||||
def extract_class_info(self):
|
||||
try:
|
||||
module = ast.parse(self.code)
|
||||
except SyntaxError as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": err.msg, "traceback": traceback.format_exc()},
|
||||
) from err
|
||||
|
||||
for node in module.body:
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
self._handle_import(node)
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
self._handle_class(node)
|
||||
|
||||
return self.class_template
|
||||
|
||||
def get_entrypoint_function_args_and_return_type(self):
|
||||
data = self.extract_class_info()
|
||||
functions = data.get("functions", [])
|
||||
|
||||
if build_function := next(
|
||||
(f for f in functions if f["name"] == self.function_entrypoint_name),
|
||||
None,
|
||||
):
|
||||
function_args = build_function.get("arguments", None)
|
||||
function_args = self.transform_list(function_args)
|
||||
|
||||
return_type = build_function.get("return_type", None)
|
||||
else:
|
||||
function_args = None
|
||||
return_type = None
|
||||
|
||||
return function_args, return_type
|
||||
|
||||
def _class_template_validation(self, code: dict):
|
||||
class_name = code.get("class", {}).get("name", None)
|
||||
if not class_name: # this will also check for None, empty string, etc.
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "The main class must have a valid name.",
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
functions = code.get("functions", [])
|
||||
build_function = next(
|
||||
(f for f in functions if f["name"] == self.function_entrypoint_name),
|
||||
None,
|
||||
)
|
||||
|
||||
if not build_function:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid entrypoint function name",
|
||||
"traceback": f"There needs to be at least one entrypoint function named '{self.function_entrypoint_name}' and it needs to return one of the types from this list {str(self.return_type_valid_list)}.",
|
||||
},
|
||||
)
|
||||
|
||||
return_type = build_function.get("return_type")
|
||||
if return_type not in self.return_type_valid_list:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid entrypoint function return",
|
||||
"traceback": f"The entrypoint function return '{return_type}' needs to be an item from this list {str(self.return_type_valid_list)}.",
|
||||
},
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def get_function(self):
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.extract_class_info()
|
||||
|
||||
def is_check_valid(self):
|
||||
return self._class_template_validation(self.data)
|
||||
|
||||
@property
|
||||
def args_and_return_type(self):
|
||||
return self.get_entrypoint_function_args_and_return_type()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue