diff --git a/translator.py b/translator.py index ac7338f..a857fc5 100644 --- a/translator.py +++ b/translator.py @@ -8,7 +8,12 @@ def translate(func): source = inspect.getsource(func) lines = list(source.splitlines()) - # skip first two lines + # skip first two lines + # assume they are + # @translate + # def blahblah(): + assert lines[0].lstrip().startswith('@') + assert lines[1].lstrip().startswith('def') first_line_len = len(lines[2]) indent = first_line_len - len(lines[2].lstrip()) source = '\n'.join(line[indent:] for line in lines[2:]) @@ -38,10 +43,18 @@ with self.ifelse(__CONDITION__) as _ifelse_: _return_template = 'self.ret(__RETURN__)' def load_template(string): + ''' + Since ast.parse() returns a ast.Module node, + it is more useful to trim the Module and get to the first item of body + ''' tree = ast.parse(string) # return a Module return tree.body[0] # get the first item of body class ExpandControlFlow(ast.NodeTransformer): + ''' + Expand control flow contructs. + These are the most tedious thing to do in llvm_cbuilder. + ''' def visit_If(self, node): condition = node.test mapping = { @@ -51,22 +64,29 @@ class ExpandControlFlow(ast.NodeTransformer): } ifelse = load_template(_if_else_template) - ifelse = NameReplacer(mapping).visit(ifelse) + ifelse = MacroExpander(mapping).visit(ifelse) newnode = ast.copy_location(ifelse, node) return self.generic_visit(newnode) def visit_Return(self, node): mapping = {'__RETURN__' : node.value} ret = load_template(_return_template) - repl = NameReplacer(mapping).visit(ret) + repl = MacroExpander(mapping).visit(ret) return ast.copy_location(repl, node) -class NameReplacer(ast.NodeTransformer): +class MacroExpander(ast.NodeTransformer): def __init__(self, mapping): self.mapping = mapping def visit_With(self, node): - if (len(node.body)==1 + ''' + Expand X in the following: + with blah: + X + Nothing should go before or after X. + X must be a list of nodes. + ''' + if (len(node.body)==1 # the body of and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Name)): try: @@ -76,9 +96,12 @@ class NameReplacer(ast.NodeTransformer): else: old = node.body[0] node.body = repl - return self.generic_visit(node) + return self.generic_visit(node) # recursively apply expand all macros def visit_Name(self, node): + ''' + Expand all Name node to simple value + ''' if type(node.ctx) is ast.Load: try: repl = self.mapping.pop(node.id)