Add integration tests for every language construct.

This commit is contained in:
whitequark 2015-07-22 18:34:52 +03:00
parent dff4ce7e3a
commit f2a6110cc4
22 changed files with 384 additions and 92 deletions

View File

@ -1,7 +1,7 @@
import sys, fileinput, os import sys, fileinput, os
from pythonparser import source, diagnostic, algorithm, parse_buffer from pythonparser import source, diagnostic, algorithm, parse_buffer
from .. import prelude, types from .. import prelude, types
from ..transforms import ASTTypedRewriter, Inferencer from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
class Printer(algorithm.Visitor): class Printer(algorithm.Visitor):
""" """
@ -42,6 +42,12 @@ class Printer(algorithm.Visitor):
":{}".format(self.type_printer.name(node.type))) ":{}".format(self.type_printer.name(node.type)))
def main(): def main():
if sys.argv[1] == "+mono":
del sys.argv[1]
monomorphize = True
else:
monomorphize = False
if len(sys.argv) > 1 and sys.argv[1] == "+diag": if len(sys.argv) > 1 and sys.argv[1] == "+diag":
del sys.argv[1] del sys.argv[1]
def process_diagnostic(diag): def process_diagnostic(diag):
@ -62,6 +68,9 @@ def main():
parsed, comments = parse_buffer(buf, engine=engine) parsed, comments = parse_buffer(buf, engine=engine)
typed = ASTTypedRewriter(engine=engine).visit(parsed) typed = ASTTypedRewriter(engine=engine).visit(parsed)
Inferencer(engine=engine).visit(typed) Inferencer(engine=engine).visit(typed)
if monomorphize:
IntMonomorphizer(engine=engine).visit(typed)
Inferencer(engine=engine).visit(typed)
printer = Printer(buf) printer = Printer(buf)
printer.visit(typed) printer.visit(typed)

View File

@ -13,6 +13,13 @@ def main():
engine.process = process_diagnostic engine.process = process_diagnostic
llmod = Module.from_string("".join(fileinput.input()).expandtabs(), engine=engine).llvm_ir llmod = Module.from_string("".join(fileinput.input()).expandtabs(), engine=engine).llvm_ir
# Add main so that the result can be executed with lli
llmain = ll.Function(llmod, ll.FunctionType(ll.VoidType(), []), "main")
llbuilder = ll.IRBuilder(llmain.append_basic_block("entry"))
llbuilder.call(llmod.get_global(llmod.name + ".__modinit__"), [])
llbuilder.ret_void()
print(llmod) print(llmod)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -275,7 +275,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.current_assign = None self.current_assign = None
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
lhs = self.visit(target) lhs = self.visit(node.target)
rhs = self.visit(node.value) rhs = self.visit(node.value)
value = self.append(ir.Arith(node.op, lhs, rhs)) value = self.append(ir.Arith(node.op, lhs, rhs))
try: try:
@ -291,20 +291,22 @@ class ARTIQIRGenerator(algorithm.Visitor):
if_true = self.add_block() if_true = self.add_block()
self.current_block = if_true self.current_block = if_true
self.visit(node.body) self.visit(node.body)
post_if_true = self.current_block
if any(node.orelse): if any(node.orelse):
if_false = self.add_block() if_false = self.add_block()
self.current_block = if_false self.current_block = if_false
self.visit(node.orelse) self.visit(node.orelse)
post_if_false = self.current_block
tail = self.add_block() tail = self.add_block()
self.current_block = tail self.current_block = tail
if not if_true.is_terminated(): if not post_if_true.is_terminated():
if_true.append(ir.Branch(tail)) post_if_true.append(ir.Branch(tail))
if any(node.orelse): if any(node.orelse):
if not if_false.is_terminated(): if not post_if_false.is_terminated():
if_false.append(ir.Branch(tail)) post_if_false.append(ir.Branch(tail))
self.append(ir.BranchIf(cond, if_true, if_false), block=head) self.append(ir.BranchIf(cond, if_true, if_false), block=head)
else: else:
self.append(ir.BranchIf(cond, if_true, tail), block=head) self.append(ir.BranchIf(cond, if_true, tail), block=head)
@ -323,38 +325,42 @@ class ARTIQIRGenerator(algorithm.Visitor):
body = self.add_block("while.body") body = self.add_block("while.body")
self.current_block = body self.current_block = body
self.visit(node.body) self.visit(node.body)
post_body = self.current_block
if any(node.orelse): if any(node.orelse):
else_tail = self.add_block("while.else") else_tail = self.add_block("while.else")
self.current_block = else_tail self.current_block = else_tail
self.visit(node.orelse) self.visit(node.orelse)
post_else_tail = self.current_block
tail = self.add_block("while.tail") tail = self.add_block("while.tail")
self.current_block = tail self.current_block = tail
if any(node.orelse): if any(node.orelse):
if not else_tail.is_terminated(): if not post_else_tail.is_terminated():
else_tail.append(ir.Branch(tail)) post_else_tail.append(ir.Branch(tail))
else: else:
else_tail = tail else_tail = tail
head.append(ir.BranchIf(cond, body, else_tail)) head.append(ir.BranchIf(cond, body, else_tail))
if not body.is_terminated(): if not post_body.is_terminated():
body.append(ir.Branch(head)) post_body.append(ir.Branch(head))
break_block.append(ir.Branch(tail)) break_block.append(ir.Branch(tail))
finally: finally:
self.break_target = old_break self.break_target = old_break
self.continue_target = old_continue self.continue_target = old_continue
def iterable_len(self, value, typ=builtins.TInt(types.TValue(32))): def iterable_len(self, value, typ=_size_type):
if builtins.is_list(value.type): if builtins.is_list(value.type):
return self.append(ir.Builtin("len", [value], typ)) return self.append(ir.Builtin("len", [value], typ,
name="{}.len".format(value.name)))
elif builtins.is_range(value.type): elif builtins.is_range(value.type):
start = self.append(ir.GetAttr(value, "start")) start = self.append(ir.GetAttr(value, "start"))
stop = self.append(ir.GetAttr(value, "stop")) stop = self.append(ir.GetAttr(value, "stop"))
step = self.append(ir.GetAttr(value, "step")) step = self.append(ir.GetAttr(value, "step"))
spread = self.append(ir.Arith(ast.Sub(loc=None), stop, start)) spread = self.append(ir.Arith(ast.Sub(loc=None), stop, start))
return self.append(ir.Arith(ast.FloorDiv(loc=None), spread, step)) return self.append(ir.Arith(ast.FloorDiv(loc=None), spread, step,
name="{}.len".format(value.name)))
else: else:
assert False assert False
@ -403,24 +409,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
finally: finally:
self.current_assign = None self.current_assign = None
self.visit(node.body) self.visit(node.body)
post_body = self.current_block
if any(node.orelse): if any(node.orelse):
else_tail = self.add_block("for.else") else_tail = self.add_block("for.else")
self.current_block = else_tail self.current_block = else_tail
self.visit(node.orelse) self.visit(node.orelse)
post_else_tail = self.current_block
tail = self.add_block("for.tail") tail = self.add_block("for.tail")
self.current_block = tail self.current_block = tail
if any(node.orelse): if any(node.orelse):
if not else_tail.is_terminated(): if not post_else_tail.is_terminated():
else_tail.append(ir.Branch(tail)) post_else_tail.append(ir.Branch(tail))
else: else:
else_tail = tail else_tail = tail
head.append(ir.BranchIf(cond, body, else_tail)) head.append(ir.BranchIf(cond, body, else_tail))
if not body.is_terminated(): if not post_body.is_terminated():
body.append(ir.Branch(continue_block)) post_body.append(ir.Branch(continue_block))
break_block.append(ir.Branch(tail)) break_block.append(ir.Branch(tail))
finally: finally:
self.break_target = old_break self.break_target = old_break
@ -611,15 +619,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
self.append(ir.SetAttr(obj, node.attr, self.current_assign)) self.append(ir.SetAttr(obj, node.attr, self.current_assign))
def _map_index(self, length, index): def _map_index(self, length, index, one_past_the_end=False):
lt_0 = self.append(ir.Compare(ast.Lt(loc=None), lt_0 = self.append(ir.Compare(ast.Lt(loc=None),
index, ir.Constant(0, index.type))) index, ir.Constant(0, index.type)))
from_end = self.append(ir.Arith(ast.Add(loc=None), length, index)) from_end = self.append(ir.Arith(ast.Add(loc=None), length, index))
mapped_index = self.append(ir.Select(lt_0, from_end, index)) mapped_index = self.append(ir.Select(lt_0, from_end, index))
mapped_ge_0 = self.append(ir.Compare(ast.GtE(loc=None), mapped_ge_0 = self.append(ir.Compare(ast.GtE(loc=None),
mapped_index, ir.Constant(0, mapped_index.type))) mapped_index, ir.Constant(0, mapped_index.type)))
mapped_lt_len = self.append(ir.Compare(ast.Lt(loc=None), end_cmpop = ast.LtE(loc=None) if one_past_the_end else ast.Lt(loc=None)
mapped_index, length)) mapped_lt_len = self.append(ir.Compare(end_cmpop, mapped_index, length))
in_bounds = self.append(ir.Select(mapped_ge_0, mapped_lt_len, in_bounds = self.append(ir.Select(mapped_ge_0, mapped_lt_len,
ir.Constant(False, builtins.TBool()))) ir.Constant(False, builtins.TBool())))
@ -699,7 +707,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
max_index = self.visit(node.slice.upper) max_index = self.visit(node.slice.upper)
else: else:
max_index = length max_index = length
mapped_max_index = self._map_index(length, max_index) mapped_max_index = self._map_index(length, max_index, one_past_the_end=True)
if node.slice.step is not None: if node.slice.step is not None:
step = self.visit(node.slice.step) step = self.visit(node.slice.step)
@ -708,9 +716,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
unstepped_size = self.append(ir.Arith(ast.Sub(loc=None), unstepped_size = self.append(ir.Arith(ast.Sub(loc=None),
mapped_max_index, mapped_min_index)) mapped_max_index, mapped_min_index))
slice_size = self.append(ir.Arith(ast.FloorDiv(loc=None), unstepped_size, step)) slice_size = self.append(ir.Arith(ast.FloorDiv(loc=None), unstepped_size, step,
name="slice.size"))
self._make_check(self.append(ir.Compare(ast.Eq(loc=None), slice_size, length)), self._make_check(self.append(ir.Compare(ast.LtE(loc=None), slice_size, length)),
lambda: self.append(ir.Alloc([], builtins.TValueError()))) lambda: self.append(ir.Alloc([], builtins.TValueError())))
if self.current_assign is None: if self.current_assign is None:
@ -735,6 +744,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, slice_size)), lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, slice_size)),
body_gen) body_gen)
if self.current_assign is None:
return other_value
def visit_TupleT(self, node): def visit_TupleT(self, node):
if self.current_assign is None: if self.current_assign is None:
return self.append(ir.Alloc([self.visit(elt) for elt in node.elts], node.type)) return self.append(ir.Alloc([self.visit(elt) for elt in node.elts], node.type))
@ -759,7 +771,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.SetElem(lst, ir.Constant(index, self._size_type), elt_node)) self.append(ir.SetElem(lst, ir.Constant(index, self._size_type), elt_node))
return lst return lst
else: else:
length = self.append(ir.Builtin("len", [self.current_assign], self._size_type)) length = self.iterable_len(self.current_assign)
self._make_check(self.append(ir.Compare(ast.Eq(loc=None), length, self._make_check(self.append(ir.Compare(ast.Eq(loc=None), length,
ir.Constant(len(node.elts), self._size_type))), ir.Constant(len(node.elts), self._size_type))),
lambda: self.append(ir.Alloc([], builtins.TValueError()))) lambda: self.append(ir.Alloc([], builtins.TValueError())))
@ -793,7 +805,6 @@ class ARTIQIRGenerator(algorithm.Visitor):
elt = self.iterable_get(iterable, index) elt = self.iterable_get(iterable, index)
try: try:
old_assign, self.current_assign = self.current_assign, elt old_assign, self.current_assign = self.current_assign, elt
print(comprehension.target, self.current_assign)
self.visit(comprehension.target) self.visit(comprehension.target)
finally: finally:
self.current_assign = old_assign self.current_assign = old_assign
@ -837,7 +848,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
def visit_UnaryOpT(self, node): def visit_UnaryOpT(self, node):
if isinstance(node.op, ast.Not): if isinstance(node.op, ast.Not):
return self.append(ir.Select(node.operand, return self.append(ir.Select(self.visit(node.operand),
ir.Constant(False, builtins.TBool()), ir.Constant(False, builtins.TBool()),
ir.Constant(True, builtins.TBool()))) ir.Constant(True, builtins.TBool())))
elif isinstance(node.op, ast.USub): elif isinstance(node.op, ast.USub):
@ -866,7 +877,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Arith(node.op, self.visit(node.left), self.visit(node.right))) return self.append(ir.Arith(node.op, self.visit(node.left), self.visit(node.right)))
elif isinstance(node.op, ast.Add): # list + list, tuple + tuple elif isinstance(node.op, ast.Add): # list + list, tuple + tuple
lhs, rhs = self.visit(node.left), self.visit(node.right) lhs, rhs = self.visit(node.left), self.visit(node.right)
if types.is_tuple(node.left.type) and builtins.is_tuple(node.right.type): if types.is_tuple(node.left.type) and types.is_tuple(node.right.type):
elts = [] elts = []
for index, elt in enumerate(node.left.type.elts): for index, elt in enumerate(node.left.type.elts):
elts.append(self.append(ir.GetAttr(lhs, index))) elts.append(self.append(ir.GetAttr(lhs, index)))
@ -874,8 +885,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
elts.append(self.append(ir.GetAttr(rhs, index))) elts.append(self.append(ir.GetAttr(rhs, index)))
return self.append(ir.Alloc(elts, node.type)) return self.append(ir.Alloc(elts, node.type))
elif builtins.is_list(node.left.type) and builtins.is_list(node.right.type): elif builtins.is_list(node.left.type) and builtins.is_list(node.right.type):
lhs_length = self.append(ir.Builtin("len", [lhs], self._size_type)) lhs_length = self.iterable_len(lhs)
rhs_length = self.append(ir.Builtin("len", [rhs], self._size_type)) rhs_length = self.iterable_len(rhs)
result_length = self.append(ir.Arith(ast.Add(loc=None), lhs_length, rhs_length)) result_length = self.append(ir.Arith(ast.Add(loc=None), lhs_length, rhs_length))
result = self.append(ir.Alloc([result_length], node.type)) result = self.append(ir.Alloc([result_length], node.type))
@ -913,7 +924,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
assert False assert False
lst_length = self.append(ir.Builtin("len", [lst], self._size_type)) lst_length = self.iterable_len(lst)
result_length = self.append(ir.Arith(ast.Mult(loc=None), lst_length, num)) result_length = self.append(ir.Arith(ast.Mult(loc=None), lst_length, num))
result = self.append(ir.Alloc([result_length], node.type)) result = self.append(ir.Alloc([result_length], node.type))
@ -934,17 +945,21 @@ class ARTIQIRGenerator(algorithm.Visitor):
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, lst_length)), lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, lst_length)),
body_gen) body_gen)
return self.append(ir.Arith(ast.Add(loc=None), lst_length, return self.append(ir.Arith(ast.Add(loc=None), num_index,
ir.Constant(1, self._size_type))) ir.Constant(1, self._size_type)))
self._make_loop(ir.Constant(0, self._size_type), self._make_loop(ir.Constant(0, self._size_type),
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, num)), lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, num)),
body_gen) body_gen)
return result
else: else:
assert False assert False
def polymorphic_compare_pair_order(self, op, lhs, rhs): def polymorphic_compare_pair_order(self, op, lhs, rhs):
if builtins.is_numeric(lhs.type) and builtins.is_numeric(rhs.type): if builtins.is_numeric(lhs.type) and builtins.is_numeric(rhs.type):
return self.append(ir.Compare(op, lhs, rhs)) return self.append(ir.Compare(op, lhs, rhs))
elif builtins.is_bool(lhs.type) and builtins.is_bool(rhs.type):
return self.append(ir.Compare(op, lhs, rhs))
elif types.is_tuple(lhs.type) and types.is_tuple(rhs.type): elif types.is_tuple(lhs.type) and types.is_tuple(rhs.type):
result = None result = None
for index in range(len(lhs.type.elts)): for index in range(len(lhs.type.elts)):
@ -959,8 +974,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
return result return result
elif builtins.is_list(lhs.type) and builtins.is_list(rhs.type): elif builtins.is_list(lhs.type) and builtins.is_list(rhs.type):
head = self.current_block head = self.current_block
lhs_length = self.append(ir.Builtin("len", [lhs], self._size_type)) lhs_length = self.iterable_len(lhs)
rhs_length = self.append(ir.Builtin("len", [rhs], self._size_type)) rhs_length = self.iterable_len(rhs)
compare_length = self.append(ir.Compare(op, lhs_length, rhs_length)) compare_length = self.append(ir.Compare(op, lhs_length, rhs_length))
eq_length = self.append(ir.Compare(ast.Eq(loc=None), lhs_length, rhs_length)) eq_length = self.append(ir.Compare(ast.Eq(loc=None), lhs_length, rhs_length))
@ -1056,24 +1071,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
return result return result
def polymorphic_compare_pair_identity(self, op, lhs, rhs):
if builtins.is_allocated(lhs) and builtins.is_allocated(rhs):
# These are actually pointers, compare directly.
return self.append(ir.Compare(op, lhs, rhs))
else:
# Compare by value instead, our backend cannot handle
# equality of aggregates.
if isinstance(op, ast.Is):
op = ast.Eq(loc=None)
elif isinstance(op, ast.IsNot):
op = ast.NotEq(loc=None)
else:
assert False
return self.polymorphic_compare_pair_order(op, lhs, rhs)
def polymorphic_compare_pair(self, op, lhs, rhs): def polymorphic_compare_pair(self, op, lhs, rhs):
if isinstance(op, (ast.Is, ast.IsNot)): if isinstance(op, (ast.Is, ast.IsNot)):
return self.polymorphic_compare_pair_identity(op, lhs, rhs) # The backend will handle equality of aggregates.
return self.append(ir.Compare(op, lhs, rhs))
elif isinstance(op, (ast.In, ast.NotIn)): elif isinstance(op, (ast.In, ast.NotIn)):
return self.polymorphic_compare_pair_inclusion(op, lhs, rhs) return self.polymorphic_compare_pair_inclusion(op, lhs, rhs)
else: # Eq, NotEq, Lt, LtE, Gt, GtE else: # Eq, NotEq, Lt, LtE, Gt, GtE
@ -1086,21 +1087,25 @@ class ARTIQIRGenerator(algorithm.Visitor):
lhs = self.visit(node.left) lhs = self.visit(node.left)
self.instrument_assert(node.left, lhs) self.instrument_assert(node.left, lhs)
for op, rhs_node in zip(node.ops, node.comparators): for op, rhs_node in zip(node.ops, node.comparators):
result_head = self.current_block
rhs = self.visit(rhs_node) rhs = self.visit(rhs_node)
self.instrument_assert(rhs_node, rhs) self.instrument_assert(rhs_node, rhs)
result = self.polymorphic_compare_pair(op, lhs, rhs) result = self.polymorphic_compare_pair(op, lhs, rhs)
blocks.append((result, self.current_block)) result_tail = self.current_block
blocks.append((result, result_head, result_tail))
self.current_block = self.add_block() self.current_block = self.add_block()
lhs = rhs lhs = rhs
tail = self.current_block tail = self.current_block
phi = self.append(ir.Phi(node.type)) phi = self.append(ir.Phi(node.type))
for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]): for ((result, result_head, result_tail), (next_result_head, next_result_tail)) in \
phi.add_incoming(value, block) zip(blocks, [(h,t) for (v,h,t) in blocks[1:]] + [(tail, tail)]):
if next_block != tail: phi.add_incoming(result, result_tail)
block.append(ir.BranchIf(value, next_block, tail)) if next_result_head != tail:
result_tail.append(ir.BranchIf(result, next_result_head, tail))
else: else:
block.append(ir.Branch(tail)) result_tail.append(ir.Branch(tail))
return phi return phi
def visit_builtin_call(self, node): def visit_builtin_call(self, node):
@ -1138,7 +1143,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
elif types.is_builtin(typ, "list"): elif types.is_builtin(typ, "list"):
if len(node.args) == 0 and len(node.keywords) == 0: if len(node.args) == 0 and len(node.keywords) == 0:
length = ir.Constant(0, builtins.TInt(types.TValue(32))) length = ir.Constant(0, builtins.TInt(types.TValue(32)))
return self.append(ir.Alloc(node.type, length)) return self.append(ir.Alloc([length], node.type))
elif len(node.args) == 1 and len(node.keywords) == 0: elif len(node.args) == 1 and len(node.keywords) == 0:
arg = self.visit(node.args[0]) arg = self.visit(node.args[0])
length = self.iterable_len(arg) length = self.iterable_len(arg)
@ -1157,7 +1162,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
assert False assert False
elif types.is_builtin(typ, "range"): elif types.is_builtin(typ, "range"):
elt_typ = builtins.getiterable_elt(node.type) elt_typ = builtins.get_iterable_elt(node.type)
if len(node.args) == 1 and len(node.keywords) == 0: if len(node.args) == 1 and len(node.keywords) == 0:
max_arg = self.visit(node.args[0]) max_arg = self.visit(node.args[0])
return self.append(ir.Alloc([ return self.append(ir.Alloc([
@ -1193,7 +1198,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
elif types.is_builtin(typ, "round"): elif types.is_builtin(typ, "round"):
if len(node.args) == 1 and len(node.keywords) == 0: if len(node.args) == 1 and len(node.keywords) == 0:
arg = self.visit(node.args[0]) arg = self.visit(node.args[0])
return self.append(ir.Builtin("round", [arg])) return self.append(ir.Builtin("round", [arg], node.type))
else: else:
assert False assert False
elif types.is_builtin(typ, "print"): elif types.is_builtin(typ, "print"):
@ -1241,7 +1246,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
for (subexpr, name) in self.current_assert_subexprs]): for (subexpr, name) in self.current_assert_subexprs]):
return # don't display the same subexpression twice return # don't display the same subexpression twice
name = self.current_assert_env.type.add("subexpr", ir.TOption(node.type)) name = self.current_assert_env.type.add(".subexpr", ir.TOption(node.type))
value_opt = self.append(ir.Alloc([value], ir.TOption(node.type)), value_opt = self.append(ir.Alloc([value], ir.TOption(node.type)),
loc=node.loc) loc=node.loc)
self.append(ir.SetLocal(self.current_assert_env, name, value_opt), self.append(ir.SetLocal(self.current_assert_env, name, value_opt),
@ -1325,7 +1330,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.polymorphic_print([self.append(ir.GetAttr(value, index)) self.polymorphic_print([self.append(ir.GetAttr(value, index))
for index in range(len(value.type.elts))], for index in range(len(value.type.elts))],
separator=", ") separator=", ")
format_string += ")" if len(value.type.elts) == 1:
format_string += ",)"
else:
format_string += ")"
elif types.is_function(value.type): elif types.is_function(value.type):
format_string += "<closure %p(%p)>" format_string += "<closure %p(%p)>"
# We're relying on the internal layout of the closure here. # We're relying on the internal layout of the closure here.
@ -1341,7 +1349,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
elif builtins.is_int(value.type): elif builtins.is_int(value.type):
width = builtins.get_int_width(value.type) width = builtins.get_int_width(value.type)
if width <= 32: if width <= 32:
format_string += "%ld" format_string += "%d"
elif width <= 64: elif width <= 64:
format_string += "%lld" format_string += "%lld"
else: else:

View File

@ -300,7 +300,7 @@ class ASTTypedRewriter(algorithm.Transformer):
node = self.generic_visit(node) node = self.generic_visit(node)
node = asttyped.SubscriptT(type=types.TVar(), node = asttyped.SubscriptT(type=types.TVar(),
value=node.value, slice=node.slice, ctx=node.ctx, value=node.value, slice=node.slice, ctx=node.ctx,
loc=node.loc) begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc)
return self.visit(node) return self.visit(node)
def visit_BoolOp(self, node): def visit_BoolOp(self, node):

View File

@ -109,6 +109,7 @@ class Inferencer(algorithm.Visitor):
self.engine.process(diag) self.engine.process(diag)
def visit_Index(self, node): def visit_Index(self, node):
self.generic_visit(node)
value = node.value value = node.value
if types.is_tuple(value.type): if types.is_tuple(value.type):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
@ -342,7 +343,8 @@ class Inferencer(algorithm.Visitor):
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
elif isinstance(op, ast.Div): elif isinstance(op, ast.Div):
# division always returns a float # division always returns a float
return self._coerce_numeric((left, right), lambda typ: (builtins.TFloat(), typ, typ)) return self._coerce_numeric((left, right),
lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat()))
else: # MatMult else: # MatMult
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"operator '{op}' is not supported", {"op": op.loc.source()}, "operator '{op}' is not supported", {"op": op.loc.source()},
@ -377,7 +379,7 @@ class Inferencer(algorithm.Visitor):
for left, right in pairs: for left, right in pairs:
self._unify(left.type, right.type, self._unify(left.type, right.type,
left.loc, right.loc) left.loc, right.loc)
else: elif any(map(builtins.is_numeric, operand_types)):
typ = self._coerce_numeric(operands) typ = self._coerce_numeric(operands)
if typ: if typ:
try: try:
@ -393,6 +395,8 @@ class Inferencer(algorithm.Visitor):
other_node = next(filter(wide_enough, operands)) other_node = next(filter(wide_enough, operands))
node.left, *node.comparators = \ node.left, *node.comparators = \
[self._coerce_one(typ, operand, other_node) for operand in operands] [self._coerce_one(typ, operand, other_node) for operand in operands]
else:
pass # No coercion required.
self._unify(node.type, builtins.TBool(), self._unify(node.type, builtins.TBool(),
node.loc, None) node.loc, None)

View File

@ -99,12 +99,16 @@ class LLVMIRGenerator:
llty = ll.FunctionType(ll.VoidType(), []) llty = ll.FunctionType(ll.VoidType(), [])
elif name in "llvm.trap": elif name in "llvm.trap":
llty = ll.FunctionType(ll.VoidType(), []) llty = ll.FunctionType(ll.VoidType(), [])
elif name == "llvm.floor.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
elif name == "llvm.round.f64": elif name == "llvm.round.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()]) llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
elif name == "llvm.pow.f64": elif name == "llvm.pow.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()]) llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
elif name == "llvm.powi.f64": elif name == "llvm.powi.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)]) llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)])
elif name == "llvm.copysign.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
elif name == "printf": elif name == "printf":
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True) llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True)
else: else:
@ -331,27 +335,36 @@ class LLVMIRGenerator:
elif isinstance(insn.op, ast.FloorDiv): elif isinstance(insn.op, ast.FloorDiv):
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs())) llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()))
return self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llvalue], return self.llbuilder.call(self.llbuiltin("llvm.floor.f64"), [llvalue],
name=insn.name) name=insn.name)
else: else:
return self.llbuilder.sdiv(self.map(insn.lhs()), self.map(insn.rhs()), return self.llbuilder.sdiv(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
elif isinstance(insn.op, ast.Mod): elif isinstance(insn.op, ast.Mod):
# Python only has the modulo operator, LLVM only has the remainder
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
return self.llbuilder.frem(self.map(insn.lhs()), self.map(insn.rhs()), llvalue = self.llbuilder.frem(self.map(insn.lhs()), self.map(insn.rhs()))
return self.llbuilder.call(self.llbuiltin("llvm.copysign.f64"),
[llvalue, self.map(insn.rhs())],
name=insn.name) name=insn.name)
else: else:
return self.llbuilder.srem(self.map(insn.lhs()), self.map(insn.rhs()), lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs()))
name=insn.name) llxorsign = self.llbuilder.and_(self.llbuilder.xor(lllhs, llrhs),
ll.Constant(lllhs.type, 1 << lllhs.type.width - 1))
llnegate = self.llbuilder.icmp_unsigned('!=',
llxorsign, ll.Constant(llxorsign.type, 0))
llvalue = self.llbuilder.srem(lllhs, llrhs)
llnegvalue = self.llbuilder.sub(ll.Constant(llvalue.type, 0), llvalue)
return self.llbuilder.select(llnegate, llnegvalue, llvalue)
elif isinstance(insn.op, ast.Pow): elif isinstance(insn.op, ast.Pow):
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
return self.llbuilder.call(self.llbuiltin("llvm.pow.f64"), return self.llbuilder.call(self.llbuiltin("llvm.pow.f64"),
[self.map(insn.lhs()), self.map(insn.rhs())], [self.map(insn.lhs()), self.map(insn.rhs())],
name=insn.name) name=insn.name)
else: else:
lllhs = self.llbuilder.sitofp(self.map(insn.lhs()), ll.DoubleType())
llrhs = self.llbuilder.trunc(self.map(insn.rhs()), ll.IntType(32)) llrhs = self.llbuilder.trunc(self.map(insn.rhs()), ll.IntType(32))
llvalue = self.llbuilder.call(self.llbuiltin("llvm.powi.f64"), llvalue = self.llbuilder.call(self.llbuiltin("llvm.powi.f64"), [lllhs, llrhs])
[self.map(insn.lhs()), llrhs])
return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type), return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type),
name=insn.name) name=insn.name)
elif isinstance(insn.op, ast.LShift): elif isinstance(insn.op, ast.LShift):
@ -373,9 +386,9 @@ class LLVMIRGenerator:
assert False assert False
def process_Compare(self, insn): def process_Compare(self, insn):
if isinstance(insn.op, ast.Eq): if isinstance(insn.op, (ast.Eq, ast.Is)):
op = '==' op = '=='
elif isinstance(insn.op, ast.NotEq): elif isinstance(insn.op, (ast.NotEq, ast.IsNot)):
op = '!=' op = '!='
elif isinstance(insn.op, ast.Gt): elif isinstance(insn.op, ast.Gt):
op = '>' op = '>'
@ -388,12 +401,32 @@ class LLVMIRGenerator:
else: else:
assert False assert False
if builtins.is_float(insn.lhs().type): lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs()))
return self.llbuilder.fcmp_ordered(op, self.map(insn.lhs()), self.map(insn.rhs()), assert lllhs.type == llrhs.type
if isinstance(lllhs.type, ll.IntType):
return self.llbuilder.icmp_signed(op, lllhs, llrhs,
name=insn.name)
elif isinstance(lllhs.type, ll.PointerType):
return self.llbuilder.icmp_unsigned(op, lllhs, llrhs,
name=insn.name)
elif isinstance(lllhs.type, (ll.FloatType, ll.DoubleType)):
return self.llbuilder.fcmp_ordered(op, lllhs, llrhs,
name=insn.name) name=insn.name)
elif isinstance(lllhs.type, ll.LiteralStructType):
# Compare aggregates (such as lists or ranges) element-by-element.
llvalue = ll.Constant(ll.IntType(1), True)
for index in range(len(lllhs.type.elements)):
lllhselt = self.llbuilder.extract_value(lllhs, index)
llrhselt = self.llbuilder.extract_value(llrhs, index)
llresult = self.llbuilder.icmp_unsigned('==', lllhselt, llrhselt)
llvalue = self.llbuilder.select(llresult, llvalue,
ll.Constant(ll.IntType(1), False))
return self.llbuilder.icmp_unsigned(op, llvalue, ll.Constant(ll.IntType(1), True),
name=insn.name)
else: else:
return self.llbuilder.icmp_signed(op, self.map(insn.lhs()), self.map(insn.rhs()), print(lllhs, llrhs)
name=insn.name) assert False
def process_Builtin(self, insn): def process_Builtin(self, insn):
if insn.op == "nop": if insn.op == "nop":
@ -401,22 +434,24 @@ class LLVMIRGenerator:
if insn.op == "abort": if insn.op == "abort":
return self.llbuilder.call(self.llbuiltin("llvm.trap"), []) return self.llbuilder.call(self.llbuiltin("llvm.trap"), [])
elif insn.op == "is_some": elif insn.op == "is_some":
optarg = self.map(insn.operands[0]) lloptarg = self.map(insn.operands[0])
return self.llbuilder.extract_value(optarg, 0, return self.llbuilder.extract_value(lloptarg, 0,
name=insn.name) name=insn.name)
elif insn.op == "unwrap": elif insn.op == "unwrap":
optarg = self.map(insn.operands[0]) lloptarg = self.map(insn.operands[0])
return self.llbuilder.extract_value(optarg, 1, return self.llbuilder.extract_value(lloptarg, 1,
name=insn.name) name=insn.name)
elif insn.op == "unwrap_or": elif insn.op == "unwrap_or":
optarg, default = map(self.map, insn.operands) lloptarg, lldefault = map(self.map, insn.operands)
has_arg = self.llbuilder.extract_value(optarg, 0) llhas_arg = self.llbuilder.extract_value(lloptarg, 0)
arg = self.llbuilder.extract_value(optarg, 1) llarg = self.llbuilder.extract_value(lloptarg, 1)
return self.llbuilder.select(has_arg, arg, default, return self.llbuilder.select(llhas_arg, llarg, lldefault,
name=insn.name) name=insn.name)
elif insn.op == "round": elif insn.op == "round":
return self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llvalue], llarg = self.map(insn.operands[0])
name=insn.name) llvalue = self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llarg])
return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type),
name=insn.name)
elif insn.op == "globalenv": elif insn.op == "globalenv":
def get_outer(llenv, env_ty): def get_outer(llenv, env_ty):
if ".outer" in env_ty.params: if ".outer" in env_ty.params:

View File

@ -13,14 +13,14 @@ def genalnum():
pos = len(ident) - 1 pos = len(ident) - 1
while pos >= 0: while pos >= 0:
cur_n = string.ascii_lowercase.index(ident[pos]) cur_n = string.ascii_lowercase.index(ident[pos])
if cur_n < 26: if cur_n < 25:
ident[pos] = string.ascii_lowercase[cur_n + 1] ident[pos] = string.ascii_lowercase[cur_n + 1]
break break
else: else:
ident[pos] = "a" ident[pos] = "a"
pos -= 1 pos -= 1
if pos < 0: if pos < 0:
ident = "a" + ident ident = ["a"] + ident
class UnificationError(Exception): class UnificationError(Exception):
def __init__(self, typea, typeb): def __init__(self, typea, typeb):

View File

@ -77,12 +77,15 @@ class RegionOf(algorithm.Visitor):
# Value lives as long as the current scope, if it's mutable, # Value lives as long as the current scope, if it's mutable,
# or else forever # or else forever
def visit_BinOpT(self, node): def visit_sometimes_allocating(self, node):
if builtins.is_allocated(node.type): if builtins.is_allocated(node.type):
return self.youngest_region return self.youngest_region
else: else:
return None return None
visit_BinOpT = visit_sometimes_allocating
visit_CallT = visit_sometimes_allocating
# Value lives as long as the object/container, if it's mutable, # Value lives as long as the object/container, if it's mutable,
# or else forever # or else forever
def visit_accessor(self, node): def visit_accessor(self, node):
@ -136,7 +139,6 @@ class RegionOf(algorithm.Visitor):
visit_EllipsisT = visit_immutable visit_EllipsisT = visit_immutable
visit_UnaryOpT = visit_immutable visit_UnaryOpT = visit_immutable
visit_CompareT = visit_immutable visit_CompareT = visit_immutable
visit_CallT = visit_immutable
# Value is mutable, but still lives forever # Value is mutable, but still lives forever
def visit_StrT(self, node): def visit_StrT(self, node):

