artiq/artiq/py2llvm/ast_body.py

540 lines
21 KiB
Python

import ast
import llvmlite_or1k.ir as ll
from artiq.py2llvm import values, base_types, fractions, lists, iterators
from artiq.py2llvm.tools import is_terminated
_ast_unops = {
ast.Invert: "o_inv",
ast.Not: "o_not",
ast.UAdd: "o_pos",
ast.USub: "o_neg"
}
_ast_binops = {
ast.Add: values.operators.add,
ast.Sub: values.operators.sub,
ast.Mult: values.operators.mul,
ast.Div: values.operators.truediv,
ast.FloorDiv: values.operators.floordiv,
ast.Mod: values.operators.mod,
ast.Pow: values.operators.pow,
ast.LShift: values.operators.lshift,
ast.RShift: values.operators.rshift,
ast.BitOr: values.operators.or_,
ast.BitXor: values.operators.xor,
ast.BitAnd: values.operators.and_
}
_ast_cmps = {
ast.Eq: values.operators.eq,
ast.NotEq: values.operators.ne,
ast.Lt: values.operators.lt,
ast.LtE: values.operators.le,
ast.Gt: values.operators.gt,
ast.GtE: values.operators.ge
}
class Visitor:
def __init__(self, runtime, ns, builder=None):
self.runtime = runtime
self.ns = ns
self.builder = builder
self._break_stack = []
self._continue_stack = []
self._active_exception_stack = []
self._exception_level_stack = [0]
# builder can be None for visit_expression
def visit_expression(self, node):
method = "_visit_expr_" + node.__class__.__name__
try:
visitor = getattr(self, method)
except AttributeError:
raise NotImplementedError("Unsupported node '{}' in expression"
.format(node.__class__.__name__))
return visitor(node)
def _visit_expr_Name(self, node):
try:
r = self.ns[node.id]
except KeyError:
raise NameError("Name '{}' is not defined".format(node.id))
return r
def _visit_expr_NameConstant(self, node):
v = node.value
if v is None:
r = base_types.VNone()
elif isinstance(v, bool):
r = base_types.VBool()
else:
raise NotImplementedError
if self.builder is not None:
r.set_const_value(self.builder, v)
return r
def _visit_expr_Num(self, node):
n = node.n
if isinstance(n, int):
if abs(n) < 2**31:
r = base_types.VInt()
else:
r = base_types.VInt(64)
elif isinstance(n, float):
r = base_types.VFloat()
else:
raise NotImplementedError
if self.builder is not None:
r.set_const_value(self.builder, n)
return r
def _visit_expr_UnaryOp(self, node):
value = self.visit_expression(node.operand)
return getattr(value, _ast_unops[type(node.op)])(self.builder)
def _visit_expr_BinOp(self, node):
return _ast_binops[type(node.op)](self.visit_expression(node.left),
self.visit_expression(node.right),
self.builder)
def _visit_expr_BoolOp(self, node):
if self.builder is not None:
initial_block = self.builder.basic_block
function = initial_block.function
merge_block = function.append_basic_block("b_merge")
test_blocks = []
test_values = []
for i, value in enumerate(node.values):
if self.builder is not None:
test_block = function.append_basic_block("b_{}_test".format(i))
test_blocks.append(test_block)
self.builder.position_at_end(test_block)
test_values.append(self.visit_expression(value))
result = test_values[0].new()
for value in test_values[1:]:
result.merge(value)
if self.builder is not None:
self.builder.position_at_end(initial_block)
result.alloca(self.builder, "b_result")
self.builder.branch(test_blocks[0])
next_test_blocks = test_blocks[1:]
next_test_blocks.append(None)
for block, next_block, value in zip(test_blocks,
next_test_blocks,
test_values):
self.builder.position_at_end(block)
bval = value.o_bool(self.builder)
result.auto_store(self.builder,
value.auto_load(self.builder))
if next_block is None:
self.builder.branch(merge_block)
else:
if isinstance(node.op, ast.Or):
self.builder.cbranch(bval.auto_load(self.builder),
merge_block,
next_block)
elif isinstance(node.op, ast.And):
self.builder.cbranch(bval.auto_load(self.builder),
next_block,
merge_block)
else:
raise NotImplementedError
self.builder.position_at_end(merge_block)
return result
def _visit_expr_Compare(self, node):
comparisons = []
old_comparator = self.visit_expression(node.left)
for op, comparator_a in zip(node.ops, node.comparators):
comparator = self.visit_expression(comparator_a)
comparison = _ast_cmps[type(op)](old_comparator, comparator,
self.builder)
comparisons.append(comparison)
old_comparator = comparator
r = comparisons[0]
for comparison in comparisons[1:]:
r = values.operators.and_(r, comparison)
return r
def _visit_expr_Call(self, node):
fn = node.func.id
if fn in {"bool", "int", "int64", "round", "round64", "float", "len"}:
value = self.visit_expression(node.args[0])
return getattr(value, "o_" + fn)(self.builder)
elif fn == "Fraction":
r = fractions.VFraction()
if self.builder is not None:
numerator = self.visit_expression(node.args[0])
denominator = self.visit_expression(node.args[1])
r.set_value_nd(self.builder, numerator, denominator)
return r
elif fn == "range":
return iterators.IRange(
self.builder,
[self.visit_expression(arg) for arg in node.args])
elif fn == "syscall":
return self.runtime.build_syscall(
node.args[0].s,
[self.visit_expression(expr) for expr in node.args[1:]],
self.builder)
else:
raise NameError("Function '{}' is not defined".format(fn))
def _visit_expr_Attribute(self, node):
value = self.visit_expression(node.value)
return value.o_getattr(node.attr, self.builder)
def _visit_expr_List(self, node):
elts = [self.visit_expression(elt) for elt in node.elts]
if elts:
el_type = elts[0].new()
for elt in elts[1:]:
el_type.merge(elt)
else:
el_type = base_types.VNone()
count = len(elts)
r = lists.VList(el_type, count)
r.elts = elts
return r
def _visit_expr_ListComp(self, node):
if len(node.generators) != 1:
raise NotImplementedError
generator = node.generators[0]
if not isinstance(generator, ast.comprehension):
raise NotImplementedError
if not isinstance(generator.target, ast.Name):
raise NotImplementedError
target = generator.target.id
if not isinstance(generator.iter, ast.Call):
raise NotImplementedError
if not isinstance(generator.iter.func, ast.Name):
raise NotImplementedError
if generator.iter.func.id != "range":
raise NotImplementedError
if len(generator.iter.args) != 1:
raise NotImplementedError
if not isinstance(generator.iter.args[0], ast.Num):
raise NotImplementedError
count = generator.iter.args[0].n
# Prevent incorrect use of the generator target, if it is defined in
# the local function namespace.
if target in self.ns:
old_target_val = self.ns[target]
del self.ns[target]
else:
old_target_val = None
elt = self.visit_expression(node.elt)
if old_target_val is not None:
self.ns[target] = old_target_val
el_type = elt.new()
r = lists.VList(el_type, count)
r.elt = elt
return r
def _visit_expr_Subscript(self, node):
value = self.visit_expression(node.value)
if isinstance(node.slice, ast.Index):
index = self.visit_expression(node.slice.value)
else:
raise NotImplementedError
return value.o_subscript(index, self.builder)
def visit_statements(self, stmts):
for node in stmts:
node_type = node.__class__.__name__
method = "_visit_stmt_" + node_type
try:
visitor = getattr(self, method)
except AttributeError:
raise NotImplementedError("Unsupported node '{}' in statement"
.format(node_type))
visitor(node)
if node_type in ("Return", "Break", "Continue"):
break
def _bb_terminated(self):
return is_terminated(self.builder.basic_block)
def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value)
if isinstance(node.value, ast.List):
if len(node.targets) > 1:
raise NotImplementedError
target = self.visit_expression(node.targets[0])
target.set_count(self.builder, val.alloc_count)
for i, elt in enumerate(val.elts):
idx = base_types.VInt()
idx.set_const_value(self.builder, i)
target.o_subscript(idx, self.builder).set_value(self.builder,
elt)
elif isinstance(node.value, ast.ListComp):
if len(node.targets) > 1:
raise NotImplementedError
target = self.visit_expression(node.targets[0])
target.set_count(self.builder, val.alloc_count)
i = base_types.VInt()
i.alloca(self.builder)
i.auto_store(self.builder, ll.Constant(ll.IntType(32), 0))
function = self.builder.basic_block.function
copy_block = function.append_basic_block("ai_copy")
end_block = function.append_basic_block("ai_end")
self.builder.branch(copy_block)
self.builder.position_at_end(copy_block)
target.o_subscript(i, self.builder).set_value(self.builder,
val.elt)
i.auto_store(self.builder, self.builder.add(
i.auto_load(self.builder),
ll.Constant(ll.IntType(32), 1)))
cont = self.builder.icmp_signed(
"<", i.auto_load(self.builder),
ll.Constant(ll.IntType(32), val.alloc_count))
self.builder.cbranch(cont, copy_block, end_block)
self.builder.position_at_end(end_block)
else:
for target in node.targets:
target = self.visit_expression(target)
target.set_value(self.builder, val)
def _visit_stmt_AugAssign(self, node):
target = self.visit_expression(node.target)
right = self.visit_expression(node.value)
val = _ast_binops[type(node.op)](target, right, self.builder)
target.set_value(self.builder, val)
def _visit_stmt_Expr(self, node):
self.visit_expression(node.value)
def _visit_stmt_If(self, node):
function = self.builder.basic_block.function
then_block = function.append_basic_block("i_then")
else_block = function.append_basic_block("i_else")
merge_block = function.append_basic_block("i_merge")
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(condition.auto_load(self.builder),
then_block, else_block)
self.builder.position_at_end(then_block)
self.visit_statements(node.body)
if not self._bb_terminated():
self.builder.branch(merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
if not self._bb_terminated():
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)
def _enter_loop_body(self, break_block, continue_block):
self._break_stack.append(break_block)
self._continue_stack.append(continue_block)
self._exception_level_stack.append(0)
def _leave_loop_body(self):
self._exception_level_stack.pop()
self._continue_stack.pop()
self._break_stack.pop()
def _visit_stmt_While(self, node):
function = self.builder.basic_block.function
body_block = function.append_basic_block("w_body")
else_block = function.append_basic_block("w_else")
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(
condition.auto_load(self.builder), body_block, else_block)
continue_block = function.append_basic_block("w_continue")
merge_block = function.append_basic_block("w_merge")
self.builder.position_at_end(body_block)
self._enter_loop_body(merge_block, continue_block)
self.visit_statements(node.body)
self._leave_loop_body()
if not self._bb_terminated():
self.builder.branch(continue_block)
self.builder.position_at_end(continue_block)
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(
condition.auto_load(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
if not self._bb_terminated():
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)
def _visit_stmt_For(self, node):
function = self.builder.basic_block.function
it = self.visit_expression(node.iter)
target = self.visit_expression(node.target)
itval = it.get_value_ptr()
body_block = function.append_basic_block("f_body")
else_block = function.append_basic_block("f_else")
cont = it.o_next(self.builder)
self.builder.cbranch(
cont.auto_load(self.builder), body_block, else_block)
continue_block = function.append_basic_block("f_continue")
merge_block = function.append_basic_block("f_merge")
self.builder.position_at_end(body_block)
target.set_value(self.builder, itval)
self._enter_loop_body(merge_block, continue_block)
self.visit_statements(node.body)
self._leave_loop_body()
if not self._bb_terminated():
self.builder.branch(continue_block)
self.builder.position_at_end(continue_block)
cont = it.o_next(self.builder)
self.builder.cbranch(
cont.auto_load(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
if not self._bb_terminated():
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)
def _break_loop_body(self, target_block):
exception_levels = self._exception_level_stack[-1]
if exception_levels:
self.runtime.build_pop(self.builder, exception_levels)
self.builder.branch(target_block)
def _visit_stmt_Break(self, node):
self._break_loop_body(self._break_stack[-1])
def _visit_stmt_Continue(self, node):
self._break_loop_body(self._continue_stack[-1])
def _visit_stmt_Return(self, node):
if node.value is None:
val = base_types.VNone()
else:
val = self.visit_expression(node.value)
exception_levels = sum(self._exception_level_stack)
if exception_levels:
self.runtime.build_pop(self.builder, exception_levels)
if isinstance(val, base_types.VNone):
self.builder.ret_void()
else:
self.builder.ret(val.auto_load(self.builder))
def _visit_stmt_Pass(self, node):
pass
def _visit_stmt_Raise(self, node):
if self._active_exception_stack:
finally_block, propagate, propagate_eid = (
self._active_exception_stack[-1])
self.builder.store(ll.Constant(ll.IntType(1), 1), propagate)
if node.exc is not None:
eid = ll.Constant(ll.IntType(32), node.exc.args[0].n)
self.builder.store(eid, propagate_eid)
self.builder.branch(finally_block)
else:
eid = ll.Constant(ll.IntType(32), node.exc.args[0].n)
self.runtime.build_raise(self.builder, eid)
def _handle_exception(self, function, finally_block,
propagate, propagate_eid, handlers):
eid = self.runtime.build_getid(self.builder)
self._active_exception_stack.append(
(finally_block, propagate, propagate_eid))
self.builder.store(ll.Constant(ll.IntType(1), 1), propagate)
self.builder.store(eid, propagate_eid)
for handler in handlers:
handled_exc_block = function.append_basic_block("try_exc_h")
cont_exc_block = function.append_basic_block("try_exc_c")
if handler.type is None:
self.builder.branch(handled_exc_block)
else:
if isinstance(handler.type, ast.Tuple):
match = self.builder.icmp_signed(
"==", eid,
ll.Constant(ll.IntType(32),
handler.type.elts[0].args[0].n))
for elt in handler.type.elts[1:]:
match = self.builder.or_(
match,
self.builder.icmp_signed(
"==", eid,
ll.Constant(ll.IntType(32), elt.args[0].n)))
else:
match = self.builder.icmp_signed(
"==", eid,
ll.Constant(ll.IntType(32), handler.type.args[0].n))
self.builder.cbranch(match, handled_exc_block, cont_exc_block)
self.builder.position_at_end(handled_exc_block)
self.builder.store(ll.Constant(ll.IntType(1), 0), propagate)
self.visit_statements(handler.body)
if not self._bb_terminated():
self.builder.branch(finally_block)
self.builder.position_at_end(cont_exc_block)
self.builder.branch(finally_block)
self._active_exception_stack.pop()
def _visit_stmt_Try(self, node):
function = self.builder.basic_block.function
noexc_block = function.append_basic_block("try_noexc")
exc_block = function.append_basic_block("try_exc")
finally_block = function.append_basic_block("try_finally")
propagate = self.builder.alloca(ll.IntType(1),
name="propagate")
self.builder.store(ll.Constant(ll.IntType(1), 0), propagate)
propagate_eid = self.builder.alloca(ll.IntType(32),
name="propagate_eid")
exception_occured = self.runtime.build_catch(self.builder)
self.builder.cbranch(exception_occured, exc_block, noexc_block)
self.builder.position_at_end(noexc_block)
self._exception_level_stack[-1] += 1
self.visit_statements(node.body)
self._exception_level_stack[-1] -= 1
if not self._bb_terminated():
self.runtime.build_pop(self.builder, 1)
self.visit_statements(node.orelse)
if not self._bb_terminated():
self.builder.branch(finally_block)
self.builder.position_at_end(exc_block)
self._handle_exception(function, finally_block,
propagate, propagate_eid, node.handlers)
propagate_block = function.append_basic_block("try_propagate")
merge_block = function.append_basic_block("try_merge")
self.builder.position_at_end(finally_block)
self.visit_statements(node.finalbody)
if not self._bb_terminated():
self.builder.cbranch(
self.builder.load(propagate),
propagate_block, merge_block)
self.builder.position_at_end(propagate_block)
self.runtime.build_raise(self.builder, self.builder.load(propagate_eid))
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)