1
0
forked from M-Labs/artiq

py2llvm: replace array with list

This commit is contained in:
Sebastien Bourdeauducq 2014-12-17 21:54:10 +08:00
parent 6ca39f7415
commit f3b727b59d
10 changed files with 174 additions and 129 deletions

View File

@ -4,7 +4,6 @@ Core ARTIQ extensions to the Python language.
""" """
from collections import namedtuple as _namedtuple from collections import namedtuple as _namedtuple
from copy import copy as _copy
from functools import wraps as _wraps from functools import wraps as _wraps
from artiq.language import units as _units from artiq.language import units as _units
@ -71,21 +70,6 @@ def round64(x):
return int64(round(x)) return int64(round(x))
def array(element, count):
"""Creates an array.
The array is initialized with the value of ``element`` repeated ``count``
times. Elements can be read and written using the regular Python index
syntax.
For static compilation, ``count`` must be a fixed integer.
Arrays of arrays are supported.
"""
return [_copy(element) for i in range(count)]
_KernelFunctionInfo = _namedtuple("_KernelFunctionInfo", "core_name k_function") _KernelFunctionInfo = _namedtuple("_KernelFunctionInfo", "core_name k_function")

View File

@ -1,70 +0,0 @@
import llvmlite.ir as ll
from artiq.py2llvm.values import VGeneric
from artiq.py2llvm.base_types import VInt
class VArray(VGeneric):
def __init__(self, el_init, count):
VGeneric.__init__(self)
self.el_init = el_init
self.count = count
if not count:
raise TypeError("Arrays must have at least one element")
def get_llvm_type(self):
return ll.ArrayType(self.el_init.get_llvm_type(), self.count)
def __repr__(self):
return "<VArray:{} x{}>".format(repr(self.el_init), self.count)
def same_type(self, other):
return (
isinstance(other, VArray)
and self.el_init.same_type(other.el_init)
and self.count == other.count)
def merge(self, other):
if isinstance(other, VArray):
self.el_init.merge(other.el_init)
else:
raise TypeError("Incompatible types: {} and {}"
.format(repr(self), repr(other)))
def merge_subscript(self, other):
self.el_init.merge(other)
def set_value(self, builder, v):
if not isinstance(v, VArray):
raise TypeError
if v.llvm_value is not None:
raise NotImplementedError("Array aliasing is not supported")
i = VInt()
i.alloca(builder, "ai_i")
i.auto_store(builder, ll.Constant(ll.IntType(32), 0))
function = builder.basic_block.function
copy_block = function.append_basic_block("ai_copy")
end_block = function.append_basic_block("ai_end")
builder.branch(copy_block)
builder.position_at_end(copy_block)
self.o_subscript(i, builder).set_value(builder, v.el_init)
i.auto_store(builder, builder.add(
i.auto_load(builder), ll.Constant(ll.IntType(32), 1)))
cont = builder.icmp_signed(
"<", i.auto_load(builder),
ll.Constant(ll.IntType(32), self.count))
builder.cbranch(cont, copy_block, end_block)
builder.position_at_end(end_block)
def o_subscript(self, index, builder):
r = self.el_init.new()
if builder is not None:
index = index.o_int(builder).auto_load(builder)
ssa_r = builder.gep(self.llvm_value, [
ll.Constant(ll.IntType(32), 0), index])
r.auto_store(builder, ssa_r)
return r

View File

@ -2,7 +2,7 @@ import ast
import llvmlite.ir as ll import llvmlite.ir as ll
from artiq.py2llvm import values, base_types, fractions, arrays, iterators from artiq.py2llvm import values, base_types, fractions, lists, iterators
from artiq.py2llvm.tools import is_terminated from artiq.py2llvm.tools import is_terminated
@ -177,14 +177,6 @@ class Visitor:
denominator = self.visit_expression(node.args[1]) denominator = self.visit_expression(node.args[1])
r.set_value_nd(self.builder, numerator, denominator) r.set_value_nd(self.builder, numerator, denominator)
return r return r
elif fn == "array":
element = self.visit_expression(node.args[0])
if (isinstance(node.args[1], ast.Num)
and isinstance(node.args[1].n, int)):
count = node.args[1].n
else:
raise ValueError("Array size must be integer and constant")
return arrays.VArray(element, count)
elif fn == "range": elif fn == "range":
return iterators.IRange( return iterators.IRange(
self.builder, self.builder,
@ -201,6 +193,56 @@ class Visitor:
value = self.visit_expression(node.value) value = self.visit_expression(node.value)
return value.o_getattr(node.attr, self.builder) return value.o_getattr(node.attr, self.builder)
def _visit_expr_List(self, node):
elts = [self.visit_expression(elt) for elt in node.elts]
if elts:
el_type = elts[0].new()
for elt in elts[1:]:
el_type.merge(elt)
else:
el_type = VNone()
count = len(elts)
r = lists.VList(el_type, count)
r.elts = elts
return r
def _visit_expr_ListComp(self, node):
if len(node.generators) != 1:
raise NotImplementedError
generator = node.generators[0]
if not isinstance(generator, ast.comprehension):
raise NotImplementedError
if not isinstance(generator.target, ast.Name):
raise NotImplementedError
target = generator.target.id
if not isinstance(generator.iter, ast.Call):
raise NotImplementedError
if not isinstance(generator.iter.func, ast.Name):
raise NotImplementedError
if generator.iter.func.id != "range":
raise NotImplementedError
if len(generator.iter.args) != 1:
raise NotImplementedError
if not isinstance(generator.iter.args[0], ast.Num):
raise NotImplementedError
count = generator.iter.args[0].n
# Prevent incorrect use of the generator target, if it is defined in
# the local function namespace.
if target in self.ns:
old_target_val = self.ns[target]
del self.ns[target]
else:
old_target_val = None
elt = self.visit_expression(node.elt)
if old_target_val is not None:
self.ns[target] = old_target_val
el_type = elt.new()
r = lists.VList(el_type, count)
r.elt = elt
return r
def _visit_expr_Subscript(self, node): def _visit_expr_Subscript(self, node):
value = self.visit_expression(node.value) value = self.visit_expression(node.value)
if isinstance(node.slice, ast.Index): if isinstance(node.slice, ast.Index):
@ -227,9 +269,47 @@ class Visitor:
def _visit_stmt_Assign(self, node): def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value) val = self.visit_expression(node.value)
for target in node.targets: if isinstance(node.value, ast.List):
target = self.visit_expression(target) if len(node.targets) > 1:
target.set_value(self.builder, val) raise NotImplementedError
target = self.visit_expression(node.targets[0])
target.set_count(self.builder, val.alloc_count)
for i, elt in enumerate(val.elts):
idx = base_types.VInt()
idx.set_const_value(self.builder, i)
target.o_subscript(idx, self.builder).set_value(self.builder,
elt)
elif isinstance(node.value, ast.ListComp):
if len(node.targets) > 1:
raise NotImplementedError
target = self.visit_expression(node.targets[0])
target.set_count(self.builder, val.alloc_count)
i = base_types.VInt()
i.alloca(self.builder)
i.auto_store(self.builder, ll.Constant(ll.IntType(32), 0))
function = self.builder.basic_block.function
copy_block = function.append_basic_block("ai_copy")
end_block = function.append_basic_block("ai_end")
self.builder.branch(copy_block)
self.builder.position_at_end(copy_block)
target.o_subscript(i, self.builder).set_value(self.builder,
val.elt)
i.auto_store(self.builder, self.builder.add(
i.auto_load(self.builder),
ll.Constant(ll.IntType(32), 1)))
cont = self.builder.icmp_signed(
"<", i.auto_load(self.builder),
ll.Constant(ll.IntType(32), val.alloc_count))
self.builder.cbranch(cont, copy_block, end_block)
self.builder.position_at_end(end_block)
else:
for target in node.targets:
target = self.visit_expression(target)
target.set_value(self.builder, val)
def _visit_stmt_AugAssign(self, node): def _visit_stmt_AugAssign(self, node):
target = self.visit_expression(node.target) target = self.visit_expression(node.target)

52
artiq/py2llvm/lists.py Normal file
View File

@ -0,0 +1,52 @@
import llvmlite.ir as ll
from artiq.py2llvm.values import VGeneric
class VList(VGeneric):
def __init__(self, el_type, alloc_count):
VGeneric.__init__(self)
self.el_type = el_type
self.alloc_count = alloc_count
def get_llvm_type(self):
count = 0 if self.alloc_count is None else self.alloc_count
return ll.LiteralStructType([ll.IntType(32),
ll.ArrayType(self.el_type.get_llvm_type(),
count)])
def __repr__(self):
return "<VList:{} x{}>".format(
repr(self.el_type),
"?" if self.alloc_count is None else self.alloc_count)
def same_type(self, other):
return (isinstance(other, VList)
and self.el_type.same_type(other.el_type))
def merge(self, other):
if isinstance(other, VList):
self.el_type.merge(other.el_type)
else:
raise TypeError("Incompatible types: {} and {}"
.format(repr(self), repr(other)))
def merge_subscript(self, other):
self.el_type.merge(other)
def set_count(self, builder, count):
count_ptr = builder.gep(self.llvm_value, [
ll.Constant(ll.IntType(32), 0),
ll.Constant(ll.IntType(32), 0)])
builder.store(ll.Constant(ll.IntType(32), count), count_ptr)
def o_subscript(self, index, builder):
r = self.el_type.new()
if builder is not None:
index = index.o_int(builder).auto_load(builder)
ssa_r = builder.gep(self.llvm_value, [
ll.Constant(ll.IntType(32), 0),
ll.Constant(ll.IntType(32), 1),
index])
r.auto_store(builder, ssa_r)
return r

View File

@ -39,7 +39,7 @@ class VGeneric:
raise RuntimeError( raise RuntimeError(
"Attempted to set LLVM SSA value multiple times") "Attempted to set LLVM SSA value multiple times")
def alloca(self, builder, name): def alloca(self, builder, name=""):
if self.llvm_value is not None: if self.llvm_value is not None:
raise RuntimeError("Attempted to alloca existing LLVM value "+name) raise RuntimeError("Attempted to alloca existing LLVM value "+name)
self.llvm_value = builder.alloca(self.get_llvm_type(), name=name) self.llvm_value = builder.alloca(self.get_llvm_type(), name=name)

View File

@ -7,9 +7,9 @@ import struct
import llvmlite.binding as llvm import llvmlite.binding as llvm
from artiq.language.core import int64, array from artiq.language.core import int64
from artiq.py2llvm.infer_types import infer_function_types from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import base_types, arrays from artiq.py2llvm import base_types, lists
from artiq.py2llvm.module import Module from artiq.py2llvm.module import Module
@ -71,22 +71,22 @@ class FunctionBaseTypesCase(unittest.TestCase):
self.assertEqual(self.ns["return"].nbits, 64) self.assertEqual(self.ns["return"].nbits, 64)
def test_array_types(): def test_list_types():
a = array(0, 5) a = [0, 0, 0, 0, 0]
for i in range(2): for i in range(2):
a[i] = int64(8) a[i] = int64(8)
return a return a
class FunctionArrayTypesCase(unittest.TestCase): class FunctionListTypesCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.ns = _build_function_types(test_array_types) self.ns = _build_function_types(test_list_types)
def test_array_types(self): def test_list_types(self):
self.assertIsInstance(self.ns["a"], arrays.VArray) self.assertIsInstance(self.ns["a"], lists.VList)
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt) self.assertIsInstance(self.ns["a"].el_type, base_types.VInt)
self.assertEqual(self.ns["a"].el_init.nbits, 64) self.assertEqual(self.ns["a"].el_type.nbits, 64)
self.assertEqual(self.ns["a"].count, 5) self.assertEqual(self.ns["a"].alloc_count, 5)
self.assertIsInstance(self.ns["i"], base_types.VInt) self.assertIsInstance(self.ns["i"], base_types.VInt)
self.assertEqual(self.ns["i"].nbits, 32) self.assertEqual(self.ns["i"].nbits, 32)
@ -212,20 +212,19 @@ def frac_arith_float_rev(op, a, b, x):
return x / Fraction(a, b) return x / Fraction(a, b)
def array_test(): def list_test():
a = array(array(2, 5), 5) x = 80
a[3][2] = 11 a = [3 for x in range(7)]
a[4][1] = 42 b = [1, 2, 4, 5, 4, 0, 5]
a[0][0] += 6 a[3] = x
a[0] += 6
a[1] = b[1] + b[2]
acc = 0 acc = 0
for i in range(5): for i in range(7):
for j in range(5): if i and a[i]:
if i + j == 2 or i + j == 1: acc += 1
continue acc += a[i]
if i and j and a[i][j]:
acc += 1
acc += a[i][j]
return acc return acc
@ -364,9 +363,9 @@ class CodeGenCase(unittest.TestCase):
self._test_frac_arith_float(3, False) self._test_frac_arith_float(3, False)
self._test_frac_arith_float(3, True) self._test_frac_arith_float(3, True)
def test_array(self): def test_list(self):
array_test_c = CompiledFunction(array_test, dict()) list_test_c = CompiledFunction(list_test, dict())
self.assertEqual(array_test_c(), array_test()) self.assertEqual(list_test_c(), list_test())
def test_corner_cases(self): def test_corner_cases(self):
corner_cases_c = CompiledFunction(corner_cases, dict()) corner_cases_c = CompiledFunction(corner_cases, dict())

View File

@ -10,7 +10,7 @@ embeddable_funcs = (
core_language.time_to_cycles, core_language.cycles_to_time, core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall, core_language.syscall,
range, bool, int, float, round, range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array, core_language.int64, core_language.round64,
Fraction, units.Quantity, units.check_unit, core_language.EncodedException Fraction, units.Quantity, units.check_unit, core_language.EncodedException
) )
embeddable_func_names = {func.__name__ for func in embeddable_funcs} embeddable_func_names = {func.__name__ for func in embeddable_funcs}

View File

@ -66,7 +66,7 @@ A number of Python algorithmic features can be used inside a kernel for compilat
* 64-bit signed integers (:class:`artiq.language.core.int64`) * 64-bit signed integers (:class:`artiq.language.core.int64`)
* Signed rational numbers with 64-bit numerator and 64-bit denominator * Signed rational numbers with 64-bit numerator and 64-bit denominator
* Double-precision floating point numbers * Double-precision floating point numbers
* Arrays of the above types and arrays of arrays, at an arbitrary depth (:class:`artiq.language.core.array`) * Lists of the above types. Lists of lists are not supported.
For a demonstration of some of these features, see the ``mandelbrot.py`` example. For a demonstration of some of these features, see the ``mandelbrot.py`` example.

View File

@ -23,7 +23,7 @@ class PhotonHistogram(AutoContext):
@kernel @kernel
def run(self): def run(self):
hist = array(0, self.nbins) hist = [0 for _ in range (self.nbins)]
for i in range(self.repeats): for i in range(self.repeats):
n = self.cool_detect() n = self.cool_detect()

View File

@ -94,7 +94,7 @@ class Transport(AutoContext):
@kernel @kernel
def repeat(self): def repeat(self):
hist = array(0, self.nbins) hist = [0 for _ in range(self.nbins)]
for i in range(self.repeats): for i in range(self.repeats):
n = self.one() n = self.one()