View File

@ -55,8 +55,8 @@ class LocalAccessValidator:
# in order to be initialized in this block. # in order to be initialized in this block.
def merge_state(a, b): def merge_state(a, b):
return {var: a[var] and b[var] for var in a} return {var: a[var] and b[var] for var in a}
block_state[env] = reduce(lambda a, b: merge_state(a[env], b[env]), block_state[env] = reduce(merge_state,
pred_states) [state[env] for state in pred_states])
elif len(pred_states) == 1: elif len(pred_states) == 1:
# The state is the same as at the terminator of predecessor. # The state is the same as at the terminator of predecessor.
# We'll mutate it, so copy. # We'll mutate it, so copy.

View File

@ -0,0 +1,6 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
r = range(10)
assert r.start == 0
assert r.stop == 10
assert r.step == 1

View File

@ -0,0 +1,24 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
assert bool() is False
# bool(x) is tested in bool.py
assert int() is 0
assert int(1.0) is 1
assert int(1, width=64) << 40 is 1099511627776
assert float() is 0.0
assert float(1) is 1.0
x = list()
if False: x = [1]
assert x == []
assert range(10) is range(0, 10, 1)
assert range(1, 10) is range(1, 10, 1)
assert len([1, 2, 3]) is 3
assert len(range(10)) is 10
assert len(range(0, 10, 2)) is 5
assert round(1.4) is 1 and round(1.6) is 2

