From eec52a2e291a77eba54dfb2a49223ef0cdd627ae Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Tue, 9 Sep 2014 17:13:48 +0800 Subject: [PATCH] py2llvm: array support --- artiq/py2llvm/arrays.py | 70 ++++++++++++++++++++++ artiq/py2llvm/ast_body.py | 112 ++++++++++++++++++++--------------- artiq/py2llvm/base_types.py | 40 ++++++------- artiq/py2llvm/fractions.py | 34 +++++------ artiq/py2llvm/infer_types.py | 14 +++++ artiq/py2llvm/module.py | 4 +- artiq/py2llvm/values.py | 30 ++++++---- test/py2llvm.py | 60 ++++++++++++++++--- 8 files changed, 256 insertions(+), 108 deletions(-) create mode 100644 artiq/py2llvm/arrays.py diff --git a/artiq/py2llvm/arrays.py b/artiq/py2llvm/arrays.py new file mode 100644 index 000000000..c53263ae1 --- /dev/null +++ b/artiq/py2llvm/arrays.py @@ -0,0 +1,70 @@ +from llvm import core as lc + +from artiq.py2llvm.values import VGeneric +from artiq.py2llvm.base_types import VInt + + +class VArray(VGeneric): + def __init__(self, el_init, count): + VGeneric.__init__(self) + self.el_init = el_init + self.count = count + if not count: + raise TypeError("Arrays must have at least one element") + + def get_llvm_type(self): + return lc.Type.array(self.el_init.get_llvm_type(), self.count) + + def __repr__(self): + return "".format(repr(self.el_init), self.count) + + def same_type(self, other): + return ( + isinstance(other, VArray) + and self.el_init.same_type(other.el_init) + and self.count == other.count) + + def merge(self, other): + if isinstance(other, VArray): + self.el_init.merge(other.el_init) + else: + raise TypeError("Incompatible types: {} and {}" + .format(repr(self), repr(other))) + + def merge_subscript(self, other): + self.el_init.merge(other) + + def set_value(self, builder, v): + if not isinstance(v, VArray): + raise TypeError + if v.llvm_value is not None: + raise NotImplementedError("Array aliasing is not supported") + + i = VInt() + i.alloca(builder, "ai_i") + i.auto_store(builder, lc.Constant.int(lc.Type.int(), 0)) + + function = builder.basic_block.function + copy_block = function.append_basic_block("ai_copy") + end_block = function.append_basic_block("ai_end") + builder.branch(copy_block) + + builder.position_at_end(copy_block) + self.o_subscript(i, builder).set_value(builder, v.el_init) + i.auto_store(builder, builder.add( + i.auto_load(builder), lc.Constant.int(lc.Type.int(), 1))) + cont = builder.icmp( + lc.ICMP_SLT, i.auto_load(builder), + lc.Constant.int(lc.Type.int(), self.count)) + builder.cbranch(cont, copy_block, end_block) + + builder.position_at_end(end_block) + + def o_subscript(self, index, builder): + r = self.el_init.new() + if builder is not None: + index = index.o_int(builder).auto_load(builder) + ssa_r = builder.gep(self.llvm_value, [ + lc.Constant.int(lc.Type.int(), 0), index]) + r.auto_store(builder, ssa_r) + return r diff --git a/artiq/py2llvm/ast_body.py b/artiq/py2llvm/ast_body.py index 889ae5c47..6afdd8a45 100644 --- a/artiq/py2llvm/ast_body.py +++ b/artiq/py2llvm/ast_body.py @@ -1,9 +1,41 @@ import ast -from artiq.py2llvm import values, base_types, fractions +from artiq.py2llvm import values, base_types, fractions, arrays from artiq.py2llvm.tools import is_terminated +_ast_unops = { + ast.Invert: "o_inv", + ast.Not: "o_not", + ast.UAdd: "o_pos", + ast.USub: "o_neg" +} + +_ast_binops = { + ast.Add: values.operators.add, + ast.Sub: values.operators.sub, + ast.Mult: values.operators.mul, + ast.Div: values.operators.truediv, + ast.FloorDiv: values.operators.floordiv, + ast.Mod: values.operators.mod, + ast.Pow: values.operators.pow, + ast.LShift: values.operators.lshift, + ast.RShift: values.operators.rshift, + ast.BitOr: values.operators.or_, + ast.BitXor: values.operators.xor, + ast.BitAnd: values.operators.and_ +} + +_ast_cmps = { + ast.Eq: values.operators.eq, + ast.NotEq: values.operators.ne, + ast.Lt: values.operators.lt, + ast.LtE: values.operators.le, + ast.Gt: values.operators.gt, + ast.GtE: values.operators.ge +} + + class Visitor: def __init__(self, env, ns, builder=None): self.env = env @@ -53,48 +85,20 @@ class Visitor: return r def _visit_expr_UnaryOp(self, node): - ast_unops = { - ast.Invert: "o_inv", - ast.Not: "o_not", - ast.UAdd: "o_pos", - ast.USub: "o_neg" - } value = self.visit_expression(node.operand) - return getattr(value, ast_unops[type(node.op)])(self.builder) + return getattr(value, _ast_unops[type(node.op)])(self.builder) def _visit_expr_BinOp(self, node): - ast_binops = { - ast.Add: values.operators.add, - ast.Sub: values.operators.sub, - ast.Mult: values.operators.mul, - ast.Div: values.operators.truediv, - ast.FloorDiv: values.operators.floordiv, - ast.Mod: values.operators.mod, - ast.Pow: values.operators.pow, - ast.LShift: values.operators.lshift, - ast.RShift: values.operators.rshift, - ast.BitOr: values.operators.or_, - ast.BitXor: values.operators.xor, - ast.BitAnd: values.operators.and_ - } - return ast_binops[type(node.op)](self.visit_expression(node.left), - self.visit_expression(node.right), - self.builder) + return _ast_binops[type(node.op)](self.visit_expression(node.left), + self.visit_expression(node.right), + self.builder) def _visit_expr_Compare(self, node): - ast_cmps = { - ast.Eq: values.operators.eq, - ast.NotEq: values.operators.ne, - ast.Lt: values.operators.lt, - ast.LtE: values.operators.le, - ast.Gt: values.operators.gt, - ast.GtE: values.operators.ge - } comparisons = [] old_comparator = self.visit_expression(node.left) for op, comparator_a in zip(node.ops, node.comparators): comparator = self.visit_expression(comparator_a) - comparison = ast_cmps[type(op)](old_comparator, comparator, + comparison = _ast_cmps[type(op)](old_comparator, comparator, self.builder) comparisons.append(comparison) old_comparator = comparator @@ -115,6 +119,14 @@ class Visitor: denominator = self.visit_expression(node.args[1]) r.set_value_nd(self.builder, numerator, denominator) return r + elif fn == "array": + element = self.visit_expression(node.args[0]) + if (isinstance(node.args[1], ast.Num) + and isinstance(node.args[1].n, int)): + count = node.args[1].n + else: + raise ValueError("Array size must be integer and constant") + return arrays.VArray(element, count) elif fn == "syscall": return self.env.syscall( node.args[0].s, @@ -127,6 +139,14 @@ class Visitor: value = self.visit_expression(node.value) return value.o_getattr(node.attr, self.builder) + def _visit_expr_Subscript(self, node): + value = self.visit_expression(node.value) + if isinstance(node.slice, ast.Index): + index = self.visit_expression(node.slice.value) + else: + raise NotImplementedError + return value.o_subscript(index, self.builder) + def visit_statements(self, stmts): for node in stmts: node_type = node.__class__.__name__ @@ -143,18 +163,14 @@ class Visitor: def _visit_stmt_Assign(self, node): val = self.visit_expression(node.value) for target in node.targets: - if isinstance(target, ast.Name): - self.ns[target.id].set_value(self.builder, val) - else: - raise NotImplementedError + target = self.visit_expression(target) + target.set_value(self.builder, val) def _visit_stmt_AugAssign(self, node): - val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, - right=node.value)) - if isinstance(node.target, ast.Name): - self.ns[node.target.id].set_value(self.builder, val) - else: - raise NotImplementedError + target = self.visit_expression(node.target) + right = self.visit_expression(node.value) + val = _ast_binops[type(node.op)](target, right, self.builder) + target.set_value(self.builder, val) def _visit_stmt_Expr(self, node): self.visit_expression(node.value) @@ -166,7 +182,7 @@ class Visitor: merge_block = function.append_basic_block("i_merge") condition = self.visit_expression(node.test).o_bool(self.builder) - self.builder.cbranch(condition.get_ssa_value(self.builder), + self.builder.cbranch(condition.auto_load(self.builder), then_block, else_block) self.builder.position_at_end(then_block) @@ -189,14 +205,14 @@ class Visitor: condition = self.visit_expression(node.test).o_bool(self.builder) self.builder.cbranch( - condition.get_ssa_value(self.builder), body_block, else_block) + condition.auto_load(self.builder), body_block, else_block) self.builder.position_at_end(body_block) self.visit_statements(node.body) if not is_terminated(self.builder.basic_block): condition = self.visit_expression(node.test).o_bool(self.builder) self.builder.cbranch( - condition.get_ssa_value(self.builder), body_block, merge_block) + condition.auto_load(self.builder), body_block, merge_block) self.builder.position_at_end(else_block) self.visit_statements(node.orelse) @@ -213,4 +229,4 @@ class Visitor: if isinstance(val, base_types.VNone): self.builder.ret_void() else: - self.builder.ret(val.get_ssa_value(self.builder)) + self.builder.ret(val.auto_load(self.builder)) diff --git a/artiq/py2llvm/base_types.py b/artiq/py2llvm/base_types.py index f34646cf6..9e9a0aec5 100644 --- a/artiq/py2llvm/base_types.py +++ b/artiq/py2llvm/base_types.py @@ -46,19 +46,19 @@ class VInt(VGeneric): .format(repr(self), repr(other))) def set_value(self, builder, n): - self.set_ssa_value( - builder, n.o_intx(self.nbits, builder).get_ssa_value(builder)) + self.auto_store( + builder, n.o_intx(self.nbits, builder).auto_load(builder)) def set_const_value(self, builder, n): - self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n)) + self.auto_store(builder, lc.Constant.int(self.get_llvm_type(), n)) def o_bool(self, builder, inv=False): r = VBool() if builder is not None: - r.set_ssa_value( + r.auto_store( builder, builder.icmp( lc.ICMP_EQ if inv else lc.ICMP_NE, - self.get_ssa_value(builder), + self.auto_load(builder), lc.Constant.int(self.get_llvm_type(), 0))) return r @@ -68,9 +68,9 @@ class VInt(VGeneric): def o_neg(self, builder): r = VInt(self.nbits) if builder is not None: - r.set_ssa_value( + r.auto_store( builder, builder.mul( - self.get_ssa_value(builder), + self.auto_load(builder), lc.Constant.int(self.get_llvm_type(), -1))) return r @@ -78,15 +78,15 @@ class VInt(VGeneric): r = VInt(target_bits) if builder is not None: if self.nbits == target_bits: - r.set_ssa_value( - builder, self.get_ssa_value(builder)) + r.auto_store( + builder, self.auto_load(builder)) if self.nbits > target_bits: - r.set_ssa_value( - builder, builder.trunc(self.get_ssa_value(builder), + r.auto_store( + builder, builder.trunc(self.auto_load(builder), r.get_llvm_type())) if self.nbits < target_bits: - r.set_ssa_value( - builder, builder.sext(self.get_ssa_value(builder), + r.auto_store( + builder, builder.sext(self.auto_load(builder), r.get_llvm_type())) return r o_roundx = o_intx @@ -101,9 +101,9 @@ def _make_vint_binop_method(builder_name): left = self.o_intx(target_bits, builder) right = other.o_intx(target_bits, builder) bf = getattr(builder, builder_name) - r.set_ssa_value( - builder, bf(left.get_ssa_value(builder), - right.get_ssa_value(builder))) + r.auto_store( + builder, bf(left.auto_load(builder), + right.auto_load(builder))) return r else: return NotImplemented @@ -128,11 +128,11 @@ def _make_vint_cmp_method(icmp_val): target_bits = max(self.nbits, other.nbits) left = self.o_intx(target_bits, builder) right = other.o_intx(target_bits, builder) - r.set_ssa_value( + r.auto_store( builder, builder.icmp( - icmp_val, left.get_ssa_value(builder), - right.get_ssa_value(builder))) + icmp_val, left.auto_load(builder), + right.auto_load(builder))) return r else: return NotImplemented @@ -161,5 +161,5 @@ class VBool(VInt): def o_bool(self, builder): r = VBool() if builder is not None: - r.set_ssa_value(builder, self.get_ssa_value(builder)) + r.auto_store(builder, self.auto_load(builder)) return r diff --git a/artiq/py2llvm/fractions.py b/artiq/py2llvm/fractions.py index 498a916e8..1c97369b8 100644 --- a/artiq/py2llvm/fractions.py +++ b/artiq/py2llvm/fractions.py @@ -71,7 +71,7 @@ class VFraction(VGeneric): return lc.Type.vector(lc.Type.int(64), 2) def _nd(self, builder): - ssa_value = self.get_ssa_value(builder) + ssa_value = self.auto_load(builder) a = builder.extract_element( ssa_value, lc.Constant.int(lc.Type.int(), 0)) b = builder.extract_element( @@ -79,16 +79,16 @@ class VFraction(VGeneric): return a, b def set_value_nd(self, builder, a, b): - a = a.o_int64(builder).get_ssa_value(builder) - b = b.o_int64(builder).get_ssa_value(builder) + a = a.o_int64(builder).auto_load(builder) + b = b.o_int64(builder).auto_load(builder) a, b = _reduce(builder, a, b) a, b = _signnum(builder, a, b) - self.set_ssa_value(builder, _make_ssa(builder, a, b)) + self.auto_store(builder, _make_ssa(builder, a, b)) def set_value(self, builder, v): if not isinstance(v, VFraction): raise TypeError - self.set_ssa_value(builder, v.get_ssa_value(builder)) + self.auto_store(builder, v.auto_load(builder)) def o_getattr(self, attr, builder): if attr == "numerator": @@ -100,9 +100,9 @@ class VFraction(VGeneric): r = VInt(64) if builder is not None: elt = builder.extract_element( - self.get_ssa_value(builder), + self.auto_load(builder), lc.Constant.int(lc.Type.int(), idx)) - r.set_ssa_value(builder, elt) + r.auto_store(builder, elt) return r def o_bool(self, builder): @@ -110,8 +110,8 @@ class VFraction(VGeneric): if builder is not None: zero = lc.Constant.int(lc.Type.int(64), 0) a = builder.extract_element( - self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) - r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, a, zero)) + self.auto_load(builder), lc.Constant.int(lc.Type.int(), 0)) + r.auto_store(builder, builder.icmp(lc.ICMP_NE, a, zero)) return r def o_intx(self, target_bits, builder): @@ -120,7 +120,7 @@ class VFraction(VGeneric): else: r = VInt(64) a, b = self._nd(builder) - r.set_ssa_value(builder, builder.sdiv(a, b)) + r.auto_store(builder, builder.sdiv(a, b)) return r.o_intx(target_bits, builder) def o_roundx(self, target_bits, builder): @@ -131,7 +131,7 @@ class VFraction(VGeneric): a, b = self._nd(builder) h_b = builder.ashr(b, lc.Constant.int(lc.Type.int(), 1)) a = builder.add(a, h_b) - r.set_ssa_value(builder, builder.sdiv(a, b)) + r.auto_store(builder, builder.sdiv(a, b)) return r.o_intx(target_bits, builder) def _o_eq_inv(self, other, builder, ne): @@ -144,7 +144,7 @@ class VFraction(VGeneric): a, b = self._nd(builder) ssa_r = builder.and_( builder.icmp(lc.ICMP_EQ, a, - other.get_ssa_value()), + other.auto_load()), builder.icmp(lc.ICMP_EQ, b, lc.Constant.int(lc.Type.int(64), 1))) else: @@ -156,7 +156,7 @@ class VFraction(VGeneric): if ne: ssa_r = builder.xor(ssa_r, lc.Constant.int(lc.Type.int(1), 1)) - r.set_ssa_value(builder, ssa_r) + r.auto_store(builder, ssa_r) return r def o_eq(self, other, builder): @@ -171,7 +171,7 @@ class VFraction(VGeneric): r = VFraction() if builder is not None: if isinstance(other, VInt): - i = other.o_int64(builder).get_ssa_value() + i = other.o_int64(builder).auto_load() x, rd = self._nd(builder) y = builder.mul(rd, i) else: @@ -188,7 +188,7 @@ class VFraction(VGeneric): else: rn = builder.add(x, y) rn, rd = _reduce(builder, rn, rd) # rd is already > 0 - r.set_ssa_value(builder, _make_ssa(builder, rn, rd)) + r.auto_store(builder, _make_ssa(builder, rn, rd)) return r def o_add(self, other, builder): @@ -212,7 +212,7 @@ class VFraction(VGeneric): if invert: a, b = b, a if isinstance(other, VInt): - i = other.o_int64(builder).get_ssa_value(builder) + i = other.o_int64(builder).auto_load(builder) if div: b = builder.mul(b, i) else: @@ -228,7 +228,7 @@ class VFraction(VGeneric): if div or invert: a, b = _signnum(builder, a, b) a, b = _reduce(builder, a, b) - r.set_ssa_value(builder, _make_ssa(builder, a, b)) + r.auto_store(builder, _make_ssa(builder, a, b)) return r def o_mul(self, other, builder): diff --git a/artiq/py2llvm/infer_types.py b/artiq/py2llvm/infer_types.py index 9cd830735..8df54d786 100644 --- a/artiq/py2llvm/infer_types.py +++ b/artiq/py2llvm/infer_types.py @@ -16,6 +16,19 @@ class _TypeScanner(ast.NodeVisitor): ns[target.id].merge(val) else: ns[target.id] = deepcopy(val) + elif isinstance(target, ast.Subscript): + target = target.value + levels = 0 + while isinstance(target, ast.Subscript): + target = target.value + levels += 1 + if isinstance(target, ast.Name): + target_value = ns[target.id] + for i in range(levels): + target_value = target_value.o_subscript(None, None) + target_value.merge_subscript(val) + else: + raise NotImplementedError else: raise NotImplementedError @@ -40,6 +53,7 @@ class _TypeScanner(ast.NodeVisitor): else: ns["return"] = deepcopy(val) + def infer_function_types(env, node, param_types): ns = deepcopy(param_types) ts = _TypeScanner(env, ns) diff --git a/artiq/py2llvm/module.py b/artiq/py2llvm/module.py index a0a03c512..0ab6f5562 100644 --- a/artiq/py2llvm/module.py +++ b/artiq/py2llvm/module.py @@ -46,7 +46,7 @@ class Module: for k, v in ns.items(): v.alloca(builder, k) for arg_ast, arg_llvm in zip(funcdef.args.args, function.args): - ns[arg_ast.arg].set_ssa_value(builder, arg_llvm) + ns[arg_ast.arg].auto_store(builder, arg_llvm) visitor = ast_body.Visitor(self.env, ns, builder) visitor.visit_statements(funcdef.body) @@ -55,6 +55,6 @@ class Module: if isinstance(retval, base_types.VNone): builder.ret_void() else: - builder.ret(retval.get_ssa_value(builder)) + builder.ret(retval.auto_load(builder)) return function, retval diff --git a/artiq/py2llvm/values.py b/artiq/py2llvm/values.py index e5536181a..037ba48a3 100644 --- a/artiq/py2llvm/values.py +++ b/artiq/py2llvm/values.py @@ -1,11 +1,17 @@ from types import SimpleNamespace +from copy import copy from llvm import core as lc class VGeneric: def __init__(self): - self._llvm_value = None + self.llvm_value = None + + def new(self): + r = copy(self) + r.llvm_value = None + return r def __repr__(self): return "<" + self.__class__.__name__ + ">" @@ -18,25 +24,25 @@ class VGeneric: raise TypeError("Incompatible types: {} and {}" .format(repr(self), repr(other))) - def get_ssa_value(self, builder): - if isinstance(self._llvm_value, lc.AllocaInstruction): - return builder.load(self._llvm_value) + def auto_load(self, builder): + if isinstance(self.llvm_value.type, lc.PointerType): + return builder.load(self.llvm_value) else: - return self._llvm_value + return self.llvm_value - def set_ssa_value(self, builder, value): - if self._llvm_value is None: - self._llvm_value = value - elif isinstance(self._llvm_value, lc.AllocaInstruction): - builder.store(value, self._llvm_value) + def auto_store(self, builder, llvm_value): + if self.llvm_value is None: + self.llvm_value = llvm_value + elif isinstance(self.llvm_value.type, lc.PointerType): + builder.store(llvm_value, self.llvm_value) else: raise RuntimeError( "Attempted to set LLVM SSA value multiple times") def alloca(self, builder, name): - if self._llvm_value is not None: + if self.llvm_value is not None: raise RuntimeError("Attempted to alloca existing LLVM value "+name) - self._llvm_value = builder.alloca(self.get_llvm_type(), name=name) + self.llvm_value = builder.alloca(self.get_llvm_type(), name=name) def o_int(self, builder): return self.o_intx(32, builder) diff --git a/test/py2llvm.py b/test/py2llvm.py index c619fd364..c5e41cf3c 100644 --- a/test/py2llvm.py +++ b/test/py2llvm.py @@ -5,13 +5,13 @@ from fractions import Fraction from llvm import ee as le -from artiq.language.core import int64 +from artiq.language.core import int64, array from artiq.py2llvm.infer_types import infer_function_types -from artiq.py2llvm import base_types +from artiq.py2llvm import base_types, arrays from artiq.py2llvm.module import Module -def test_types(choice): +def test_base_types(choice): a = 2 # promoted later to int64 b = a + 1 # initially int32, becomes int64 after a is promoted c = b//2 # initially int32, becomes int64 after b is promoted @@ -27,13 +27,17 @@ def test_types(choice): return x + c -class FunctionTypesCase(unittest.TestCase): - def setUp(self): - self.ns = infer_function_types( - None, ast.parse(inspect.getsource(test_types)), - dict()) +def _build_function_types(f): + return infer_function_types( + None, ast.parse(inspect.getsource(f)), + dict()) - def test_base_types(self): + +class FunctionBaseTypesCase(unittest.TestCase): + def setUp(self): + self.ns = _build_function_types(test_base_types) + + def test_simple_types(self): self.assertIsInstance(self.ns["foo"], base_types.VBool) self.assertIsInstance(self.ns["bar"], base_types.VNone) self.assertIsInstance(self.ns["d"], base_types.VInt) @@ -51,6 +55,23 @@ class FunctionTypesCase(unittest.TestCase): self.assertEqual(self.ns["return"].nbits, 64) +def test_array_types(): + a = array(0, 5) + a[3] = int64(8) + return a + + +class FunctionArrayTypesCase(unittest.TestCase): + def setUp(self): + self.ns = _build_function_types(test_array_types) + + def test_array_types(self): + self.assertIsInstance(self.ns["a"], arrays.VArray) + self.assertIsInstance(self.ns["a"].el_init, base_types.VInt) + self.assertEqual(self.ns["a"].el_init.nbits, 64) + self.assertEqual(self.ns["a"].count, 5) + + class CompiledFunction: def __init__(self, function, param_types): module = Module() @@ -99,6 +120,23 @@ def arith_encode(op, a, b, c, d): return f.numerator*1000 + f.denominator +def array_test(): + a = array(array(2, 5), 5) + a[3][2] = 11 + a[4][1] = 42 + a[0][0] += 6 + + acc = 0 + i = 0 + while i < 5: + j = 0 + while j < 5: + acc += a[i][j] + j += 1 + i += 1 + return acc + + class CodeGenCase(unittest.TestCase): def test_is_prime(self): is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()}) @@ -138,3 +176,7 @@ class CodeGenCase(unittest.TestCase): def test_frac_div(self): self._test_frac_arith(3) + + def test_array(self): + array_test_c = CompiledFunction(array_test, dict()) + self.assertEqual(array_test_c(), array_test())