From d66448a48696d69c7abfcebd7663ef82f97eb248 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 10 Sep 2014 16:06:27 +0800 Subject: [PATCH] py2llvm: for loop and range support --- artiq/py2llvm/ast_body.py | 35 +++++++++++++++++++++++- artiq/py2llvm/infer_types.py | 5 ++++ artiq/py2llvm/iterators.py | 53 ++++++++++++++++++++++++++++++++++++ examples/coredev_test.py | 4 +-- examples/dds_test.py | 4 +-- test/py2llvm.py | 13 ++++----- 6 files changed, 100 insertions(+), 14 deletions(-) create mode 100644 artiq/py2llvm/iterators.py diff --git a/artiq/py2llvm/ast_body.py b/artiq/py2llvm/ast_body.py index 6afdd8a45..bd69ed62d 100644 --- a/artiq/py2llvm/ast_body.py +++ b/artiq/py2llvm/ast_body.py @@ -1,6 +1,6 @@ import ast -from artiq.py2llvm import values, base_types, fractions, arrays +from artiq.py2llvm import values, base_types, fractions, arrays, iterators from artiq.py2llvm.tools import is_terminated @@ -127,6 +127,10 @@ class Visitor: else: raise ValueError("Array size must be integer and constant") return arrays.VArray(element, count) + elif fn == "range": + return iterators.IRange( + self.builder, + [self.visit_expression(arg) for arg in node.args]) elif fn == "syscall": return self.env.syscall( node.args[0].s, @@ -221,6 +225,35 @@ class Visitor: self.builder.position_at_end(merge_block) + def _visit_stmt_For(self, node): + function = self.builder.basic_block.function + body_block = function.append_basic_block("f_body") + else_block = function.append_basic_block("f_else") + merge_block = function.append_basic_block("f_merge") + + it = self.visit_expression(node.iter) + target = self.visit_expression(node.target) + itval = it.get_value_ptr() + + cont = it.o_next(self.builder) + self.builder.cbranch( + cont.auto_load(self.builder), body_block, else_block) + + self.builder.position_at_end(body_block) + target.set_value(self.builder, itval) + self.visit_statements(node.body) + if not is_terminated(self.builder.basic_block): + cont = it.o_next(self.builder) + self.builder.cbranch( + cont.auto_load(self.builder), body_block, merge_block) + + self.builder.position_at_end(else_block) + self.visit_statements(node.orelse) + if not is_terminated(self.builder.basic_block): + self.builder.branch(merge_block) + + self.builder.position_at_end(merge_block) + def _visit_stmt_Return(self, node): if node.value is None: val = base_types.VNone() diff --git a/artiq/py2llvm/infer_types.py b/artiq/py2llvm/infer_types.py index 8df54d786..d02aeef36 100644 --- a/artiq/py2llvm/infer_types.py +++ b/artiq/py2llvm/infer_types.py @@ -42,6 +42,11 @@ class _TypeScanner(ast.NodeVisitor): op=node.op, left=node.target, right=node.value)) self._update_target(node.target, val) + def visit_For(self, node): + it = self.exprv.visit_expression(node.iter) + self._update_target(node.target, it.get_value_ptr()) + self.generic_visit(node) + def visit_Return(self, node): if node.value is None: val = base_types.VNone() diff --git a/artiq/py2llvm/iterators.py b/artiq/py2llvm/iterators.py new file mode 100644 index 000000000..3e8d1eba9 --- /dev/null +++ b/artiq/py2llvm/iterators.py @@ -0,0 +1,53 @@ +from llvm import core as lc + +from artiq.py2llvm.values import operators +from artiq.py2llvm.base_types import VBool, VInt + +class IRange: + def __init__(self, builder, args): + minimum, step = None, None + if len(args) == 1: + maximum = args[0] + elif len(args) == 2: + minimum, maximum = args + else: + minimum, maximum, step = args + if minimum is None: + minimum = VInt() + if builder is not None: + minimum.set_const_value(builder, 0) + if step is None: + step = VInt() + if builder is not None: + step.set_const_value(builder, 1) + + self._counter = minimum.new() + self._counter.merge(maximum) + self._counter.merge(step) + self._minimum = self._counter.new() + self._maximum = self._counter.new() + self._step = self._counter.new() + + if builder is not None: + self._minimum.alloca(builder, "irange_min") + self._maximum.alloca(builder, "irange_max") + self._step.alloca(builder, "irange_step") + self._counter.alloca(builder, "irange_count") + + self._minimum.set_value(builder, minimum) + self._maximum.set_value(builder, maximum) + self._step.set_value(builder, step) + + counter_init = operators.sub(self._minimum, self._step, builder) + self._counter.set_value(builder, counter_init) + + # must be a pointer value that can be dereferenced anytime + # to get the current value of the iterator + def get_value_ptr(self): + return self._counter + + def o_next(self, builder): + self._counter.set_value( + builder, + operators.add(self._counter, self._step, builder)) + return operators.lt(self._counter, self._maximum, builder) diff --git a/examples/coredev_test.py b/examples/coredev_test.py index f5e8cef1e..127648a25 100644 --- a/examples/coredev_test.py +++ b/examples/coredev_test.py @@ -14,9 +14,8 @@ class CompilerTest(AutoContext): @kernel def run(self): self.led.set(1) - x = 1 m = self.get_max() - while x < m: + for x in range(1, m): d = 2 prime = True while d*d <= x: @@ -25,7 +24,6 @@ class CompilerTest(AutoContext): d += 1 if prime: self.output(x) - x += 1 self.led.set(0) diff --git a/examples/dds_test.py b/examples/dds_test.py index d630dabdd..8b1869eb4 100644 --- a/examples/dds_test.py +++ b/examples/dds_test.py @@ -8,8 +8,7 @@ class DDSTest(AutoContext): @kernel def run(self): - i = 0 - while i < 10000: + for i in range(10000): if i & 0x200: self.led.set(1) else: @@ -21,7 +20,6 @@ class DDSTest(AutoContext): with sequential: self.c.pulse(200*MHz, 100*us) self.d.pulse(250*MHz, 200*us) - i += 1 self.led.set(0) diff --git a/test/py2llvm.py b/test/py2llvm.py index c5e41cf3c..f87136280 100644 --- a/test/py2llvm.py +++ b/test/py2llvm.py @@ -57,7 +57,8 @@ class FunctionBaseTypesCase(unittest.TestCase): def test_array_types(): a = array(0, 5) - a[3] = int64(8) + for i in range(2): + a[i] = int64(8) return a @@ -70,6 +71,8 @@ class FunctionArrayTypesCase(unittest.TestCase): self.assertIsInstance(self.ns["a"].el_init, base_types.VInt) self.assertEqual(self.ns["a"].el_init.nbits, 64) self.assertEqual(self.ns["a"].count, 5) + self.assertIsInstance(self.ns["i"], base_types.VInt) + self.assertEqual(self.ns["i"].nbits, 32) class CompiledFunction: @@ -127,13 +130,9 @@ def array_test(): a[0][0] += 6 acc = 0 - i = 0 - while i < 5: - j = 0 - while j < 5: + for i in range(5): + for j in range(5): acc += a[i][j] - j += 1 - i += 1 return acc