View File

@ -0,0 +1,18 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
assert 1 < 2 and not (2 < 1)
assert 2 > 1 and not (1 > 2)
assert 1 == 1 and not (1 == 2)
assert 1 != 2 and not (1 != 1)
assert 1 <= 1 and 1 <= 2 and not (2 <= 1)
assert 1 >= 1 and 2 >= 1 and not (1 >= 2)
assert 1 is 1 and not (1 is 2)
assert 1 is not 2 and not (1 is not 1)
x, y = [1], [1]
assert x is x and x is not y
assert range(10) is range(10) and range(10) is not range(11)
lst = [1, 2, 3]
assert 1 in lst and 0 not in lst
assert 1 in range(10) and 11 not in range(10) and -1 not in range(10)

View File

@ -0,0 +1,28 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
count = 0
for x in range(10):
count += 1
assert count == 10
for x in range(10):
assert True
else:
assert True
for x in range(0):
assert False
else:
assert True
for x in range(10):
continue
assert False
else:
assert True
for x in range(10):
break
assert False
else:
assert False

View File

@ -0,0 +1,20 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
if True:
assert True
if False:
assert False
if True:
assert True
else:
assert False
if False:
assert False
else:
assert True
assert (0 if True else 1) == 0
assert (0 if False else 1) == 1

