From f3b727b59d9613aa46e4fea8d58b38dc87ea1989 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 17 Dec 2014 21:54:10 +0800 Subject: [PATCH] py2llvm: replace array with list --- artiq/language/core.py | 16 ----- artiq/py2llvm/arrays.py | 70 ---------------------- artiq/py2llvm/ast_body.py | 104 +++++++++++++++++++++++++++++---- artiq/py2llvm/lists.py | 52 +++++++++++++++++ artiq/py2llvm/values.py | 2 +- artiq/test/py2llvm.py | 51 ++++++++-------- artiq/transforms/tools.py | 2 +- doc/manual/getting_started.rst | 2 +- examples/photon_histogram.py | 2 +- examples/transport.py | 2 +- 10 files changed, 174 insertions(+), 129 deletions(-) delete mode 100644 artiq/py2llvm/arrays.py create mode 100644 artiq/py2llvm/lists.py diff --git a/artiq/language/core.py b/artiq/language/core.py index dc063ab73..974ee886b 100644 --- a/artiq/language/core.py +++ b/artiq/language/core.py @@ -4,7 +4,6 @@ Core ARTIQ extensions to the Python language. """ from collections import namedtuple as _namedtuple -from copy import copy as _copy from functools import wraps as _wraps from artiq.language import units as _units @@ -71,21 +70,6 @@ def round64(x): return int64(round(x)) -def array(element, count): - """Creates an array. - - The array is initialized with the value of ``element`` repeated ``count`` - times. Elements can be read and written using the regular Python index - syntax. - - For static compilation, ``count`` must be a fixed integer. - - Arrays of arrays are supported. - - """ - return [_copy(element) for i in range(count)] - - _KernelFunctionInfo = _namedtuple("_KernelFunctionInfo", "core_name k_function") diff --git a/artiq/py2llvm/arrays.py b/artiq/py2llvm/arrays.py deleted file mode 100644 index 3a5e34d9b..000000000 --- a/artiq/py2llvm/arrays.py +++ /dev/null @@ -1,70 +0,0 @@ -import llvmlite.ir as ll - -from artiq.py2llvm.values import VGeneric -from artiq.py2llvm.base_types import VInt - - -class VArray(VGeneric): - def __init__(self, el_init, count): - VGeneric.__init__(self) - self.el_init = el_init - self.count = count - if not count: - raise TypeError("Arrays must have at least one element") - - def get_llvm_type(self): - return ll.ArrayType(self.el_init.get_llvm_type(), self.count) - - def __repr__(self): - return "".format(repr(self.el_init), self.count) - - def same_type(self, other): - return ( - isinstance(other, VArray) - and self.el_init.same_type(other.el_init) - and self.count == other.count) - - def merge(self, other): - if isinstance(other, VArray): - self.el_init.merge(other.el_init) - else: - raise TypeError("Incompatible types: {} and {}" - .format(repr(self), repr(other))) - - def merge_subscript(self, other): - self.el_init.merge(other) - - def set_value(self, builder, v): - if not isinstance(v, VArray): - raise TypeError - if v.llvm_value is not None: - raise NotImplementedError("Array aliasing is not supported") - - i = VInt() - i.alloca(builder, "ai_i") - i.auto_store(builder, ll.Constant(ll.IntType(32), 0)) - - function = builder.basic_block.function - copy_block = function.append_basic_block("ai_copy") - end_block = function.append_basic_block("ai_end") - builder.branch(copy_block) - - builder.position_at_end(copy_block) - self.o_subscript(i, builder).set_value(builder, v.el_init) - i.auto_store(builder, builder.add( - i.auto_load(builder), ll.Constant(ll.IntType(32), 1))) - cont = builder.icmp_signed( - "<", i.auto_load(builder), - ll.Constant(ll.IntType(32), self.count)) - builder.cbranch(cont, copy_block, end_block) - - builder.position_at_end(end_block) - - def o_subscript(self, index, builder): - r = self.el_init.new() - if builder is not None: - index = index.o_int(builder).auto_load(builder) - ssa_r = builder.gep(self.llvm_value, [ - ll.Constant(ll.IntType(32), 0), index]) - r.auto_store(builder, ssa_r) - return r diff --git a/artiq/py2llvm/ast_body.py b/artiq/py2llvm/ast_body.py index 783aa5f2f..7f67af79d 100644 --- a/artiq/py2llvm/ast_body.py +++ b/artiq/py2llvm/ast_body.py @@ -2,7 +2,7 @@ import ast import llvmlite.ir as ll -from artiq.py2llvm import values, base_types, fractions, arrays, iterators +from artiq.py2llvm import values, base_types, fractions, lists, iterators from artiq.py2llvm.tools import is_terminated @@ -177,14 +177,6 @@ class Visitor: denominator = self.visit_expression(node.args[1]) r.set_value_nd(self.builder, numerator, denominator) return r - elif fn == "array": - element = self.visit_expression(node.args[0]) - if (isinstance(node.args[1], ast.Num) - and isinstance(node.args[1].n, int)): - count = node.args[1].n - else: - raise ValueError("Array size must be integer and constant") - return arrays.VArray(element, count) elif fn == "range": return iterators.IRange( self.builder, @@ -201,6 +193,56 @@ class Visitor: value = self.visit_expression(node.value) return value.o_getattr(node.attr, self.builder) + def _visit_expr_List(self, node): + elts = [self.visit_expression(elt) for elt in node.elts] + if elts: + el_type = elts[0].new() + for elt in elts[1:]: + el_type.merge(elt) + else: + el_type = VNone() + count = len(elts) + r = lists.VList(el_type, count) + r.elts = elts + return r + + def _visit_expr_ListComp(self, node): + if len(node.generators) != 1: + raise NotImplementedError + generator = node.generators[0] + if not isinstance(generator, ast.comprehension): + raise NotImplementedError + if not isinstance(generator.target, ast.Name): + raise NotImplementedError + target = generator.target.id + if not isinstance(generator.iter, ast.Call): + raise NotImplementedError + if not isinstance(generator.iter.func, ast.Name): + raise NotImplementedError + if generator.iter.func.id != "range": + raise NotImplementedError + if len(generator.iter.args) != 1: + raise NotImplementedError + if not isinstance(generator.iter.args[0], ast.Num): + raise NotImplementedError + count = generator.iter.args[0].n + + # Prevent incorrect use of the generator target, if it is defined in + # the local function namespace. + if target in self.ns: + old_target_val = self.ns[target] + del self.ns[target] + else: + old_target_val = None + elt = self.visit_expression(node.elt) + if old_target_val is not None: + self.ns[target] = old_target_val + + el_type = elt.new() + r = lists.VList(el_type, count) + r.elt = elt + return r + def _visit_expr_Subscript(self, node): value = self.visit_expression(node.value) if isinstance(node.slice, ast.Index): @@ -227,9 +269,47 @@ class Visitor: def _visit_stmt_Assign(self, node): val = self.visit_expression(node.value) - for target in node.targets: - target = self.visit_expression(target) - target.set_value(self.builder, val) + if isinstance(node.value, ast.List): + if len(node.targets) > 1: + raise NotImplementedError + target = self.visit_expression(node.targets[0]) + target.set_count(self.builder, val.alloc_count) + for i, elt in enumerate(val.elts): + idx = base_types.VInt() + idx.set_const_value(self.builder, i) + target.o_subscript(idx, self.builder).set_value(self.builder, + elt) + elif isinstance(node.value, ast.ListComp): + if len(node.targets) > 1: + raise NotImplementedError + target = self.visit_expression(node.targets[0]) + target.set_count(self.builder, val.alloc_count) + + i = base_types.VInt() + i.alloca(self.builder) + i.auto_store(self.builder, ll.Constant(ll.IntType(32), 0)) + + function = self.builder.basic_block.function + copy_block = function.append_basic_block("ai_copy") + end_block = function.append_basic_block("ai_end") + self.builder.branch(copy_block) + + self.builder.position_at_end(copy_block) + target.o_subscript(i, self.builder).set_value(self.builder, + val.elt) + i.auto_store(self.builder, self.builder.add( + i.auto_load(self.builder), + ll.Constant(ll.IntType(32), 1))) + cont = self.builder.icmp_signed( + "<", i.auto_load(self.builder), + ll.Constant(ll.IntType(32), val.alloc_count)) + self.builder.cbranch(cont, copy_block, end_block) + + self.builder.position_at_end(end_block) + else: + for target in node.targets: + target = self.visit_expression(target) + target.set_value(self.builder, val) def _visit_stmt_AugAssign(self, node): target = self.visit_expression(node.target) diff --git a/artiq/py2llvm/lists.py b/artiq/py2llvm/lists.py new file mode 100644 index 000000000..cd705f3be --- /dev/null +++ b/artiq/py2llvm/lists.py @@ -0,0 +1,52 @@ +import llvmlite.ir as ll + +from artiq.py2llvm.values import VGeneric + + +class VList(VGeneric): + def __init__(self, el_type, alloc_count): + VGeneric.__init__(self) + self.el_type = el_type + self.alloc_count = alloc_count + + def get_llvm_type(self): + count = 0 if self.alloc_count is None else self.alloc_count + return ll.LiteralStructType([ll.IntType(32), + ll.ArrayType(self.el_type.get_llvm_type(), + count)]) + + def __repr__(self): + return "".format( + repr(self.el_type), + "?" if self.alloc_count is None else self.alloc_count) + + def same_type(self, other): + return (isinstance(other, VList) + and self.el_type.same_type(other.el_type)) + + def merge(self, other): + if isinstance(other, VList): + self.el_type.merge(other.el_type) + else: + raise TypeError("Incompatible types: {} and {}" + .format(repr(self), repr(other))) + + def merge_subscript(self, other): + self.el_type.merge(other) + + def set_count(self, builder, count): + count_ptr = builder.gep(self.llvm_value, [ + ll.Constant(ll.IntType(32), 0), + ll.Constant(ll.IntType(32), 0)]) + builder.store(ll.Constant(ll.IntType(32), count), count_ptr) + + def o_subscript(self, index, builder): + r = self.el_type.new() + if builder is not None: + index = index.o_int(builder).auto_load(builder) + ssa_r = builder.gep(self.llvm_value, [ + ll.Constant(ll.IntType(32), 0), + ll.Constant(ll.IntType(32), 1), + index]) + r.auto_store(builder, ssa_r) + return r diff --git a/artiq/py2llvm/values.py b/artiq/py2llvm/values.py index 5963d007b..554aad7c6 100644 --- a/artiq/py2llvm/values.py +++ b/artiq/py2llvm/values.py @@ -39,7 +39,7 @@ class VGeneric: raise RuntimeError( "Attempted to set LLVM SSA value multiple times") - def alloca(self, builder, name): + def alloca(self, builder, name=""): if self.llvm_value is not None: raise RuntimeError("Attempted to alloca existing LLVM value "+name) self.llvm_value = builder.alloca(self.get_llvm_type(), name=name) diff --git a/artiq/test/py2llvm.py b/artiq/test/py2llvm.py index 7f756dc00..9f2948db8 100644 --- a/artiq/test/py2llvm.py +++ b/artiq/test/py2llvm.py @@ -7,9 +7,9 @@ import struct import llvmlite.binding as llvm -from artiq.language.core import int64, array +from artiq.language.core import int64 from artiq.py2llvm.infer_types import infer_function_types -from artiq.py2llvm import base_types, arrays +from artiq.py2llvm import base_types, lists from artiq.py2llvm.module import Module @@ -71,22 +71,22 @@ class FunctionBaseTypesCase(unittest.TestCase): self.assertEqual(self.ns["return"].nbits, 64) -def test_array_types(): - a = array(0, 5) +def test_list_types(): + a = [0, 0, 0, 0, 0] for i in range(2): a[i] = int64(8) return a -class FunctionArrayTypesCase(unittest.TestCase): +class FunctionListTypesCase(unittest.TestCase): def setUp(self): - self.ns = _build_function_types(test_array_types) + self.ns = _build_function_types(test_list_types) - def test_array_types(self): - self.assertIsInstance(self.ns["a"], arrays.VArray) - self.assertIsInstance(self.ns["a"].el_init, base_types.VInt) - self.assertEqual(self.ns["a"].el_init.nbits, 64) - self.assertEqual(self.ns["a"].count, 5) + def test_list_types(self): + self.assertIsInstance(self.ns["a"], lists.VList) + self.assertIsInstance(self.ns["a"].el_type, base_types.VInt) + self.assertEqual(self.ns["a"].el_type.nbits, 64) + self.assertEqual(self.ns["a"].alloc_count, 5) self.assertIsInstance(self.ns["i"], base_types.VInt) self.assertEqual(self.ns["i"].nbits, 32) @@ -212,20 +212,19 @@ def frac_arith_float_rev(op, a, b, x): return x / Fraction(a, b) -def array_test(): - a = array(array(2, 5), 5) - a[3][2] = 11 - a[4][1] = 42 - a[0][0] += 6 +def list_test(): + x = 80 + a = [3 for x in range(7)] + b = [1, 2, 4, 5, 4, 0, 5] + a[3] = x + a[0] += 6 + a[1] = b[1] + b[2] acc = 0 - for i in range(5): - for j in range(5): - if i + j == 2 or i + j == 1: - continue - if i and j and a[i][j]: - acc += 1 - acc += a[i][j] + for i in range(7): + if i and a[i]: + acc += 1 + acc += a[i] return acc @@ -364,9 +363,9 @@ class CodeGenCase(unittest.TestCase): self._test_frac_arith_float(3, False) self._test_frac_arith_float(3, True) - def test_array(self): - array_test_c = CompiledFunction(array_test, dict()) - self.assertEqual(array_test_c(), array_test()) + def test_list(self): + list_test_c = CompiledFunction(list_test, dict()) + self.assertEqual(list_test_c(), list_test()) def test_corner_cases(self): corner_cases_c = CompiledFunction(corner_cases, dict()) diff --git a/artiq/transforms/tools.py b/artiq/transforms/tools.py index 2ae3a6cb5..e796490a4 100644 --- a/artiq/transforms/tools.py +++ b/artiq/transforms/tools.py @@ -10,7 +10,7 @@ embeddable_funcs = ( core_language.time_to_cycles, core_language.cycles_to_time, core_language.syscall, range, bool, int, float, round, - core_language.int64, core_language.round64, core_language.array, + core_language.int64, core_language.round64, Fraction, units.Quantity, units.check_unit, core_language.EncodedException ) embeddable_func_names = {func.__name__ for func in embeddable_funcs} diff --git a/doc/manual/getting_started.rst b/doc/manual/getting_started.rst index 212f08acc..c74e140fb 100644 --- a/doc/manual/getting_started.rst +++ b/doc/manual/getting_started.rst @@ -66,7 +66,7 @@ A number of Python algorithmic features can be used inside a kernel for compilat * 64-bit signed integers (:class:`artiq.language.core.int64`) * Signed rational numbers with 64-bit numerator and 64-bit denominator * Double-precision floating point numbers -* Arrays of the above types and arrays of arrays, at an arbitrary depth (:class:`artiq.language.core.array`) +* Lists of the above types. Lists of lists are not supported. For a demonstration of some of these features, see the ``mandelbrot.py`` example. diff --git a/examples/photon_histogram.py b/examples/photon_histogram.py index 566204309..3a2400597 100644 --- a/examples/photon_histogram.py +++ b/examples/photon_histogram.py @@ -23,7 +23,7 @@ class PhotonHistogram(AutoContext): @kernel def run(self): - hist = array(0, self.nbins) + hist = [0 for _ in range (self.nbins)] for i in range(self.repeats): n = self.cool_detect() diff --git a/examples/transport.py b/examples/transport.py index 804de18d6..9deff73dd 100644 --- a/examples/transport.py +++ b/examples/transport.py @@ -94,7 +94,7 @@ class Transport(AutoContext): @kernel def repeat(self): - hist = array(0, self.nbins) + hist = [0 for _ in range(self.nbins)] for i in range(self.repeats): n = self.one()