pykaleidoscope/chapter1and2.py

463 lines
15 KiB
Python

# Chapter 1 & 2 - Lexer and Parser
from collections import namedtuple
from enum import Enum
# Each token is a tuple of kind and value. kind is one of the enumeration values
# in TokenKind. value is the textual value of the token in the input.
class TokenKind(Enum):
EOF = -1
DEF = -2
EXTERN = -3
IDENTIFIER = -4
NUMBER = -5
OPERATOR = -6
Token = namedtuple('Token', 'kind value')
class Lexer(object):
"""Lexer for Kaleidoscope.
Initialize the lexer with a string buffer. tokens() returns a generator that
can be queried for tokens. The generator will emit an EOF token before
stopping.
"""
def __init__(self, buf):
assert len(buf) >= 1
self.buf = buf
self.pos = 0
self.lastchar = self.buf[0]
def tokens(self):
while self.lastchar:
# Skip whitespace
while self.lastchar.isspace():
self._advance()
# Identifier or keyword
if self.lastchar.isalpha():
id_str = ''
while self.lastchar.isalnum():
id_str += self.lastchar
self._advance()
if id_str == 'def':
yield Token(kind=TokenKind.DEF, value=id_str)
elif id_str == 'extern':
yield Token(kind=TokenKind.EXTERN, value=id_str)
else:
yield Token(kind=TokenKind.IDENTIFIER, value=id_str)
# Number
elif self.lastchar.isdigit() or self.lastchar == '.':
num_str = ''
while self.lastchar.isdigit() or self.lastchar == '.':
num_str += self.lastchar
self._advance()
yield Token(kind=TokenKind.NUMBER, value=num_str)
# Comment
elif self.lastchar == '#':
self._advance()
while self.lastchar and self.lastchar not in '\r\n':
self._advance()
elif self.lastchar:
# Some other char
yield Token(kind=TokenKind.OPERATOR, value=self.lastchar)
self._advance()
yield Token(kind=TokenKind.EOF, value='')
def _advance(self):
try:
self.pos += 1
self.lastchar = self.buf[self.pos]
except IndexError:
self.lastchar = ''
# AST hierarchy
class ASTNode(object):
def dump(self, indent=0):
raise NotImplementedError
class ExprAST(ASTNode):
pass
class NumberExprAST(ExprAST):
def __init__(self, val):
self.val = val
def dump(self, indent=0):
return '{0}{1}[{2}]'.format(
' ' * indent, self.__class__.__name__, self.val)
class VariableExprAST(ExprAST):
def __init__(self, name):
self.name = name
def dump(self, indent=0):
return '{0}{1}[{2}]'.format(
' ' * indent, self.__class__.__name__, self.name)
class BinaryExprAST(ExprAST):
def __init__(self, op, lhs, rhs):
self.op = op
self.lhs = lhs
self.rhs = rhs
def dump(self, indent=0):
s = '{0}{1}[{2}]\n'.format(
' ' * indent, self.__class__.__name__, self.op)
s += self.lhs.dump(indent + 2) + '\n'
s += self.rhs.dump(indent + 2)
return s
class CallExprAST(ExprAST):
def __init__(self, callee, args):
self.callee = callee
self.args = args
def dump(self, indent=0):
s = '{0}{1}[{2}]\n'.format(
' ' * indent, self.__class__.__name__, self.callee)
for arg in self.args:
s += arg.dump(indent + 2) + '\n'
return s[:-1] # snip out trailing '\n'
class PrototypeAST(ASTNode):
def __init__(self, name, argnames):
self.name = name
self.argnames = argnames
def dump(self, indent=0):
return '{0}{1}[{2}]'.format(
' ' * indent, self.__class__.__name__, ', '.join(self.argnames))
class FunctionAST(ASTNode):
def __init__(self, proto, body):
self.proto = proto
self.body = body
def dump(self, indent=0):
s = '{0}{1}[{2}]\n'.format(
' ' * indent, self.__class__.__name__, self.proto.dump())
s += self.body.dump(indent + 2) + '\n'
return s
class ParseError(Exception): pass
class Parser(object):
"""Parser for the Kaleidoscope language.
After the parser is created, invoke parse_toplevel multiple times to parse
Kaleidoscope source into an AST.
"""
def __init__(self):
self.token_generator = None
self.cur_tok = None
# toplevel ::= definition | external | expression | ';'
def parse_toplevel(self, buf):
"""Given a string, returns an AST node representing it."""
self.token_generator = Lexer(buf).tokens()
self.cur_tok = None
self._get_next_token()
if self.cur_tok.kind == TokenKind.EXTERN:
return self._parse_external()
elif self.cur_tok.kind == TokenKind.DEF:
return self._parse_definition()
elif self._cur_tok_is_operator(';'):
self._get_next_token()
return None
else:
return self._parse_toplevel_expression()
def _get_next_token(self):
self.cur_tok = next(self.token_generator)
def _match(self, expected_kind, expected_value=None):
"""Consume the current token; verify that it's of the expected kind.
If expected_kind == TokenKind.OPERATOR, verify the operator's value.
"""
if (expected_kind == TokenKind.OPERATOR and
not self._cur_tok_is_operator(expected_value)):
raise ParseError('Expected "{0}"'.format(expected_value))
elif expected_kind != self.cur_tok.kind:
raise ParseError('Expected "{0}"'.format(expected_kind))
self._get_next_token()
_precedence_map = {'<': 10, '+': 20, '-': 20, '*': 40}
def _cur_tok_precedence(self):
"""Get the operator precedence of the current token."""
try:
return Parser._precedence_map[self.cur_tok.value]
except KeyError:
return -1
def _cur_tok_is_operator(self, op):
"""Query whether the current token is the operator op"""
return (self.cur_tok.kind == TokenKind.OPERATOR and
self.cur_tok.value == op)
# identifierexpr
# ::= identifier
# ::= identifier '(' expression* ')'
def _parse_identifier_expr(self):
id_name = self.cur_tok.value
self._get_next_token()
# If followed by a '(' it's a call; otherwise, a simple variable ref.
if not self._cur_tok_is_operator('('):
return VariableExprAST(id_name)
self._get_next_token()
args = []
if not self._cur_tok_is_operator(')'):
while True:
args.append(self._parse_expression())
if self._cur_tok_is_operator(')'):
break
self._match(TokenKind.OPERATOR, ',')
self._get_next_token() # consume the ')'
return CallExprAST(id_name, args)
# numberexpr ::= number
def _parse_number_expr(self):
result = NumberExprAST(self.cur_tok.value)
self._get_next_token() # consume the number
return result
# parenexpr ::= '(' expression ')'
def _parse_paren_expr(self):
self._get_next_token() # consume the '('
expr = self._parse_expression()
self._match(TokenKind.OPERATOR, ')')
return expr
# primary
# ::= identifierexpr
# ::= numberexpr
# ::= parenexpr
def _parse_primary(self):
if self.cur_tok.kind == TokenKind.IDENTIFIER:
return self._parse_identifier_expr()
elif self.cur_tok.kind == TokenKind.NUMBER:
return self._parse_number_expr()
elif self._cur_tok_is_operator('('):
return self._parse_paren_expr()
else:
raise ParseError('Unknown token when expecting an expression')
# binoprhs ::= (<binop> primary)*
def _parse_binop_rhs(self, expr_prec, lhs):
"""Parse the right-hand-side of a binary expression.
expr_prec: minimal precedence to keep going (precedence climbing).
lhs: AST of the left-hand-side.
"""
while True:
cur_prec = self._cur_tok_precedence()
# If this is a binary operator with precedence lower than the
# currently parsed sub-expression, bail out. If it binds at least
# as tightly, keep going.
# Note that the precedence of non-operators is defined to be -1,
# so this condition handles cases when the expression ended.
if cur_prec < expr_prec:
return lhs
op = self.cur_tok.value
self._get_next_token() # consume the operator
rhs = self._parse_primary()
next_prec = self._cur_tok_precedence()
# There are three options:
# 1. next_prec > cur_prec: we need to make a recursive call
# 2. next_prec == cur_prec: no need for a recursive call, the next
# iteration of this loop will handle it.
# 3. next_prec < cur_prec: no need for a recursive call, combine
# lhs and the next iteration will immediately bail out.
if cur_prec < next_prec:
rhs = self._parse_binop_rhs(cur_prec + 1, rhs)
# Merge lhs/rhs
lhs = BinaryExprAST(op, lhs, rhs)
# expression ::= primary binoprhs
def _parse_expression(self):
lhs = self._parse_primary()
# Start with precedence 0 because we want to bind any operator to the
# expression at this point.
return self._parse_binop_rhs(0, lhs)
# prototype ::= id '(' id* ')'
def _parse_prototype(self):
name = self.cur_tok.value
self._match(TokenKind.IDENTIFIER)
self._match(TokenKind.OPERATOR, '(')
argnames = []
while self.cur_tok.kind == TokenKind.IDENTIFIER:
argnames.append(self.cur_tok.value)
self._get_next_token()
self._match(TokenKind.OPERATOR, ')')
return PrototypeAST(name, argnames)
# external ::= 'extern' prototype
def _parse_external(self):
self._get_next_token() # consume 'extern'
return self._parse_prototype()
# definition ::= 'def' prototype expression
def _parse_definition(self):
self._get_next_token() # consume 'def'
proto = self._parse_prototype()
expr = self._parse_expression()
return FunctionAST(proto, expr)
# toplevel ::= expression
def _parse_toplevel_expression(self):
expr = self._parse_expression()
# Anonymous function
return FunctionAST(PrototypeAST('', []), expr)
#---- Some unit tests ----#
import unittest
class TestLexer(unittest.TestCase):
def _assert_toks(self, toks, kinds):
"""Assert that the list of toks has the given kinds."""
self.assertEqual([t.kind.name for t in toks], kinds)
def test_lexer_simple_tokens_and_values(self):
l = Lexer('a+1')
toks = list(l.tokens())
self.assertEqual(toks[0], Token(TokenKind.IDENTIFIER, 'a'))
self.assertEqual(toks[1], Token(TokenKind.OPERATOR, '+'))
self.assertEqual(toks[2], Token(TokenKind.NUMBER, '1'))
self.assertEqual(toks[3], Token(TokenKind.EOF, ''))
l = Lexer('.1519')
toks = list(l.tokens())
self.assertEqual(toks[0], Token(TokenKind.NUMBER, '.1519'))
def test_token_kinds(self):
l = Lexer('10.1 def der extern foo (')
self._assert_toks(
list(l.tokens()),
['NUMBER', 'DEF', 'IDENTIFIER', 'EXTERN', 'IDENTIFIER',
'OPERATOR', 'EOF'])
l = Lexer('+- 1 2 22 22.4 a b2 C3d')
self._assert_toks(
list(l.tokens()),
['OPERATOR', 'OPERATOR', 'NUMBER', 'NUMBER', 'NUMBER', 'NUMBER',
'IDENTIFIER', 'IDENTIFIER', 'IDENTIFIER', 'EOF'])
def test_skip_whitespace_comments(self):
l = Lexer('''
def foo # this is a comment
# another comment
\t\t\t10
''')
self._assert_toks(
list(l.tokens()),
['DEF', 'IDENTIFIER', 'NUMBER', 'EOF'])
class TestParser(unittest.TestCase):
def _flatten(self, ast):
"""Test helper - flattens the AST into a sexpr-like nested list."""
if isinstance(ast, NumberExprAST):
return ['Number', ast.val]
elif isinstance(ast, VariableExprAST):
return ['Variable', ast.name]
elif isinstance(ast, BinaryExprAST):
return ['Binop', ast.op,
self._flatten(ast.lhs), self._flatten(ast.rhs)]
elif isinstance(ast, CallExprAST):
args = [self._flatten(arg) for arg in ast.args]
return ['Call', ast.callee, args]
elif isinstance(ast, PrototypeAST):
return ['Proto', ast.name, ' '.join(ast.argnames)]
elif isinstance(ast, FunctionAST):
return ['Function',
self._flatten(ast.proto), self._flatten(ast.body)]
else:
raise TypeError('unknown type in _flatten: {0}'.format(type(ast)))
def _assert_body(self, toplevel, expected):
"""Assert the flattened body of the given toplevel function"""
self.assertIsInstance(toplevel, FunctionAST)
self.assertEqual(self._flatten(toplevel.body), expected)
def test_basic(self):
ast = Parser().parse_toplevel('2')
self.assertIsInstance(ast, FunctionAST)
self.assertIsInstance(ast.body, NumberExprAST)
self.assertEqual(ast.body.val, '2')
def test_basic_with_flattening(self):
ast = Parser().parse_toplevel('2')
self._assert_body(ast, ['Number', '2'])
ast = Parser().parse_toplevel('foobar')
self._assert_body(ast, ['Variable', 'foobar'])
def test_expr_singleprec(self):
ast = Parser().parse_toplevel('2+ 3-4')
self._assert_body(ast,
['Binop',
'-', ['Binop', '+', ['Number', '2'], ['Number', '3']],
['Number', '4']])
def test_expr_multiprec(self):
ast = Parser().parse_toplevel('2+3*4-9')
self._assert_body(ast,
['Binop', '-',
['Binop', '+',
['Number', '2'],
['Binop', '*', ['Number', '3'], ['Number', '4']]],
['Number', '9']])
def test_expr_parens(self):
ast = Parser().parse_toplevel('2*(3-4)*7')
self._assert_body(ast,
['Binop', '*',
['Binop', '*',
['Number', '2'],
['Binop', '-', ['Number', '3'], ['Number', '4']]],
['Number', '7']])
def test_externals(self):
ast = Parser().parse_toplevel('extern sin(arg)')
self.assertEqual(self._flatten(ast), ['Proto', 'sin', 'arg'])
ast = Parser().parse_toplevel('extern Foobar(nom denom abom)')
self.assertEqual(self._flatten(ast),
['Proto', 'Foobar', 'nom denom abom'])
def test_funcdef(self):
ast = Parser().parse_toplevel('def foo(x) 1 + bar(x)')
self.assertEqual(self._flatten(ast),
['Function', ['Proto', 'foo', 'x'],
['Binop', '+',
['Number', '1'],
['Call', 'bar', [['Variable', 'x']]]]])
if __name__ == '__main__':
# We just have the lexer and parser here, no code generation yet. This is
# just a simple way to parse Kaleidoscope expressions and dump the AST.
p = Parser()
print(p.parse_toplevel('def bina(a b) a + b').dump())