refactor: Update code_parser.py to improve class node processing and base class handling
This commit is contained in:
parent
8f0ca52e9c
commit
ed391c010d
1 changed files with 54 additions and 12 deletions
|
|
@ -1,10 +1,9 @@
|
|||
import ast
|
||||
import inspect
|
||||
import operator
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
|
||||
from cachetools import TTLCache, cachedmethod, keys
|
||||
from cachetools import TTLCache, keys
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -22,6 +21,32 @@ def get_data_type():
|
|||
return Data
|
||||
|
||||
|
||||
def find_class_ast_node(class_obj):
|
||||
"""Finds the AST node corresponding to the given class object."""
|
||||
# Get the source file where the class is defined
|
||||
source_file = inspect.getsourcefile(class_obj)
|
||||
if not source_file:
|
||||
return None, []
|
||||
|
||||
# Read the source code from the file
|
||||
with open(source_file, "r") as file:
|
||||
source_code = file.read()
|
||||
|
||||
# Parse the source code into an AST
|
||||
tree = ast.parse(source_code)
|
||||
|
||||
# Search for the class definition node in the AST
|
||||
class_node = None
|
||||
import_nodes = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__:
|
||||
class_node = node
|
||||
elif isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
import_nodes.append(node)
|
||||
|
||||
return class_node, import_nodes
|
||||
|
||||
|
||||
def imports_key(*args, **kwargs):
|
||||
imports = kwargs.pop("imports")
|
||||
key = keys.methodkey(*args, **kwargs)
|
||||
|
|
@ -114,7 +139,7 @@ class CodeParser:
|
|||
arg_dict["type"] = ast.unparse(arg.annotation)
|
||||
return arg_dict
|
||||
|
||||
@cachedmethod(operator.attrgetter("cache"))
|
||||
# @cachedmethod(operator.attrgetter("cache"))
|
||||
def construct_eval_env(self, return_type_str: str, imports) -> dict:
|
||||
"""
|
||||
Constructs an evaluation environment with the necessary imports for the return type,
|
||||
|
|
@ -136,7 +161,6 @@ class CodeParser:
|
|||
exec(f"import {module} as {alias if alias else module}", eval_env)
|
||||
return eval_env
|
||||
|
||||
@cachedmethod(cache=operator.attrgetter("cache"))
|
||||
def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts details from a single function or method node.
|
||||
|
|
@ -157,7 +181,7 @@ class CodeParser:
|
|||
doc=ast.get_docstring(node),
|
||||
args=self.parse_function_args(node),
|
||||
body=self.parse_function_body(node),
|
||||
return_type=return_type or get_data_type(),
|
||||
return_type=return_type,
|
||||
has_return=self.parse_return_statement(node),
|
||||
)
|
||||
|
||||
|
|
@ -297,7 +321,6 @@ class CodeParser:
|
|||
bases = self.execute_and_inspect_classes(self.code)
|
||||
except Exception as e:
|
||||
# If the code cannot be executed, return an empty list
|
||||
logger.debug(e)
|
||||
bases = []
|
||||
raise e
|
||||
return bases
|
||||
|
|
@ -306,16 +329,37 @@ class CodeParser:
|
|||
"""
|
||||
Extracts "classes" from the code, including inheritance and init methods.
|
||||
"""
|
||||
bases = self.get_base_classes() or [ast.unparse(b) for b in node.bases]
|
||||
if node.name in ["CustomComponent", "Component", "BaseComponent"]:
|
||||
return
|
||||
bases = self.get_base_classes()
|
||||
nodes = []
|
||||
for base in bases:
|
||||
if base.__name__ == node.name or base.__name__ in ["CustomComponent", "Component", "BaseComponent"]:
|
||||
continue
|
||||
try:
|
||||
class_node, import_nodes = find_class_ast_node(base)
|
||||
if class_node is None:
|
||||
continue
|
||||
for import_node in import_nodes:
|
||||
self.parse_imports(import_node)
|
||||
nodes.append(class_node)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error finding base class node: {exc}")
|
||||
pass
|
||||
nodes.insert(0, node)
|
||||
class_details = ClassCodeDetails(
|
||||
name=node.name,
|
||||
doc=ast.get_docstring(node),
|
||||
bases=bases,
|
||||
bases=[b.__name__ for b in bases],
|
||||
attributes=[],
|
||||
methods=[],
|
||||
init=None,
|
||||
)
|
||||
for node in nodes:
|
||||
self.process_class_node(node, class_details)
|
||||
self.data["classes"].append(class_details.model_dump())
|
||||
|
||||
def process_class_node(self, node, class_details):
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.Assign):
|
||||
if attr := self.parse_assign(stmt):
|
||||
|
|
@ -330,8 +374,6 @@ class CodeParser:
|
|||
else:
|
||||
class_details.methods.append(method)
|
||||
|
||||
self.data["classes"].append(class_details.model_dump())
|
||||
|
||||
def parse_global_vars(self, node: ast.Assign) -> None:
|
||||
"""
|
||||
Extracts global variables from the code.
|
||||
|
|
@ -349,9 +391,9 @@ class CodeParser:
|
|||
# Get the base classes at two levels of inheritance
|
||||
bases = []
|
||||
for base in dunder_class.__bases__:
|
||||
bases.append(base.__name__)
|
||||
bases.append(base)
|
||||
for bases_base in base.__bases__:
|
||||
bases.append(bases_base.__name__)
|
||||
bases.append(bases_base)
|
||||
return bases
|
||||
|
||||
def parse_code(self) -> Dict[str, Any]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue