pykaleidoscope/chapter3.py
2015-01-28 11:43:35 -08:00

405 lines
13 KiB
Python

# Chapter 3 - Code generation to LLVM IR
from collections import namedtuple
from enum import Enum
import llvmlite.ir as ir
# Each token is a tuple of kind and value. kind is one of the enumeration values
# in TokenKind. value is the textual value of the token in the input.
class TokenKind(Enum):
EOF = -1
DEF = -2
EXTERN = -3
IDENTIFIER = -4
NUMBER = -5
OPERATOR = -6
Token = namedtuple('Token', 'kind value')
class Lexer(object):
"""Lexer for Kaleidoscope.
Initialize the lexer with a string buffer. tokens() returns a generator that
can be queried for tokens. The generator will emit an EOF token before
stopping.
"""
def __init__(self, buf):
assert len(buf) >= 1
self.buf = buf
self.pos = 0
self.lastchar = self.buf[0]
def tokens(self):
while self.lastchar:
# Skip whitespace
while self.lastchar.isspace():
self._advance()
# Identifier or keyword
if self.lastchar.isalpha():
id_str = ''
while self.lastchar.isalnum():
id_str += self.lastchar
self._advance()
if id_str == 'def':
yield Token(kind=TokenKind.DEF, value=id_str)
elif id_str == 'extern':
yield Token(kind=TokenKind.EXTERN, value=id_str)
else:
yield Token(kind=TokenKind.IDENTIFIER, value=id_str)
# Number
elif self.lastchar.isdigit() or self.lastchar == '.':
num_str = ''
while self.lastchar.isdigit() or self.lastchar == '.':
num_str += self.lastchar
self._advance()
yield Token(kind=TokenKind.NUMBER, value=num_str)
# Comment
elif self.lastchar == '#':
self._advance()
while self.lastchar and self.lastchar not in '\r\n':
self._advance()
elif self.lastchar:
# Some other char
yield Token(kind=TokenKind.OPERATOR, value=self.lastchar)
self._advance()
yield Token(kind=TokenKind.EOF, value='')
def _advance(self):
try:
self.pos += 1
self.lastchar = self.buf[self.pos]
except IndexError:
self.lastchar = ''
# AST hierarchy
class ASTNode(object):
def dump(self, indent=0):
raise NotImplementedError
class ExprAST(ASTNode):
pass
class NumberExprAST(ExprAST):
def __init__(self, val):
self.val = val
def dump(self, indent=0):
return '{0}{1}[{2}]'.format(
' ' * indent, self.__class__.__name__, self.val)
class VariableExprAST(ExprAST):
def __init__(self, name):
self.name = name
def dump(self, indent=0):
return '{0}{1}[{2}]'.format(
' ' * indent, self.__class__.__name__, self.name)
class BinaryExprAST(ExprAST):
def __init__(self, op, lhs, rhs):
self.op = op
self.lhs = lhs
self.rhs = rhs
def dump(self, indent=0):
s = '{0}{1}[{2}]\n'.format(
' ' * indent, self.__class__.__name__, self.op)
s += self.lhs.dump(indent + 2) + '\n'
s += self.rhs.dump(indent + 2)
return s
class CallExprAST(ExprAST):
def __init__(self, callee, args):
self.callee = callee
self.args = args
def dump(self, indent=0):
s = '{0}{1}[{2}]\n'.format(
' ' * indent, self.__class__.__name__, self.callee)
for arg in self.args:
s += arg.dump(indent + 2) + '\n'
return s[:-1] # snip out trailing '\n'
class PrototypeAST(ASTNode):
def __init__(self, name, argnames):
self.name = name
self.argnames = argnames
def dump(self, indent=0):
return '{0}{1}[{2}]'.format(
' ' * indent, self.__class__.__name__, ', '.join(self.argnames))
class FunctionAST(ASTNode):
def __init__(self, proto, body):
self.proto = proto
self.body = body
def dump(self, indent=0):
s = '{0}{1}[{2}]\n'.format(
' ' * indent, self.__class__.__name__, self.proto.dump())
s += self.body.dump(indent + 2) + '\n'
return s
class ParseError(Exception): pass
class Parser(object):
def __init__(self, buf):
self.token_generator = Lexer(buf).tokens()
self.cur_tok = None
self._get_next_token()
# toplevel ::= definition | external | expression | ';'
def parse_toplevel(self):
if self.cur_tok.kind == TokenKind.EXTERN:
return self._parse_external()
elif self.cur_tok.kind == TokenKind.DEF:
return self._parse_definition()
elif self._cur_tok_is_operator(';'):
self._get_next_token()
return None
else:
return self._parse_toplevel_expression()
def _get_next_token(self):
self.cur_tok = next(self.token_generator)
_precedence_map = {'<': 10, '+': 20, '-': 20, '*': 40}
def _cur_tok_precedence(self):
"""Get the operator precedence of the current token."""
try:
return Parser._precedence_map[self.cur_tok.value]
except KeyError:
return -1
def _cur_tok_is_operator(self, op):
"""Query whether the current token is the operator op"""
return (self.cur_tok.kind == TokenKind.OPERATOR and
self.cur_tok.value == op)
# identifierexpr
# ::= identifier
# ::= identifier '(' expression* ')'
def _parse_identifier_expr(self):
id_name = self.cur_tok.value
self._get_next_token()
# If followed by a '(' it's a call; otherwise, a simple variable ref.
if not self._cur_tok_is_operator('('):
return VariableExprAST(id_name)
self._get_next_token()
args = []
if not self._cur_tok_is_operator(')'):
while True:
args.append(self._parse_expression())
if self._cur_tok_is_operator(')'):
break
if not self._cur_tok_is_operator(','):
raise ParseError('Expected ")" or "," in argument list')
self._get_next_token()
self._get_next_token() # consume the ')'
return CallExprAST(id_name, args)
# numberexpr ::= number
def _parse_number_expr(self):
result = NumberExprAST(self.cur_tok.value)
self._get_next_token() # consume the number
return result
# parenexpr ::= '(' expression ')'
def _parse_paren_expr(self):
self._get_next_token() # consume the '('
expr = self._parse_expression()
if not self._cur_tok_is_operator(')'):
raise ParseError('Expected ")"')
self._get_next_token() # consume the ')'
return expr
# primary
# ::= identifierexpr
# ::= numberexpr
# ::= parenexpr
def _parse_primary(self):
if self.cur_tok.kind == TokenKind.IDENTIFIER:
return self._parse_identifier_expr()
elif self.cur_tok.kind == TokenKind.NUMBER:
return self._parse_number_expr()
elif self._cur_tok_is_operator('('):
return self._parse_paren_expr()
else:
raise ParseError('Unknown token when expecting an expression')
# binoprhs ::= (<binop> primary)*
def _parse_binop_rhs(self, expr_prec, lhs):
"""Parse the right-hand-side of a binary expression.
expr_prec: minimal precedence to keep going (precedence climbing).
lhs: AST of the left-hand-side.
"""
while True:
cur_prec = self._cur_tok_precedence()
# If this is a binary operator with precedence lower than the
# currently parsed sub-expression, bail out. If it binds at least
# as tightly, keep going.
# Note that the precedence of non-operators is defined to be -1,
# so this condition handles cases when the expression ended.
if cur_prec < expr_prec:
return lhs
op = self.cur_tok.value
self._get_next_token() # consume the operator
rhs = self._parse_primary()
next_prec = self._cur_tok_precedence()
# There are three options:
# 1. next_prec > cur_prec: we need to make a recursive call
# 2. next_prec == cur_prec: no need for a recursive call, the next
# iteration of this loop will handle it.
# 3. next_prec < cur_prec: no need for a recursive call, combine
# lhs and the next iteration will immediately bail out.
if cur_prec < next_prec:
rhs = self._parse_binop_rhs(cur_prec + 1, rhs)
# Merge lhs/rhs
lhs = BinaryExprAST(op, lhs, rhs)
# expression ::= primary binoprhs
def _parse_expression(self):
lhs = self._parse_primary()
# Start with precedence 0 because we want to bind any operator to the
# expression at this point.
return self._parse_binop_rhs(0, lhs)
# prototype ::= id '(' id* ')'
def _parse_prototype(self):
if self.cur_tok.kind != TokenKind.IDENTIFIER:
raise ParseError('Expected function name in prototype')
name = self.cur_tok.value
self._get_next_token() # consume the name
if not self._cur_tok_is_operator('('):
raise ParseError('Expected "(" in prototype')
self._get_next_token() # consume '('
argnames = []
while self.cur_tok.kind == TokenKind.IDENTIFIER:
argnames.append(self.cur_tok.value)
self._get_next_token()
if not self._cur_tok_is_operator(')'):
raise ParseError('Expected ")" in prototype')
self._get_next_token() # consume ')'
return PrototypeAST(name, argnames)
# external ::= 'extern' prototype
def _parse_external(self):
self._get_next_token() # consume 'extern'
return self._parse_prototype()
# definition ::= 'def' prototype expression
def _parse_definition(self):
self._get_next_token() # consume 'def'
proto = self._parse_prototype()
expr = self._parse_expression()
return FunctionAST(proto, expr)
# toplevel ::= expression
def _parse_toplevel_expression(self):
expr = self._parse_expression()
# Anonymous function
return FunctionAST(PrototypeAST('', []), expr)
class CodegenError(Exception): pass
class LLVMCodeGenerator(object):
def __init__(self):
"""Initialize the code generator.
This creates a new LLVM module into which code is generated. The
generate_code() method can be called multiple times. It adds the code
generated for this node into the module, and returns the module.
"""
self.module = ir.Module()
# Current IR builder.
self.builder = None
# Manages a symbol table while a function is being codegen'd. Maps var
# names to ir.Value.
self.func_symtab = {}
def generate_code(self, node):
assert isinstance(node, (PrototypeAST, FunctionAST))
self._codegen(node)
return self.module
def _codegen(self, node):
"""Node visitor. Dispathces upon node type.
For AST node of class Foo, calls self._codegen_Foo. Each visitor is
expected to return a llvmlite.ir.Value.
"""
method = 'visit_' + node.__class__.__name__
return getattr(self, method)(node)
def _codegen_NumberExprAST(self, node):
return self.builder.constant(ir.DoubleType(), float(node.val))
def _codegen_VariableExprAST(self, node):
return self.symtab[node.name]
def _codegen_PrototypeAST(self, node):
# Create a function type
func_ty = ir.FunctionType(ir.DoubleType(),
[ir.DoubleType() * len(node.argnames)])
# If a function with this name already exists in the module...
if node.name in self.module.globals:
# We only allow the case in which a declaration exists and now the
# function is defined (or redeclared) with the same number of args.
existing_func = self.module[node.name]
if not isinstance(existing_func, ir.Function):
raise CodegenError('Function/Global name collision', node.name)
if not existing_func.is_declaration():
raise CodegenError('Redifinition of {0}'.format(node.name))
if len(existing_func.function_type.args) != len(func_ty.args):
raise CodegenError(
'Redifinition with different number of arguments')
func = self.module.globals[node.name]
else:
# Otherwise create a new function
func = ir.Function(self.module, func_ty, node.name)
# Set function argument names from AST
for i in range(len(func.args)):
func.args[i].name = node.argnames[i]
return func
def _codegen_FunctionAST(self, node):
# Create the function skeleton from the prototype.
func = self._codegen(node.proto)
# Create the entry BB in the function and set the builder to it.
bb_entry = func.append_basic_block('entry')
self.builder = ir.IRBuilder(bb_entry)
# We're going to generate the function body now. Reset the symbol table.
self.func_symtab = {}
retval = self._codegen(node.body)
irbuilder.ret(retval)
return func
if __name__ == '__main__':
pass