View File

@ -0,0 +1,7 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
[x, y] = [1, 2]
assert (x, y) == (1, 2)
lst = [1, 2, 3]
assert [x*x for x in lst] == [1, 4, 9]

View File

@ -0,0 +1,6 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
x = 1
assert x == 1
x += 1
assert x == 2

View File

@ -0,0 +1,38 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
assert (not True) == False
assert (not False) == True
assert -(-1) == 1
assert -(-1.0) == 1.0
assert +1 == 1
assert +1.0 == 1.0
assert 1 + 1 == 2
assert 1.0 + 1.0 == 2.0
assert 1 - 1 == 0
assert 1.0 - 1.0 == 0.0
assert 2 * 2 == 4
assert 2.0 * 2.0 == 4.0
assert 3 / 2 == 1.5
assert 3.0 / 2.0 == 1.5
assert 3 // 2 == 1
assert 3.0 // 2.0 == 1.0
assert 3 % 2 == 1
assert -3 % 2 == 1
assert 3 % -2 == -1
assert -3 % -2 == -1
assert 3.0 % 2.0 == 1.0
assert -3.0 % 2.0 == 1.0
assert 3.0 % -2.0 == -1.0
assert -3.0 % -2.0 == -1.0
assert 3 ** 2 == 9
assert 3.0 ** 2.0 == 9.0
assert 9.0 ** 0.5 == 3.0
assert 1 << 1 == 2
assert 2 >> 1 == 1
assert -2 >> 1 == -1
assert 0x18 & 0x0f == 0x08
assert 0x18 | 0x0f == 0x1f
assert 0x18 ^ 0x0f == 0x17
assert [1] + [2] == [1, 2]
assert [1] * 3 == [1, 1, 1]

