mirror of https://github.com/m-labs/artiq.git
py2llvm: for loop and range support
This commit is contained in:
parent
a580d44007
commit
d66448a486
|
@ -1,6 +1,6 @@
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
from artiq.py2llvm import values, base_types, fractions, arrays
|
from artiq.py2llvm import values, base_types, fractions, arrays, iterators
|
||||||
from artiq.py2llvm.tools import is_terminated
|
from artiq.py2llvm.tools import is_terminated
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,6 +127,10 @@ class Visitor:
|
||||||
else:
|
else:
|
||||||
raise ValueError("Array size must be integer and constant")
|
raise ValueError("Array size must be integer and constant")
|
||||||
return arrays.VArray(element, count)
|
return arrays.VArray(element, count)
|
||||||
|
elif fn == "range":
|
||||||
|
return iterators.IRange(
|
||||||
|
self.builder,
|
||||||
|
[self.visit_expression(arg) for arg in node.args])
|
||||||
elif fn == "syscall":
|
elif fn == "syscall":
|
||||||
return self.env.syscall(
|
return self.env.syscall(
|
||||||
node.args[0].s,
|
node.args[0].s,
|
||||||
|
@ -221,6 +225,35 @@ class Visitor:
|
||||||
|
|
||||||
self.builder.position_at_end(merge_block)
|
self.builder.position_at_end(merge_block)
|
||||||
|
|
||||||
|
def _visit_stmt_For(self, node):
|
||||||
|
function = self.builder.basic_block.function
|
||||||
|
body_block = function.append_basic_block("f_body")
|
||||||
|
else_block = function.append_basic_block("f_else")
|
||||||
|
merge_block = function.append_basic_block("f_merge")
|
||||||
|
|
||||||
|
it = self.visit_expression(node.iter)
|
||||||
|
target = self.visit_expression(node.target)
|
||||||
|
itval = it.get_value_ptr()
|
||||||
|
|
||||||
|
cont = it.o_next(self.builder)
|
||||||
|
self.builder.cbranch(
|
||||||
|
cont.auto_load(self.builder), body_block, else_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(body_block)
|
||||||
|
target.set_value(self.builder, itval)
|
||||||
|
self.visit_statements(node.body)
|
||||||
|
if not is_terminated(self.builder.basic_block):
|
||||||
|
cont = it.o_next(self.builder)
|
||||||
|
self.builder.cbranch(
|
||||||
|
cont.auto_load(self.builder), body_block, merge_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(else_block)
|
||||||
|
self.visit_statements(node.orelse)
|
||||||
|
if not is_terminated(self.builder.basic_block):
|
||||||
|
self.builder.branch(merge_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(merge_block)
|
||||||
|
|
||||||
def _visit_stmt_Return(self, node):
|
def _visit_stmt_Return(self, node):
|
||||||
if node.value is None:
|
if node.value is None:
|
||||||
val = base_types.VNone()
|
val = base_types.VNone()
|
||||||
|
|
|
@ -42,6 +42,11 @@ class _TypeScanner(ast.NodeVisitor):
|
||||||
op=node.op, left=node.target, right=node.value))
|
op=node.op, left=node.target, right=node.value))
|
||||||
self._update_target(node.target, val)
|
self._update_target(node.target, val)
|
||||||
|
|
||||||
|
def visit_For(self, node):
|
||||||
|
it = self.exprv.visit_expression(node.iter)
|
||||||
|
self._update_target(node.target, it.get_value_ptr())
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
if node.value is None:
|
if node.value is None:
|
||||||
val = base_types.VNone()
|
val = base_types.VNone()
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
from llvm import core as lc
|
||||||
|
|
||||||
|
from artiq.py2llvm.values import operators
|
||||||
|
from artiq.py2llvm.base_types import VBool, VInt
|
||||||
|
|
||||||
|
class IRange:
|
||||||
|
def __init__(self, builder, args):
|
||||||
|
minimum, step = None, None
|
||||||
|
if len(args) == 1:
|
||||||
|
maximum = args[0]
|
||||||
|
elif len(args) == 2:
|
||||||
|
minimum, maximum = args
|
||||||
|
else:
|
||||||
|
minimum, maximum, step = args
|
||||||
|
if minimum is None:
|
||||||
|
minimum = VInt()
|
||||||
|
if builder is not None:
|
||||||
|
minimum.set_const_value(builder, 0)
|
||||||
|
if step is None:
|
||||||
|
step = VInt()
|
||||||
|
if builder is not None:
|
||||||
|
step.set_const_value(builder, 1)
|
||||||
|
|
||||||
|
self._counter = minimum.new()
|
||||||
|
self._counter.merge(maximum)
|
||||||
|
self._counter.merge(step)
|
||||||
|
self._minimum = self._counter.new()
|
||||||
|
self._maximum = self._counter.new()
|
||||||
|
self._step = self._counter.new()
|
||||||
|
|
||||||
|
if builder is not None:
|
||||||
|
self._minimum.alloca(builder, "irange_min")
|
||||||
|
self._maximum.alloca(builder, "irange_max")
|
||||||
|
self._step.alloca(builder, "irange_step")
|
||||||
|
self._counter.alloca(builder, "irange_count")
|
||||||
|
|
||||||
|
self._minimum.set_value(builder, minimum)
|
||||||
|
self._maximum.set_value(builder, maximum)
|
||||||
|
self._step.set_value(builder, step)
|
||||||
|
|
||||||
|
counter_init = operators.sub(self._minimum, self._step, builder)
|
||||||
|
self._counter.set_value(builder, counter_init)
|
||||||
|
|
||||||
|
# must be a pointer value that can be dereferenced anytime
|
||||||
|
# to get the current value of the iterator
|
||||||
|
def get_value_ptr(self):
|
||||||
|
return self._counter
|
||||||
|
|
||||||
|
def o_next(self, builder):
|
||||||
|
self._counter.set_value(
|
||||||
|
builder,
|
||||||
|
operators.add(self._counter, self._step, builder))
|
||||||
|
return operators.lt(self._counter, self._maximum, builder)
|
|
@ -14,9 +14,8 @@ class CompilerTest(AutoContext):
|
||||||
@kernel
|
@kernel
|
||||||
def run(self):
|
def run(self):
|
||||||
self.led.set(1)
|
self.led.set(1)
|
||||||
x = 1
|
|
||||||
m = self.get_max()
|
m = self.get_max()
|
||||||
while x < m:
|
for x in range(1, m):
|
||||||
d = 2
|
d = 2
|
||||||
prime = True
|
prime = True
|
||||||
while d*d <= x:
|
while d*d <= x:
|
||||||
|
@ -25,7 +24,6 @@ class CompilerTest(AutoContext):
|
||||||
d += 1
|
d += 1
|
||||||
if prime:
|
if prime:
|
||||||
self.output(x)
|
self.output(x)
|
||||||
x += 1
|
|
||||||
self.led.set(0)
|
self.led.set(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,7 @@ class DDSTest(AutoContext):
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run(self):
|
def run(self):
|
||||||
i = 0
|
for i in range(10000):
|
||||||
while i < 10000:
|
|
||||||
if i & 0x200:
|
if i & 0x200:
|
||||||
self.led.set(1)
|
self.led.set(1)
|
||||||
else:
|
else:
|
||||||
|
@ -21,7 +20,6 @@ class DDSTest(AutoContext):
|
||||||
with sequential:
|
with sequential:
|
||||||
self.c.pulse(200*MHz, 100*us)
|
self.c.pulse(200*MHz, 100*us)
|
||||||
self.d.pulse(250*MHz, 200*us)
|
self.d.pulse(250*MHz, 200*us)
|
||||||
i += 1
|
|
||||||
self.led.set(0)
|
self.led.set(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,8 @@ class FunctionBaseTypesCase(unittest.TestCase):
|
||||||
|
|
||||||
def test_array_types():
|
def test_array_types():
|
||||||
a = array(0, 5)
|
a = array(0, 5)
|
||||||
a[3] = int64(8)
|
for i in range(2):
|
||||||
|
a[i] = int64(8)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,6 +71,8 @@ class FunctionArrayTypesCase(unittest.TestCase):
|
||||||
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
|
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
|
||||||
self.assertEqual(self.ns["a"].el_init.nbits, 64)
|
self.assertEqual(self.ns["a"].el_init.nbits, 64)
|
||||||
self.assertEqual(self.ns["a"].count, 5)
|
self.assertEqual(self.ns["a"].count, 5)
|
||||||
|
self.assertIsInstance(self.ns["i"], base_types.VInt)
|
||||||
|
self.assertEqual(self.ns["i"].nbits, 32)
|
||||||
|
|
||||||
|
|
||||||
class CompiledFunction:
|
class CompiledFunction:
|
||||||
|
@ -127,13 +130,9 @@ def array_test():
|
||||||
a[0][0] += 6
|
a[0][0] += 6
|
||||||
|
|
||||||
acc = 0
|
acc = 0
|
||||||
i = 0
|
for i in range(5):
|
||||||
while i < 5:
|
for j in range(5):
|
||||||
j = 0
|
|
||||||
while j < 5:
|
|
||||||
acc += a[i][j]
|
acc += a[i][j]
|
||||||
j += 1
|
|
||||||
i += 1
|
|
||||||
return acc
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue