1
0
forked from M-Labs/artiq

py2llvm: for loop and range support

This commit is contained in:
Sebastien Bourdeauducq 2014-09-10 16:06:27 +08:00
parent a580d44007
commit d66448a486
6 changed files with 100 additions and 14 deletions

View File

@ -1,6 +1,6 @@
import ast
from artiq.py2llvm import values, base_types, fractions, arrays
from artiq.py2llvm import values, base_types, fractions, arrays, iterators
from artiq.py2llvm.tools import is_terminated
@ -127,6 +127,10 @@ class Visitor:
else:
raise ValueError("Array size must be integer and constant")
return arrays.VArray(element, count)
elif fn == "range":
return iterators.IRange(
self.builder,
[self.visit_expression(arg) for arg in node.args])
elif fn == "syscall":
return self.env.syscall(
node.args[0].s,
@ -221,6 +225,35 @@ class Visitor:
self.builder.position_at_end(merge_block)
def _visit_stmt_For(self, node):
function = self.builder.basic_block.function
body_block = function.append_basic_block("f_body")
else_block = function.append_basic_block("f_else")
merge_block = function.append_basic_block("f_merge")
it = self.visit_expression(node.iter)
target = self.visit_expression(node.target)
itval = it.get_value_ptr()
cont = it.o_next(self.builder)
self.builder.cbranch(
cont.auto_load(self.builder), body_block, else_block)
self.builder.position_at_end(body_block)
target.set_value(self.builder, itval)
self.visit_statements(node.body)
if not is_terminated(self.builder.basic_block):
cont = it.o_next(self.builder)
self.builder.cbranch(
cont.auto_load(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
if not is_terminated(self.builder.basic_block):
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)
def _visit_stmt_Return(self, node):
if node.value is None:
val = base_types.VNone()

View File

@ -42,6 +42,11 @@ class _TypeScanner(ast.NodeVisitor):
op=node.op, left=node.target, right=node.value))
self._update_target(node.target, val)
def visit_For(self, node):
it = self.exprv.visit_expression(node.iter)
self._update_target(node.target, it.get_value_ptr())
self.generic_visit(node)
def visit_Return(self, node):
if node.value is None:
val = base_types.VNone()

View File

@ -0,0 +1,53 @@
from llvm import core as lc
from artiq.py2llvm.values import operators
from artiq.py2llvm.base_types import VBool, VInt
class IRange:
def __init__(self, builder, args):
minimum, step = None, None
if len(args) == 1:
maximum = args[0]
elif len(args) == 2:
minimum, maximum = args
else:
minimum, maximum, step = args
if minimum is None:
minimum = VInt()
if builder is not None:
minimum.set_const_value(builder, 0)
if step is None:
step = VInt()
if builder is not None:
step.set_const_value(builder, 1)
self._counter = minimum.new()
self._counter.merge(maximum)
self._counter.merge(step)
self._minimum = self._counter.new()
self._maximum = self._counter.new()
self._step = self._counter.new()
if builder is not None:
self._minimum.alloca(builder, "irange_min")
self._maximum.alloca(builder, "irange_max")
self._step.alloca(builder, "irange_step")
self._counter.alloca(builder, "irange_count")
self._minimum.set_value(builder, minimum)
self._maximum.set_value(builder, maximum)
self._step.set_value(builder, step)
counter_init = operators.sub(self._minimum, self._step, builder)
self._counter.set_value(builder, counter_init)
# must be a pointer value that can be dereferenced anytime
# to get the current value of the iterator
def get_value_ptr(self):
return self._counter
def o_next(self, builder):
self._counter.set_value(
builder,
operators.add(self._counter, self._step, builder))
return operators.lt(self._counter, self._maximum, builder)

View File

@ -14,9 +14,8 @@ class CompilerTest(AutoContext):
@kernel
def run(self):
self.led.set(1)
x = 1
m = self.get_max()
while x < m:
for x in range(1, m):
d = 2
prime = True
while d*d <= x:
@ -25,7 +24,6 @@ class CompilerTest(AutoContext):
d += 1
if prime:
self.output(x)
x += 1
self.led.set(0)

View File

@ -8,8 +8,7 @@ class DDSTest(AutoContext):
@kernel
def run(self):
i = 0
while i < 10000:
for i in range(10000):
if i & 0x200:
self.led.set(1)
else:
@ -21,7 +20,6 @@ class DDSTest(AutoContext):
with sequential:
self.c.pulse(200*MHz, 100*us)
self.d.pulse(250*MHz, 200*us)
i += 1
self.led.set(0)

View File

@ -57,7 +57,8 @@ class FunctionBaseTypesCase(unittest.TestCase):
def test_array_types():
a = array(0, 5)
a[3] = int64(8)
for i in range(2):
a[i] = int64(8)
return a
@ -70,6 +71,8 @@ class FunctionArrayTypesCase(unittest.TestCase):
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
self.assertEqual(self.ns["a"].el_init.nbits, 64)
self.assertEqual(self.ns["a"].count, 5)
self.assertIsInstance(self.ns["i"], base_types.VInt)
self.assertEqual(self.ns["i"].nbits, 32)
class CompiledFunction:
@ -127,13 +130,9 @@ def array_test():
a[0][0] += 6
acc = 0
i = 0
while i < 5:
j = 0
while j < 5:
for i in range(5):
for j in range(5):
acc += a[i][j]
j += 1
i += 1
return acc