refactor: Update code_parser.py to improve class node processing and base class handling

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-20 15:50:19 -03:00
commit ed391c010d

View file

@ -1,10 +1,9 @@
import ast
import inspect
import operator
import traceback
from typing import Any, Dict, List, Type, Union
from cachetools import TTLCache, cachedmethod, keys
from cachetools import TTLCache, keys
from fastapi import HTTPException
from loguru import logger
@ -22,6 +21,32 @@ def get_data_type():
return Data
def find_class_ast_node(class_obj):
"""Finds the AST node corresponding to the given class object."""
# Get the source file where the class is defined
source_file = inspect.getsourcefile(class_obj)
if not source_file:
return None, []
# Read the source code from the file
with open(source_file, "r") as file:
source_code = file.read()
# Parse the source code into an AST
tree = ast.parse(source_code)
# Search for the class definition node in the AST
class_node = None
import_nodes = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__:
class_node = node
elif isinstance(node, (ast.Import, ast.ImportFrom)):
import_nodes.append(node)
return class_node, import_nodes
def imports_key(*args, **kwargs):
imports = kwargs.pop("imports")
key = keys.methodkey(*args, **kwargs)
@ -114,7 +139,7 @@ class CodeParser:
arg_dict["type"] = ast.unparse(arg.annotation)
return arg_dict
@cachedmethod(operator.attrgetter("cache"))
# @cachedmethod(operator.attrgetter("cache"))
def construct_eval_env(self, return_type_str: str, imports) -> dict:
"""
Constructs an evaluation environment with the necessary imports for the return type,
@ -136,7 +161,6 @@ class CodeParser:
exec(f"import {module} as {alias if alias else module}", eval_env)
return eval_env
@cachedmethod(cache=operator.attrgetter("cache"))
def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:
"""
Extracts details from a single function or method node.
@ -157,7 +181,7 @@ class CodeParser:
doc=ast.get_docstring(node),
args=self.parse_function_args(node),
body=self.parse_function_body(node),
return_type=return_type or get_data_type(),
return_type=return_type,
has_return=self.parse_return_statement(node),
)
@ -297,7 +321,6 @@ class CodeParser:
bases = self.execute_and_inspect_classes(self.code)
except Exception as e:
# If the code cannot be executed, return an empty list
logger.debug(e)
bases = []
raise e
return bases
@ -306,16 +329,37 @@ class CodeParser:
"""
Extracts "classes" from the code, including inheritance and init methods.
"""
bases = self.get_base_classes() or [ast.unparse(b) for b in node.bases]
if node.name in ["CustomComponent", "Component", "BaseComponent"]:
return
bases = self.get_base_classes()
nodes = []
for base in bases:
if base.__name__ == node.name or base.__name__ in ["CustomComponent", "Component", "BaseComponent"]:
continue
try:
class_node, import_nodes = find_class_ast_node(base)
if class_node is None:
continue
for import_node in import_nodes:
self.parse_imports(import_node)
nodes.append(class_node)
except Exception as exc:
logger.error(f"Error finding base class node: {exc}")
pass
nodes.insert(0, node)
class_details = ClassCodeDetails(
name=node.name,
doc=ast.get_docstring(node),
bases=bases,
bases=[b.__name__ for b in bases],
attributes=[],
methods=[],
init=None,
)
for node in nodes:
self.process_class_node(node, class_details)
self.data["classes"].append(class_details.model_dump())
def process_class_node(self, node, class_details):
for stmt in node.body:
if isinstance(stmt, ast.Assign):
if attr := self.parse_assign(stmt):
@ -330,8 +374,6 @@ class CodeParser:
else:
class_details.methods.append(method)
self.data["classes"].append(class_details.model_dump())
def parse_global_vars(self, node: ast.Assign) -> None:
"""
Extracts global variables from the code.
@ -349,9 +391,9 @@ class CodeParser:
# Get the base classes at two levels of inheritance
bases = []
for base in dunder_class.__bases__:
bases.append(base.__name__)
bases.append(base)
for bases_base in base.__bases__:
bases.append(bases_base.__name__)
bases.append(bases_base)
return bases
def parse_code(self) -> Dict[str, Any]: