py2llvm: array support

This commit is contained in:
Sebastien Bourdeauducq 2014-09-09 17:13:48 +08:00
parent e2ca571c89
commit eec52a2e29
8 changed files with 256 additions and 108 deletions

70
artiq/py2llvm/arrays.py Normal file
View File

@ -0,0 +1,70 @@
from llvm import core as lc
from artiq.py2llvm.values import VGeneric
from artiq.py2llvm.base_types import VInt
class VArray(VGeneric):
def __init__(self, el_init, count):
VGeneric.__init__(self)
self.el_init = el_init
self.count = count
if not count:
raise TypeError("Arrays must have at least one element")
def get_llvm_type(self):
return lc.Type.array(self.el_init.get_llvm_type(), self.count)
def __repr__(self):
return "<VArray:{} x{}>".format(repr(self.el_init), self.count)
def same_type(self, other):
return (
isinstance(other, VArray)
and self.el_init.same_type(other.el_init)
and self.count == other.count)
def merge(self, other):
if isinstance(other, VArray):
self.el_init.merge(other.el_init)
else:
raise TypeError("Incompatible types: {} and {}"
.format(repr(self), repr(other)))
def merge_subscript(self, other):
self.el_init.merge(other)
def set_value(self, builder, v):
if not isinstance(v, VArray):
raise TypeError
if v.llvm_value is not None:
raise NotImplementedError("Array aliasing is not supported")
i = VInt()
i.alloca(builder, "ai_i")
i.auto_store(builder, lc.Constant.int(lc.Type.int(), 0))
function = builder.basic_block.function
copy_block = function.append_basic_block("ai_copy")
end_block = function.append_basic_block("ai_end")
builder.branch(copy_block)
builder.position_at_end(copy_block)
self.o_subscript(i, builder).set_value(builder, v.el_init)
i.auto_store(builder, builder.add(
i.auto_load(builder), lc.Constant.int(lc.Type.int(), 1)))
cont = builder.icmp(
lc.ICMP_SLT, i.auto_load(builder),
lc.Constant.int(lc.Type.int(), self.count))
builder.cbranch(cont, copy_block, end_block)
builder.position_at_end(end_block)
def o_subscript(self, index, builder):
r = self.el_init.new()
if builder is not None:
index = index.o_int(builder).auto_load(builder)
ssa_r = builder.gep(self.llvm_value, [
lc.Constant.int(lc.Type.int(), 0), index])
r.auto_store(builder, ssa_r)
return r

View File

