mirror of
https://github.com/m-labs/artiq.git
synced 2025-01-26 18:38:13 +08:00
py2llvm: for loop and range support
This commit is contained in:
parent
a580d44007
commit
d66448a486
@ -1,6 +1,6 @@
|
||||
import ast
|
||||
|
||||
from artiq.py2llvm import values, base_types, fractions, arrays
|
||||
from artiq.py2llvm import values, base_types, fractions, arrays, iterators
|
||||
from artiq.py2llvm.tools import is_terminated
|
||||
|
||||
|
||||
@ -127,6 +127,10 @@ class Visitor:
|
||||
else:
|
||||
raise ValueError("Array size must be integer and constant")
|
||||
return arrays.VArray(element, count)
|
||||
elif fn == "range":
|
||||
return iterators.IRange(
|
||||
self.builder,
|
||||
[self.visit_expression(arg) for arg in node.args])
|
||||
elif fn == "syscall":
|
||||
return self.env.syscall(
|
||||
node.args[0].s,
|
||||
@ -221,6 +225,35 @@ class Visitor:
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
def _visit_stmt_For(self, node):
|
||||
function = self.builder.basic_block.function
|
||||
body_block = function.append_basic_block("f_body")
|
||||
else_block = function.append_basic_block("f_else")
|
||||
merge_block = function.append_basic_block("f_merge")
|
||||
|
||||
it = self.visit_expression(node.iter)
|
||||
target = self.visit_expression(node.target)
|
||||
itval = it.get_value_ptr()
|
||||
|
||||
cont = it.o_next(self.builder)
|
||||
self.builder.cbranch(
|
||||
cont.auto_load(self.builder), body_block, else_block)
|
||||
|
||||
self.builder.position_at_end(body_block)
|
||||
target.set_value(self.builder, itval)
|
||||
self.visit_statements(node.body)
|
||||
if not is_terminated(self.builder.basic_block):
|
||||
cont = it.o_next(self.builder)
|
||||
self.builder.cbranch(
|
||||
cont.auto_load(self.builder), body_block, merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
if not is_terminated(self.builder.basic_block):
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
def _visit_stmt_Return(self, node):
|
||||
if node.value is None:
|
||||
val = base_types.VNone()
|
||||
|
@ -42,6 +42,11 @@ class _TypeScanner(ast.NodeVisitor):
|
||||
op=node.op, left=node.target, right=node.value))
|
||||
self._update_target(node.target, val)
|
||||
|
||||
def visit_For(self, node):
|
||||
it = self.exprv.visit_expression(node.iter)
|
||||
self._update_target(node.target, it.get_value_ptr())
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Return(self, node):
|
||||
if node.value is None:
|
||||
val = base_types.VNone()
|
||||
|
53
artiq/py2llvm/iterators.py
Normal file
53
artiq/py2llvm/iterators.py
Normal file
@ -0,0 +1,53 @@
|
||||
from llvm import core as lc
|
||||
|
||||
from artiq.py2llvm.values import operators
|
||||
from artiq.py2llvm.base_types import VBool, VInt
|
||||
|
||||
class IRange:
|
||||
def __init__(self, builder, args):
|
||||
minimum, step = None, None
|
||||
if len(args) == 1:
|
||||
maximum = args[0]
|
||||
elif len(args) == 2:
|
||||
minimum, maximum = args
|
||||
else:
|
||||
minimum, maximum, step = args
|
||||
if minimum is None:
|
||||
minimum = VInt()
|
||||
if builder is not None:
|
||||
minimum.set_const_value(builder, 0)
|
||||
if step is None:
|
||||
step = VInt()
|
||||
if builder is not None:
|
||||
step.set_const_value(builder, 1)
|
||||
|
||||
self._counter = minimum.new()
|
||||
self._counter.merge(maximum)
|
||||
self._counter.merge(step)
|
||||
self._minimum = self._counter.new()
|
||||
self._maximum = self._counter.new()
|
||||
self._step = self._counter.new()
|
||||
|
||||
if builder is not None:
|
||||
self._minimum.alloca(builder, "irange_min")
|
||||
self._maximum.alloca(builder, "irange_max")
|
||||
self._step.alloca(builder, "irange_step")
|
||||
self._counter.alloca(builder, "irange_count")
|
||||
|
||||
self._minimum.set_value(builder, minimum)
|
||||
self._maximum.set_value(builder, maximum)
|
||||
self._step.set_value(builder, step)
|
||||
|
||||
counter_init = operators.sub(self._minimum, self._step, builder)
|
||||
self._counter.set_value(builder, counter_init)
|
||||
|
||||
# must be a pointer value that can be dereferenced anytime
|
||||
# to get the current value of the iterator
|
||||
def get_value_ptr(self):
|
||||
return self._counter
|
||||
|
||||
def o_next(self, builder):
|
||||
self._counter.set_value(
|
||||
builder,
|
||||
operators.add(self._counter, self._step, builder))
|
||||
return operators.lt(self._counter, self._maximum, builder)
|
@ -14,9 +14,8 @@ class CompilerTest(AutoContext):
|
||||
@kernel
|
||||
def run(self):
|
||||
self.led.set(1)
|
||||
x = 1
|
||||
m = self.get_max()
|
||||
while x < m:
|
||||
for x in range(1, m):
|
||||
d = 2
|
||||
prime = True
|
||||
while d*d <= x:
|
||||
@ -25,7 +24,6 @@ class CompilerTest(AutoContext):
|
||||
d += 1
|
||||
if prime:
|
||||
self.output(x)
|
||||
x += 1
|
||||
self.led.set(0)
|
||||
|
||||
|
||||
|
@ -8,8 +8,7 @@ class DDSTest(AutoContext):
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
i = 0
|
||||
while i < 10000:
|
||||
for i in range(10000):
|
||||
if i & 0x200:
|
||||
self.led.set(1)
|
||||
else:
|
||||
@ -21,7 +20,6 @@ class DDSTest(AutoContext):
|
||||
with sequential:
|
||||
self.c.pulse(200*MHz, 100*us)
|
||||
self.d.pulse(250*MHz, 200*us)
|
||||
i += 1
|
||||
self.led.set(0)
|
||||
|
||||
|
||||
|
@ -57,7 +57,8 @@ class FunctionBaseTypesCase(unittest.TestCase):
|
||||
|
||||
def test_array_types():
|
||||
a = array(0, 5)
|
||||
a[3] = int64(8)
|
||||
for i in range(2):
|
||||
a[i] = int64(8)
|
||||
return a
|
||||
|
||||
|
||||
@ -70,6 +71,8 @@ class FunctionArrayTypesCase(unittest.TestCase):
|
||||
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
|
||||
self.assertEqual(self.ns["a"].el_init.nbits, 64)
|
||||
self.assertEqual(self.ns["a"].count, 5)
|
||||
self.assertIsInstance(self.ns["i"], base_types.VInt)
|
||||
self.assertEqual(self.ns["i"].nbits, 32)
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
@ -127,13 +130,9 @@ def array_test():
|
||||
a[0][0] += 6
|
||||
|
||||
acc = 0
|
||||
i = 0
|
||||
while i < 5:
|
||||
j = 0
|
||||
while j < 5:
|
||||
for i in range(5):
|
||||
for j in range(5):
|
||||
acc += a[i][j]
|
||||
j += 1
|
||||
i += 1
|
||||
return acc
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user