diff --git a/src/backend/langflow/interface/custom/code_parser.py b/src/backend/langflow/interface/custom/code_parser.py index 7fb3b0184..606208624 100644 --- a/src/backend/langflow/interface/custom/code_parser.py +++ b/src/backend/langflow/interface/custom/code_parser.py @@ -2,39 +2,15 @@ import ast import inspect import traceback -from typing import Dict, Any, Optional, Type, Union +from typing import Dict, Any, List, Type, Union from fastapi import HTTPException -from pydantic import BaseModel +from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails class CodeSyntaxError(HTTPException): pass -class CallableCodeDetails(BaseModel): - """ - A dataclass for storing details about a callable. - """ - - name: str - doc: Optional[str] - args: list - body: list - return_type: Optional[str] - - -class ClassCodeDetails(BaseModel): - """ - A dataclass for storing details about a class. - """ - - name: str - doc: str - bases: list - attributes: list - methods: list - - class CodeParser: """ A parser for Python source code, extracting code details. @@ -79,13 +55,13 @@ class CodeParser: return tree - def parse_node(self, node: ast.AST) -> None: + def parse_node(self, node: Union[ast.stmt, ast.AST]) -> None: """ Parses an AST node and updates the data dictionary with the relevant information. """ - if handler := self.handlers.get(type(node)): - handler(node) + if handler := self.handlers.get(type(node)): # type: ignore + handler(node) # type: ignore def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None: """ @@ -117,13 +93,6 @@ class CodeParser: """ Extracts details from a single function or method node. """ - # func = { - # "name": node.name, - # "doc": ast.get_docstring(node), - # "args": [], - # "body": [], - # "return_type": ast.unparse(node.returns) if node.returns else None, - # } func = CallableCodeDetails( name=node.name, doc=ast.get_docstring(node), @@ -132,19 +101,58 @@ class CodeParser: return_type=ast.unparse(node.returns) if node.returns else None, ) - # Handle positional arguments with default values - defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + [ + func.args = self.parse_function_args(node) + func.body = self.parse_function_body(node) + + return func.dict() + + def parse_function_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: + """ + Parses the arguments of a function or method node. + """ + args = [] + + args += self.parse_positional_args(node) + args += self.parse_varargs(node) + args += self.parse_keyword_args(node) + args += self.parse_kwargs(node) + + return args + + def parse_positional_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: + """ + Parses the positional arguments of a function or method node. + """ + num_args = len(node.args.args) + num_defaults = len(node.args.defaults) + num_missing_defaults = num_args - num_defaults + missing_defaults = [None] * num_missing_defaults + default_values = [ ast.unparse(default) if default else None for default in node.args.defaults ] + defaults = missing_defaults + default_values - for arg, default in zip(node.args.args, defaults): - func.args.append(self.parse_arg(arg, default)) + args = [ + self.parse_arg(arg, default) + for arg, default in zip(node.args.args, defaults) + ] + return args + + def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: + """ + Parses the *args argument of a function or method node. + """ + args = [] - # Handle *args if node.args.vararg: - func.args.append(self.parse_arg(node.args.vararg, None)) + args.append(self.parse_arg(node.args.vararg, None)) - # Handle keyword-only arguments with default values + return args + + def parse_keyword_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: + """ + Parses the keyword-only arguments of a function or method node. + """ kw_defaults = [None] * ( len(node.args.kwonlyargs) - len(node.args.kw_defaults) ) + [ @@ -152,16 +160,28 @@ class CodeParser: for default in node.args.kw_defaults ] - for arg, default in zip(node.args.kwonlyargs, kw_defaults): - func.args.append(self.parse_arg(arg, default)) + args = [ + self.parse_arg(arg, default) + for arg, default in zip(node.args.kwonlyargs, kw_defaults) + ] + return args + + def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]: + """ + Parses the **kwargs argument of a function or method node. + """ + args = [] - # Handle **kwargs if node.args.kwarg: - func.args.append(self.parse_arg(node.args.kwarg, None)) + args.append(self.parse_arg(node.args.kwarg, None)) - for line in node.body: - func.body.append(ast.unparse(line)) - return func.dict() + return args + + def parse_function_body(self, node: ast.FunctionDef) -> List[str]: + """ + Parses the body of a function or method node. + """ + return [ast.unparse(line) for line in node.body] def parse_assign(self, stmt): """ @@ -196,29 +216,31 @@ class CodeParser: """ Extracts "classes" from the code, including inheritance and init methods. """ - class_dict = { - "name": node.name, - "doc": ast.get_docstring(node), - "bases": [ast.unparse(base) for base in node.bases], - "attributes": [], - "methods": [], - } + + class_details = ClassCodeDetails( + name=node.name, + doc=ast.get_docstring(node), + bases=[ast.unparse(base) for base in node.bases], + attributes=[], + methods=[], + init=None, + ) for stmt in node.body: if isinstance(stmt, ast.Assign): if attr := self.parse_assign(stmt): - class_dict["attributes"].append(attr) + class_details.attributes.append(attr) elif isinstance(stmt, ast.AnnAssign): if attr := self.parse_ann_assign(stmt): - class_dict["attributes"].append(attr) + class_details.attributes.append(attr) elif isinstance(stmt, ast.FunctionDef): method, is_init = self.parse_function_def(stmt) if is_init: - class_dict["init"] = method + class_details.init = method else: - class_dict["methods"].append(method) + class_details.methods.append(method) - self.data["classes"].append(class_dict) + self.data["classes"].append(class_details.dict()) def parse_global_vars(self, node: ast.Assign) -> None: """