View File

@ -0,0 +1,32 @@
# RUN: %python -m artiq.compiler.testbench.jit %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: None
print(None)
# CHECK-L: True False
print(True, False)
# CHECK-L: 1 -1
print(1, -1)
# CHECK-L: 10000000000
print(10000000000)
# CHECK-L: 1.5
print(1.5)
# CHECK-L: (True, 1)
print((True, 1))
# CHECK-L: (True,)
print((True,))
# CHECK-L: [1, 2, 3]
print([1, 2, 3])
# CHECK-L: [[1, 2], [3]]
print([[1, 2], [3]])
# CHECK-L: range(0, 10, 1)
print(range(10))

View File

@ -0,0 +1,12 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
lst = list(range(10))
assert lst[0] == 0
assert lst[1] == 1
assert lst[-1] == 9
assert lst[0:1] == [0]
assert lst[0:2] == [0, 1]
assert lst[0:10] == lst
assert lst[1:-1] == lst[1:9]
assert lst[0:1:2] == [0]
assert lst[0:2:2] == [0]

View File

@ -0,0 +1,6 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
x, y = 2, 1
x, y = y, x
assert x == 1 and y == 2
assert (1, 2) + (3.0,) == (1, 2, 3.0)

View File

@ -0,0 +1,30 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
cond, count = True, 0
while cond:
count += 1
cond = False
assert count == 1
while False:
pass
else:
assert True
cond = True
while cond:
cond = False
else:
assert True
while True:
break
assert False
else:
assert False
cond = True
while cond:
cond = False
continue
assert False