langflow/src/backend/langflow/interface/tools/custom.py
gustavoschaedler ae5bc1e203 Refactor custom.py for better error handling
- Modifies the `CustomComponent` class to handle errors using `HTTPException` from FastAPI
- Removes the unused `HTTPExceptionWithTraceback` class
- Updates the error response format with `err.msg` and `traceback.format_exc()`
2023-07-05 16:23:32 +01:00

229 lines
6.4 KiB
Python

import ast
import traceback
from typing import Callable, Optional
from langflow.interface.importing.utils import get_function
from pydantic import BaseModel, validator
from langflow.utils import validate
from langchain.agents.tools import Tool
from fastapi import HTTPException
class Function(BaseModel):
code: str
function: Optional[Callable] = None
imports: Optional[str] = None
# Eval code and store the function
def __init__(self, **data):
super().__init__(**data)
# Validate the function
@validator("code")
def validate_func(cls, v):
try:
validate.eval_function(v)
except Exception as e:
raise e
return v
def get_function(self):
"""Get the function"""
function_name = validate.extract_function_name(self.code)
return validate.create_function(self.code, function_name)
class PythonFunctionTool(Function, Tool):
name: str = "Custom Tool"
description: str
code: str
def ___init__(self, name: str, description: str, code: str):
self.name = name
self.description = description
self.code = code
self.func = get_function(self.code)
super().__init__(name=name, description=description, func=self.func)
class PythonFunction(Function):
code: str
class CustomComponent_old(BaseModel):
code: str
function: Optional[Callable] = None
imports: Optional[str] = None
# Eval code and store the class
def __init__(self, **data):
super().__init__(**data)
# Validate the Class code
@validator("code")
def validate_func(cls, v):
try:
validate.eval_function(v)
except Exception as e:
raise e
return v
def get_function(self):
"""Get the function"""
function_name = validate.extract_function_name(self.code)
return validate.create_function(self.code, function_name)
class CustomComponent(BaseModel):
code: str
function: Optional[Callable] = None
function_entrypoint_name = "build"
return_type_valid_list = [
"ConversationChain",
"BaseLLM",
"Tool"
]
class_template = {
"imports": [],
"class": {
"inherited_classes": "",
"name": "",
"init": ""
},
"functions": []
}
def __init__(self, **data):
super().__init__(**data)
def _handle_import(self, node):
for alias in node.names:
module_name = getattr(node, 'module', None)
self.class_template['imports'].append(
f"{module_name}.{alias.name}" if module_name else alias.name)
def _handle_class(self, node):
self.class_template['class'].update({
'name': node.name,
'inherited_classes': [ast.unparse(base) for base in node.bases]
})
for inner_node in node.body:
if isinstance(inner_node, ast.FunctionDef):
self._handle_function(inner_node)
def _handle_function(self, node):
function_name = node.name
function_args_str = ast.unparse(node.args)
function_args = function_args_str.split(
", ") if function_args_str else []
return_type = ast.unparse(node.returns) if node.returns else "None"
function_data = {
"name": function_name,
"arguments": function_args,
"return_type": return_type
}
if function_name == "__init__":
self.class_template['class']['init'] = function_args_str.split(
", ") if function_args_str else []
else:
self.class_template["functions"].append(function_data)
def transform_list(self, input_list):
output_list = []
for item in input_list:
# Split each item on ':' to separate variable name and type
split_item = item.split(':')
# If there is a type, strip any leading/trailing spaces from it
if len(split_item) > 1:
split_item[1] = split_item[1].strip()
# If there isn't a type, append None
else:
split_item.append(None)
output_list.append(split_item)
return output_list
def extract_class_info(self):
try:
module = ast.parse(self.code)
except SyntaxError as err:
raise HTTPException(
status_code=400,
detail={
'error': err.msg,
'traceback': traceback.format_exc()
},
) from err
for node in module.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._handle_import(node)
elif isinstance(node, ast.ClassDef):
self._handle_class(node)
return self.class_template
def get_entrypoint_function_args_and_return_type(self):
data = self.extract_class_info()
functions = data.get("functions", [])
if build_function := next(
(f for f in functions if f["name"]
== self.function_entrypoint_name),
None,
):
function_args = build_function.get("arguments", None)
function_args = self.transform_list(function_args)
return_type = build_function.get("return_type", None)
else:
function_args = None
return_type = None
return function_args, return_type
def is_valid_class_template(self, code: dict):
class_name = code.get("class", {}).get("name", None)
if not class_name: # this will also check for None, empty string, etc.
return False
functions = code.get("functions", [])
if build_function := next(
(f for f in functions if f["name"]
== self.function_entrypoint_name),
None,
):
# Check if the return type of the build function is valid
return build_function.get("return_type") in self.return_type_valid_list
else:
return False
def get_function(self):
return validate.create_function(
self.code,
self.function_entrypoint_name
)
@property
def data(self):
return self.extract_class_info()
@property
def is_valid(self):
return self.is_valid_class_template(self.data)
@property
def args_and_return_type(self):
return self.get_entrypoint_function_args_and_return_type()