@ -1,9 +1,41 @@
import ast
from artiq.py2llvm import values, base_types, fractions
from artiq.py2llvm import values, base_types, fractions, arrays
from artiq.py2llvm.tools import is_terminated
_ast_unops = {
ast.Invert: "o_inv",
ast.Not: "o_not",
ast.UAdd: "o_pos",
ast.USub: "o_neg"
}
_ast_binops = {
ast.Add: values.operators.add,
ast.Sub: values.operators.sub,
ast.Mult: values.operators.mul,
ast.Div: values.operators.truediv,
ast.FloorDiv: values.operators.floordiv,
ast.Mod: values.operators.mod,
ast.Pow: values.operators.pow,
ast.LShift: values.operators.lshift,
ast.RShift: values.operators.rshift,
ast.BitOr: values.operators.or_,
ast.BitXor: values.operators.xor,
ast.BitAnd: values.operators.and_
}
_ast_cmps = {
ast.Eq: values.operators.eq,
ast.NotEq: values.operators.ne,
ast.Lt: values.operators.lt,
ast.LtE: values.operators.le,
ast.Gt: values.operators.gt,
ast.GtE: values.operators.ge
}
class Visitor:
def __init__(self, env, ns, builder=None):
self.env = env
@ -53,48 +85,20 @@ class Visitor:
return r
def _visit_expr_UnaryOp(self, node):
ast_unops = {
ast.Invert: "o_inv",
ast.Not: "o_not",
ast.UAdd: "o_pos",
ast.USub: "o_neg"
}
value = self.visit_expression(node.operand)
return getattr(value, ast_unops[type(node.op)])(self.builder)
return getattr(value, _ast_unops[type(node.op)])(self.builder)
def _visit_expr_BinOp(self, node):
ast_binops = {
ast.Add: values.operators.add,
ast.Sub: values.operators.sub,
ast.Mult: values.operators.mul,
ast.Div: values.operators.truediv,
ast.FloorDiv: values.operators.floordiv,
ast.Mod: values.operators.mod,
ast.Pow: values.operators.pow,
ast.LShift: values.operators.lshift,
ast.RShift: values.operators.rshift,
ast.BitOr: values.operators.or_,
ast.BitXor: values.operators.xor,
ast.BitAnd: values.operators.and_
}
return ast_binops[type(node.op)](self.visit_expression(node.left),
return _ast_binops[type(node.op)](self.visit_expression(node.left),
self.visit_expression(node.right),
self.builder)
def _visit_expr_Compare(self, node):
ast_cmps = {
ast.Eq: values.operators.eq,
ast.NotEq: values.operators.ne,
ast.Lt: values.operators.lt,
ast.LtE: values.operators.le,
ast.Gt: values.operators.gt,
ast.GtE: values.operators.ge
}
comparisons = []
old_comparator = self.visit_expression(node.left)
for op, comparator_a in zip(node.ops, node.comparators):
comparator = self.visit_expression(comparator_a)
comparison = ast_cmps[type(op)](old_comparator, comparator,
comparison = _ast_cmps[type(op)](old_comparator, comparator,
self.builder)
comparisons.append(comparison)
old_comparator = comparator
@ -115,6 +119,14 @@ class Visitor:
denominator = self.visit_expression(node.args[1])
r.set_value_nd(self.builder, numerator, denominator)
return r
elif fn == "array":
element = self.visit_expression(node.args[0])
if (isinstance(node.args[1], ast.Num)
and isinstance(node.args[1].n, int)):
count = node.args[1].n
else:
raise ValueError("Array size must be integer and constant")
return arrays.VArray(element, count)
elif fn == "syscall":
return self.env.syscall(
node.args[0].s,
@ -127,6 +139,14 @@ class Visitor:
value = self.visit_expression(node.value)
return value.o_getattr(node.attr, self.builder)
def _visit_expr_Subscript(self, node):
value = self.visit_expression(node.value)
if isinstance(node.slice, ast.Index):
index = self.visit_expression(node.slice.value)
else:
raise NotImplementedError
return value.o_subscript(index, self.builder)
def visit_statements(self, stmts):
for node in stmts:
node_type = node.__class__.__name__
@ -143,18 +163,14 @@ class Visitor:
def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value)
for target in node.targets:
if isinstance(target, ast.Name):
self.ns[target.id].set_value(self.builder, val)
else:
raise NotImplementedError
target = self.visit_expression(target)
target.set_value(self.builder, val)
def _visit_stmt_AugAssign(self, node):
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target,
right=node.value))
if isinstance(node.target, ast.Name):
self.ns[node.target.id].set_value(self.builder, val)
else:
raise NotImplementedError
target = self.visit_expression(node.target)
right = self.visit_expression(node.value)
val = _ast_binops[type(node.op)](target, right, self.builder)
target.set_value(self.builder, val)
def _visit_stmt_Expr(self, node):
self.visit_expression(node.value)
@ -166,7 +182,7 @@ class Visitor:
merge_block = function.append_basic_block("i_merge")
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(condition.get_ssa_value(self.builder),
self.builder.cbranch(condition.auto_load(self.builder),
then_block, else_block)
self.builder.position_at_end(then_block)
@ -189,14 +205,14 @@ class Visitor:
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, else_block)
condition.auto_load(self.builder), body_block, else_block)
self.builder.position_at_end(body_block)
self.visit_statements(node.body)
if not is_terminated(self.builder.basic_block):
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, merge_block)
condition.auto_load(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
@ -213,4 +229,4 @@ class Visitor:
if isinstance(val, base_types.VNone):
self.builder.ret_void()
else:
self.builder.ret(val.get_ssa_value(self.builder))
self.builder.ret(val.auto_load(self.builder))

View File

@ -46,19 +46,19 @@ class VInt(VGeneric):
.format(repr(self), repr(other)))
def set_value(self, builder, n):
self.set_ssa_value(
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
self.auto_store(
builder, n.o_intx(self.nbits, builder).auto_load(builder))
def set_const_value(self, builder, n):
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
self.auto_store(builder, lc.Constant.int(self.get_llvm_type(), n))
def o_bool(self, builder, inv=False):
r = VBool()
if builder is not None:
r.set_ssa_value(
r.auto_store(
builder, builder.icmp(
lc.ICMP_EQ if inv else lc.ICMP_NE,
self.get_ssa_value(builder),
self.auto_load(builder),
lc.Constant.int(self.get_llvm_type(), 0)))
return r
@ -68,9 +68,9 @@ class VInt(VGeneric):
def o_neg(self, builder):
r = VInt(self.nbits)
if builder is not None:
r.set_ssa_value(
r.auto_store(
builder, builder.mul(
self.get_ssa_value(builder),
self.auto_load(builder),
lc.Constant.int(self.get_llvm_type(), -1)))
return r
@ -78,15 +78,15 @@ class VInt(VGeneric):
r = VInt(target_bits)
if builder is not None:
if self.nbits == target_bits:
r.set_ssa_value(
builder, self.get_ssa_value(builder))
r.auto_store(
builder, self.auto_load(builder))
if self.nbits > target_bits:
r.set_ssa_value(
builder, builder.trunc(self.get_ssa_value(builder),
r.auto_store(
builder, builder.trunc(self.auto_load(builder),
r.get_llvm_type()))
if self.nbits < target_bits:
r.set_ssa_value(
builder, builder.sext(self.get_ssa_value(builder),
r.auto_store(
builder, builder.sext(self.auto_load(builder),
r.get_llvm_type()))
return r
o_roundx = o_intx
@ -101,9 +101,9 @@ def _make_vint_binop_method(builder_name):
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name)
r.set_ssa_value(
builder, bf(left.get_ssa_value(builder),
right.get_ssa_value(builder)))
r.auto_store(
builder, bf(left.auto_load(builder),
right.auto_load(builder)))
return r
else:
return NotImplemented
@ -128,11 +128,11 @@ def _make_vint_cmp_method(icmp_val):
target_bits = max(self.nbits, other.nbits)
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
r.set_ssa_value(
r.auto_store(
builder,
builder.icmp(
icmp_val, left.get_ssa_value(builder),
right.get_ssa_value(builder)))
icmp_val, left.auto_load(builder),
right.auto_load(builder)))
return r
else:
return NotImplemented
@ -161,5 +161,5 @@ class VBool(VInt):
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_ssa_value(builder, self.get_ssa_value(builder))
r.auto_store(builder, self.auto_load(builder))
return r

View File

@ -71,7 +71,7 @@ class VFraction(VGeneric):
return lc.Type.vector(lc.Type.int(64), 2)
def _nd(self, builder):
ssa_value = self.get_ssa_value(builder)
ssa_value = self.auto_load(builder)
a = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 0))
b = builder.extract_element(
@ -79,16 +79,16 @@ class VFraction(VGeneric):
return a, b
def set_value_nd(self, builder, a, b):
a = a.o_int64(builder).get_ssa_value(builder)
b = b.o_int64(builder).get_ssa_value(builder)
a = a.o_int64(builder).auto_load(builder)
b = b.o_int64(builder).auto_load(builder)
a, b = _reduce(builder, a, b)
a, b = _signnum(builder, a, b)
self.set_ssa_value(builder, _make_ssa(builder, a, b))
self.auto_store(builder, _make_ssa(builder, a, b))
def set_value(self, builder, v):
if not isinstance(v, VFraction):
raise TypeError
self.set_ssa_value(builder, v.get_ssa_value(builder))
self.auto_store(builder, v.auto_load(builder))
def o_getattr(self, attr, builder):
if attr == "numerator":
@ -100,9 +100,9 @@ class VFraction(VGeneric):
r = VInt(64)
if builder is not None:
elt = builder.extract_element(
self.get_ssa_value(builder),
self.auto_load(builder),
lc.Constant.int(lc.Type.int(), idx))
r.set_ssa_value(builder, elt)
r.auto_store(builder, elt)
return r
def o_bool(self, builder):
@ -110,8 +110,8 @@ class VFraction(VGeneric):
if builder is not None:
zero = lc.Constant.int(lc.Type.int(64), 0)
a = builder.extract_element(
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, a, zero))
self.auto_load(builder), lc.Constant.int(lc.Type.int(), 0))
r.auto_store(builder, builder.icmp(lc.ICMP_NE, a, zero))
return r
def o_intx(self, target_bits, builder):
@ -120,7 +120,7 @@ class VFraction(VGeneric):
else:
r = VInt(64)
a, b = self._nd(builder)
r.set_ssa_value(builder, builder.sdiv(a, b))
r.auto_store(builder, builder.sdiv(a, b))
return r.o_intx(target_bits, builder)
def o_roundx(self, target_bits, builder):
@ -131,7 +131,7 @@ class VFraction(VGeneric):
a, b = self._nd(builder)
h_b = builder.ashr(b, lc.Constant.int(lc.Type.int(), 1))
a = builder.add(a, h_b)
r.set_ssa_value(builder, builder.sdiv(a, b))
r.auto_store(builder, builder.sdiv(a, b))
return r.o_intx(target_bits, builder)
def _o_eq_inv(self, other, builder, ne):
@ -144,7 +144,7 @@ class VFraction(VGeneric):
a, b = self._nd(builder)
ssa_r = builder.and_(
builder.icmp(lc.ICMP_EQ, a,
other.get_ssa_value()),
other.auto_load()),
builder.icmp(lc.ICMP_EQ, b,
lc.Constant.int(lc.Type.int(64), 1)))
else:
@ -156,7 +156,7 @@ class VFraction(VGeneric):
if ne:
ssa_r = builder.xor(ssa_r,
lc.Constant.int(lc.Type.int(1), 1))
r.set_ssa_value(builder, ssa_r)
r.auto_store(builder, ssa_r)
return r
def o_eq(self, other, builder):
@ -171,7 +171,7 @@ class VFraction(VGeneric):
r = VFraction()
if builder is not None:
if isinstance(other, VInt):
i = other.o_int64(builder).get_ssa_value()
i = other.o_int64(builder).auto_load()
x, rd = self._nd(builder)
y = builder.mul(rd, i)
else:
@ -188,7 +188,7 @@ class VFraction(VGeneric):
else:
rn = builder.add(x, y)
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
r.set_ssa_value(builder, _make_ssa(builder, rn, rd))
r.auto_store(builder, _make_ssa(builder, rn, rd))
return r
def o_add(self, other, builder):
@ -212,7 +212,7 @@ class VFraction(VGeneric):
if invert:
a, b = b, a
if isinstance(other, VInt):
i = other.o_int64(builder).get_ssa_value(builder)
i = other.o_int64(builder).auto_load(builder)
if div:
b = builder.mul(b, i)
else:
@ -228,7 +228,7 @@ class VFraction(VGeneric):
if div or invert:
a, b = _signnum(builder, a, b)
a, b = _reduce(builder, a, b)
r.set_ssa_value(builder, _make_ssa(builder, a, b))
r.auto_store(builder, _make_ssa(builder, a, b))
return r
def o_mul(self, other, builder):

View File

@ -16,6 +16,19 @@ class _TypeScanner(ast.NodeVisitor):
ns[target.id].merge(val)
else:
ns[target.id] = deepcopy(val)
elif isinstance(target, ast.Subscript):
target = target.value
levels = 0
while isinstance(target, ast.Subscript):
target = target.value
levels += 1
if isinstance(target, ast.Name):
target_value = ns[target.id]
for i in range(levels):
target_value = target_value.o_subscript(None, None)
target_value.merge_subscript(val)
else:
raise NotImplementedError
else:
raise NotImplementedError
@ -40,6 +53,7 @@ class _TypeScanner(ast.NodeVisitor):
else:
ns["return"] = deepcopy(val)
def infer_function_types(env, node, param_types):
ns = deepcopy(param_types)
ts = _TypeScanner(env, ns)

View File

@ -46,7 +46,7 @@ class Module:
for k, v in ns.items():
v.alloca(builder, k)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
ns[arg_ast.arg].auto_store(builder, arg_llvm)
visitor = ast_body.Visitor(self.env, ns, builder)
visitor.visit_statements(funcdef.body)
@ -55,6 +55,6 @@ class Module:
if isinstance(retval, base_types.VNone):
builder.ret_void()
else:
builder.ret(retval.get_ssa_value(builder))
builder.ret(retval.auto_load(builder))
return function, retval

View File

@ -1,11 +1,17 @@
from types import SimpleNamespace
from copy import copy
from llvm import core as lc
class VGeneric:
def __init__(self):
self._llvm_value = None
self.llvm_value = None
def new(self):
r = copy(self)
r.llvm_value = None
return r
def __repr__(self):
return "<" + self.__class__.__name__ + ">"
@ -18,25 +24,25 @@ class VGeneric:
raise TypeError("Incompatible types: {} and {}"
.format(repr(self), repr(other)))
def get_ssa_value(self, builder):
if isinstance(self._llvm_value, lc.AllocaInstruction):
return builder.load(self._llvm_value)
def auto_load(self, builder):
if isinstance(self.llvm_value.type, lc.PointerType):
return builder.load(self.llvm_value)
else:
return self._llvm_value
return self.llvm_value
def set_ssa_value(self, builder, value):
if self._llvm_value is None:
self._llvm_value = value
elif isinstance(self._llvm_value, lc.AllocaInstruction):
builder.store(value, self._llvm_value)
def auto_store(self, builder, llvm_value):
if self.llvm_value is None:
self.llvm_value = llvm_value
elif isinstance(self.llvm_value.type, lc.PointerType):
builder.store(llvm_value, self.llvm_value)
else:
raise RuntimeError(
"Attempted to set LLVM SSA value multiple times")
def alloca(self, builder, name):
if self._llvm_value is not None:
if self.llvm_value is not None:
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
self.llvm_value = builder.alloca(self.get_llvm_type(), name=name)
def o_int(self, builder):
return self.o_intx(32, builder)

View File

@ -5,13 +5,13 @@ from fractions import Fraction
from llvm import ee as le
from artiq.language.core import int64
from artiq.language.core import int64, array
from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import base_types
from artiq.py2llvm import base_types, arrays
from artiq.py2llvm.module import Module
def test_types(choice):
def test_base_types(choice):
a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted
@ -27,13 +27,17 @@ def test_types(choice):
return x + c
class FunctionTypesCase(unittest.TestCase):
def setUp(self):
self.ns = infer_function_types(
None, ast.parse(inspect.getsource(test_types)),
def _build_function_types(f):
return infer_function_types(
None, ast.parse(inspect.getsource(f)),
dict())
def test_base_types(self):
class FunctionBaseTypesCase(unittest.TestCase):
def setUp(self):
self.ns = _build_function_types(test_base_types)
def test_simple_types(self):
self.assertIsInstance(self.ns["foo"], base_types.VBool)
self.assertIsInstance(self.ns["bar"], base_types.VNone)
self.assertIsInstance(self.ns["d"], base_types.VInt)
@ -51,6 +55,23 @@ class FunctionTypesCase(unittest.TestCase):
self.assertEqual(self.ns["return"].nbits, 64)
def test_array_types():
a = array(0, 5)
a[3] = int64(8)
return a
class FunctionArrayTypesCase(unittest.TestCase):
def setUp(self):
self.ns = _build_function_types(test_array_types)
def test_array_types(self):
self.assertIsInstance(self.ns["a"], arrays.VArray)
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
self.assertEqual(self.ns["a"].el_init.nbits, 64)
self.assertEqual(self.ns["a"].count, 5)
class CompiledFunction:
def __init__(self, function, param_types):
module = Module()
@ -99,6 +120,23 @@ def arith_encode(op, a, b, c, d):
return f.numerator*1000 + f.denominator
def array_test():
a = array(array(2, 5), 5)
a[3][2] = 11
a[4][1] = 42
a[0][0] += 6
acc = 0
i = 0
while i < 5:
j = 0
while j < 5:
acc += a[i][j]
j += 1
i += 1
return acc
class CodeGenCase(unittest.TestCase):
def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
@ -138,3 +176,7 @@ class CodeGenCase(unittest.TestCase):
def test_frac_div(self):
self._test_frac_arith(3)
def test_array(self):
array_test_c = CompiledFunction(array_test, dict())
self.assertEqual(array_test_c(), array_test())