diff --git a/src/backend/base/langflow/custom/code_parser/code_parser.py b/src/backend/base/langflow/custom/code_parser/code_parser.py index 705e779f4..9dc736dc0 100644 --- a/src/backend/base/langflow/custom/code_parser/code_parser.py +++ b/src/backend/base/langflow/custom/code_parser/code_parser.py @@ -1,10 +1,9 @@ import ast import inspect -import operator import traceback from typing import Any, Dict, List, Type, Union -from cachetools import TTLCache, cachedmethod, keys +from cachetools import TTLCache, keys from fastapi import HTTPException from loguru import logger @@ -22,6 +21,32 @@ def get_data_type(): return Data +def find_class_ast_node(class_obj): + """Finds the AST node corresponding to the given class object.""" + # Get the source file where the class is defined + source_file = inspect.getsourcefile(class_obj) + if not source_file: + return None, [] + + # Read the source code from the file + with open(source_file, "r") as file: + source_code = file.read() + + # Parse the source code into an AST + tree = ast.parse(source_code) + + # Search for the class definition node in the AST + class_node = None + import_nodes = [] + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__: + class_node = node + elif isinstance(node, (ast.Import, ast.ImportFrom)): + import_nodes.append(node) + + return class_node, import_nodes + + def imports_key(*args, **kwargs): imports = kwargs.pop("imports") key = keys.methodkey(*args, **kwargs) @@ -114,7 +139,7 @@ class CodeParser: arg_dict["type"] = ast.unparse(arg.annotation) return arg_dict - @cachedmethod(operator.attrgetter("cache")) + # @cachedmethod(operator.attrgetter("cache")) def construct_eval_env(self, return_type_str: str, imports) -> dict: """ Constructs an evaluation environment with the necessary imports for the return type, @@ -136,7 +161,6 @@ class CodeParser: exec(f"import {module} as {alias if alias else module}", eval_env) return eval_env - @cachedmethod(cache=operator.attrgetter("cache")) def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]: """ Extracts details from a single function or method node. @@ -157,7 +181,7 @@ class CodeParser: doc=ast.get_docstring(node), args=self.parse_function_args(node), body=self.parse_function_body(node), - return_type=return_type or get_data_type(), + return_type=return_type, has_return=self.parse_return_statement(node), ) @@ -297,7 +321,6 @@ class CodeParser: bases = self.execute_and_inspect_classes(self.code) except Exception as e: # If the code cannot be executed, return an empty list - logger.debug(e) bases = [] raise e return bases @@ -306,16 +329,37 @@ class CodeParser: """ Extracts "classes" from the code, including inheritance and init methods. """ - bases = self.get_base_classes() or [ast.unparse(b) for b in node.bases] + if node.name in ["CustomComponent", "Component", "BaseComponent"]: + return + bases = self.get_base_classes() + nodes = [] + for base in bases: + if base.__name__ == node.name or base.__name__ in ["CustomComponent", "Component", "BaseComponent"]: + continue + try: + class_node, import_nodes = find_class_ast_node(base) + if class_node is None: + continue + for import_node in import_nodes: + self.parse_imports(import_node) + nodes.append(class_node) + except Exception as exc: + logger.error(f"Error finding base class node: {exc}") + pass + nodes.insert(0, node) class_details = ClassCodeDetails( name=node.name, doc=ast.get_docstring(node), - bases=bases, + bases=[b.__name__ for b in bases], attributes=[], methods=[], init=None, ) + for node in nodes: + self.process_class_node(node, class_details) + self.data["classes"].append(class_details.model_dump()) + def process_class_node(self, node, class_details): for stmt in node.body: if isinstance(stmt, ast.Assign): if attr := self.parse_assign(stmt): @@ -330,8 +374,6 @@ class CodeParser: else: class_details.methods.append(method) - self.data["classes"].append(class_details.model_dump()) - def parse_global_vars(self, node: ast.Assign) -> None: """ Extracts global variables from the code. @@ -349,9 +391,9 @@ class CodeParser: # Get the base classes at two levels of inheritance bases = [] for base in dunder_class.__bases__: - bases.append(base.__name__) + bases.append(base) for bases_base in base.__bases__: - bases.append(bases_base.__name__) + bases.append(bases_base) return bases def parse_code(self) -> Dict[str, Any]: