diff --git a/src/backend/langflow/interface/custom/code_parser/code_parser.py b/src/backend/langflow/interface/custom/code_parser/code_parser.py index 7a102e33f..e54051a5c 100644 --- a/src/backend/langflow/interface/custom/code_parser/code_parser.py +++ b/src/backend/langflow/interface/custom/code_parser/code_parser.py @@ -6,6 +6,9 @@ from typing import Any, Dict, List, Type, Union from cachetools import TTLCache, cachedmethod, keys from fastapi import HTTPException +from loguru import logger + +from langflow.interface.custom.eval import eval_custom_component_code from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails @@ -92,7 +95,9 @@ class CodeParser: elif isinstance(node, ast.ImportFrom): for alias in node.names: if alias.asname: - self.data["imports"].append((node.module, f"{alias.name} as {alias.asname}")) + self.data["imports"].append( + (node.module, f"{alias.name} as {alias.asname}") + ) else: self.data["imports"].append((node.module, alias.name)) @@ -141,7 +146,9 @@ class CodeParser: return_type = None if node.returns: return_type_str = ast.unparse(node.returns) - eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"])) + eval_env = self.construct_eval_env( + return_type_str, tuple(self.data["imports"]) + ) try: return_type = eval(return_type_str, eval_env) @@ -183,14 +190,22 @@ class CodeParser: num_defaults = len(node.args.defaults) num_missing_defaults = num_args - num_defaults missing_defaults = [None] * num_missing_defaults - default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults] + default_values = [ + ast.unparse(default).strip("'") if default else None + for default in node.args.defaults + ] # Now check all default values to see if there # are any "None" values in the middle - default_values = [None if value == "None" else value for value in default_values] + default_values = [ + None if value == "None" else value for value in default_values + ] defaults = missing_defaults + default_values - args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)] + args = [ + self.parse_arg(arg, default) + for arg, default in zip(node.args.args, defaults) + ] return args def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: @@ -208,11 +223,17 @@ class CodeParser: """ Parses the keyword-only arguments of a function or method node. """ - kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [ - ast.unparse(default) if default else None for default in node.args.kw_defaults + kw_defaults = [None] * ( + len(node.args.kwonlyargs) - len(node.args.kw_defaults) + ) + [ + ast.unparse(default) if default else None + for default in node.args.kw_defaults ] - args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)] + args = [ + self.parse_arg(arg, default) + for arg, default in zip(node.args.kwonlyargs, kw_defaults) + ] return args def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: @@ -268,15 +289,28 @@ class CodeParser: method = self.parse_callable_details(stmt) return (method, True) if stmt.name == "__init__" else (method, False) + def get_base_classes(self): + """ + Returns the base classes of the custom component class. + """ + try: + bases = self.execute_and_inspect_classes(self.code) + except Exception as e: + # If the code cannot be executed, return an empty list + logger.exception(e) + bases = [] + raise e + return bases + def parse_classes(self, node: ast.ClassDef) -> None: """ Extracts "classes" from the code, including inheritance and init methods. """ - + bases = self.get_base_classes() or [ast.unparse(b) for b in node.bases] class_details = ClassCodeDetails( name=node.name, doc=ast.get_docstring(node), - bases=[ast.unparse(base) for base in node.bases], + bases=bases, attributes=[], methods=[], init=None, @@ -303,11 +337,25 @@ class CodeParser: Extracts global variables from the code. """ global_var = { - "targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets], + "targets": [ + t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets + ], "value": ast.unparse(node.value), } self.data["global_vars"].append(global_var) + def execute_and_inspect_classes(self, code: str): + custom_component_class = eval_custom_component_code(code) + custom_component = custom_component_class() + dunder_class = custom_component.__class__ + # Get the base classes at two levels of inheritance + bases = [] + for base in dunder_class.__bases__: + bases.append(base.__name__) + for bases_base in base.__bases__: + bases.append(bases_base.__name__) + return bases + def parse_code(self) -> Dict[str, Any]: """ Runs all parsing operations and returns the resulting data.