mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-28 20:53:35 +08:00
py2llvm: array support
This commit is contained in:
parent
e2ca571c89
commit
eec52a2e29
70
artiq/py2llvm/arrays.py
Normal file
70
artiq/py2llvm/arrays.py
Normal file
@ -0,0 +1,70 @@
|
||||
from llvm import core as lc
|
||||
|
||||
from artiq.py2llvm.values import VGeneric
|
||||
from artiq.py2llvm.base_types import VInt
|
||||
|
||||
|
||||
class VArray(VGeneric):
|
||||
def __init__(self, el_init, count):
|
||||
VGeneric.__init__(self)
|
||||
self.el_init = el_init
|
||||
self.count = count
|
||||
if not count:
|
||||
raise TypeError("Arrays must have at least one element")
|
||||
|
||||
def get_llvm_type(self):
|
||||
return lc.Type.array(self.el_init.get_llvm_type(), self.count)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VArray:{} x{}>".format(repr(self.el_init), self.count)
|
||||
|
||||
def same_type(self, other):
|
||||
return (
|
||||
isinstance(other, VArray)
|
||||
and self.el_init.same_type(other.el_init)
|
||||
and self.count == other.count)
|
||||
|
||||
def merge(self, other):
|
||||
if isinstance(other, VArray):
|
||||
self.el_init.merge(other.el_init)
|
||||
else:
|
||||
raise TypeError("Incompatible types: {} and {}"
|
||||
.format(repr(self), repr(other)))
|
||||
|
||||
def merge_subscript(self, other):
|
||||
self.el_init.merge(other)
|
||||
|
||||
def set_value(self, builder, v):
|
||||
if not isinstance(v, VArray):
|
||||
raise TypeError
|
||||
if v.llvm_value is not None:
|
||||
raise NotImplementedError("Array aliasing is not supported")
|
||||
|
||||
i = VInt()
|
||||
i.alloca(builder, "ai_i")
|
||||
i.auto_store(builder, lc.Constant.int(lc.Type.int(), 0))
|
||||
|
||||
function = builder.basic_block.function
|
||||
copy_block = function.append_basic_block("ai_copy")
|
||||
end_block = function.append_basic_block("ai_end")
|
||||
builder.branch(copy_block)
|
||||
|
||||
builder.position_at_end(copy_block)
|
||||
self.o_subscript(i, builder).set_value(builder, v.el_init)
|
||||
i.auto_store(builder, builder.add(
|
||||
i.auto_load(builder), lc.Constant.int(lc.Type.int(), 1)))
|
||||
cont = builder.icmp(
|
||||
lc.ICMP_SLT, i.auto_load(builder),
|
||||
lc.Constant.int(lc.Type.int(), self.count))
|
||||
builder.cbranch(cont, copy_block, end_block)
|
||||
|
||||
builder.position_at_end(end_block)
|
||||
|
||||
def o_subscript(self, index, builder):
|
||||
r = self.el_init.new()
|
||||
if builder is not None:
|
||||
index = index.o_int(builder).auto_load(builder)
|
||||
ssa_r = builder.gep(self.llvm_value, [
|
||||
lc.Constant.int(lc.Type.int(), 0), index])
|
||||
r.auto_store(builder, ssa_r)
|
||||
return r
|
@ -1,9 +1,41 @@
|
||||
import ast
|
||||
|
||||
from artiq.py2llvm import values, base_types, fractions
|
||||
from artiq.py2llvm import values, base_types, fractions, arrays
|
||||
from artiq.py2llvm.tools import is_terminated
|
||||
|
||||
|
||||
_ast_unops = {
|
||||
ast.Invert: "o_inv",
|
||||
ast.Not: "o_not",
|
||||
ast.UAdd: "o_pos",
|
||||
ast.USub: "o_neg"
|
||||
}
|
||||
|
||||
_ast_binops = {
|
||||
ast.Add: values.operators.add,
|
||||
ast.Sub: values.operators.sub,
|
||||
ast.Mult: values.operators.mul,
|
||||
ast.Div: values.operators.truediv,
|
||||
ast.FloorDiv: values.operators.floordiv,
|
||||
ast.Mod: values.operators.mod,
|
||||
ast.Pow: values.operators.pow,
|
||||
ast.LShift: values.operators.lshift,
|
||||
ast.RShift: values.operators.rshift,
|
||||
ast.BitOr: values.operators.or_,
|
||||
ast.BitXor: values.operators.xor,
|
||||
ast.BitAnd: values.operators.and_
|
||||
}
|
||||
|
||||
_ast_cmps = {
|
||||
ast.Eq: values.operators.eq,
|
||||
ast.NotEq: values.operators.ne,
|
||||
ast.Lt: values.operators.lt,
|
||||
ast.LtE: values.operators.le,
|
||||
ast.Gt: values.operators.gt,
|
||||
ast.GtE: values.operators.ge
|
||||
}
|
||||
|
||||
|
||||
class Visitor:
|
||||
def __init__(self, env, ns, builder=None):
|
||||
self.env = env
|
||||
@ -53,48 +85,20 @@ class Visitor:
|
||||
return r
|
||||
|
||||
def _visit_expr_UnaryOp(self, node):
|
||||
ast_unops = {
|
||||
ast.Invert: "o_inv",
|
||||
ast.Not: "o_not",
|
||||
ast.UAdd: "o_pos",
|
||||
ast.USub: "o_neg"
|
||||
}
|
||||
value = self.visit_expression(node.operand)
|
||||
return getattr(value, ast_unops[type(node.op)])(self.builder)
|
||||
return getattr(value, _ast_unops[type(node.op)])(self.builder)
|
||||
|
||||
def _visit_expr_BinOp(self, node):
|
||||
ast_binops = {
|
||||
ast.Add: values.operators.add,
|
||||
ast.Sub: values.operators.sub,
|
||||
ast.Mult: values.operators.mul,
|
||||
ast.Div: values.operators.truediv,
|
||||
ast.FloorDiv: values.operators.floordiv,
|
||||
ast.Mod: values.operators.mod,
|
||||
ast.Pow: values.operators.pow,
|
||||
ast.LShift: values.operators.lshift,
|
||||
ast.RShift: values.operators.rshift,
|
||||
ast.BitOr: values.operators.or_,
|
||||
ast.BitXor: values.operators.xor,
|
||||
ast.BitAnd: values.operators.and_
|
||||
}
|
||||
return ast_binops[type(node.op)](self.visit_expression(node.left),
|
||||
self.visit_expression(node.right),
|
||||
self.builder)
|
||||
return _ast_binops[type(node.op)](self.visit_expression(node.left),
|
||||
self.visit_expression(node.right),
|
||||
self.builder)
|
||||
|
||||
def _visit_expr_Compare(self, node):
|
||||
ast_cmps = {
|
||||
ast.Eq: values.operators.eq,
|
||||
ast.NotEq: values.operators.ne,
|
||||
ast.Lt: values.operators.lt,
|
||||
ast.LtE: values.operators.le,
|
||||
ast.Gt: values.operators.gt,
|
||||
ast.GtE: values.operators.ge
|
||||
}
|
||||
comparisons = []
|
||||
old_comparator = self.visit_expression(node.left)
|
||||
for op, comparator_a in zip(node.ops, node.comparators):
|
||||
comparator = self.visit_expression(comparator_a)
|
||||
comparison = ast_cmps[type(op)](old_comparator, comparator,
|
||||
comparison = _ast_cmps[type(op)](old_comparator, comparator,
|
||||
self.builder)
|
||||
comparisons.append(comparison)
|
||||
old_comparator = comparator
|
||||
@ -115,6 +119,14 @@ class Visitor:
|
||||
denominator = self.visit_expression(node.args[1])
|
||||
r.set_value_nd(self.builder, numerator, denominator)
|
||||
return r
|
||||
elif fn == "array":
|
||||
element = self.visit_expression(node.args[0])
|
||||
if (isinstance(node.args[1], ast.Num)
|
||||
and isinstance(node.args[1].n, int)):
|
||||
count = node.args[1].n
|
||||
else:
|
||||
raise ValueError("Array size must be integer and constant")
|
||||
return arrays.VArray(element, count)
|
||||
elif fn == "syscall":
|
||||
return self.env.syscall(
|
||||
node.args[0].s,
|
||||
@ -127,6 +139,14 @@ class Visitor:
|
||||
value = self.visit_expression(node.value)
|
||||
return value.o_getattr(node.attr, self.builder)
|
||||
|
||||
def _visit_expr_Subscript(self, node):
|
||||
value = self.visit_expression(node.value)
|
||||
if isinstance(node.slice, ast.Index):
|
||||
index = self.visit_expression(node.slice.value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return value.o_subscript(index, self.builder)
|
||||
|
||||
def visit_statements(self, stmts):
|
||||
for node in stmts:
|
||||
node_type = node.__class__.__name__
|
||||
@ -143,18 +163,14 @@ class Visitor:
|
||||
def _visit_stmt_Assign(self, node):
|
||||
val = self.visit_expression(node.value)
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.ns[target.id].set_value(self.builder, val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
target = self.visit_expression(target)
|
||||
target.set_value(self.builder, val)
|
||||
|
||||
def _visit_stmt_AugAssign(self, node):
|
||||
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target,
|
||||
right=node.value))
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.ns[node.target.id].set_value(self.builder, val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
target = self.visit_expression(node.target)
|
||||
right = self.visit_expression(node.value)
|
||||
val = _ast_binops[type(node.op)](target, right, self.builder)
|
||||
target.set_value(self.builder, val)
|
||||
|
||||
def _visit_stmt_Expr(self, node):
|
||||
self.visit_expression(node.value)
|
||||
@ -166,7 +182,7 @@ class Visitor:
|
||||
merge_block = function.append_basic_block("i_merge")
|
||||
|
||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||
self.builder.cbranch(condition.get_ssa_value(self.builder),
|
||||
self.builder.cbranch(condition.auto_load(self.builder),
|
||||
then_block, else_block)
|
||||
|
||||
self.builder.position_at_end(then_block)
|
||||
@ -189,14 +205,14 @@ class Visitor:
|
||||
|
||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||
self.builder.cbranch(
|
||||
condition.get_ssa_value(self.builder), body_block, else_block)
|
||||
condition.auto_load(self.builder), body_block, else_block)
|
||||
|
||||
self.builder.position_at_end(body_block)
|
||||
self.visit_statements(node.body)
|
||||
if not is_terminated(self.builder.basic_block):
|
||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||
self.builder.cbranch(
|
||||
condition.get_ssa_value(self.builder), body_block, merge_block)
|
||||
condition.auto_load(self.builder), body_block, merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
@ -213,4 +229,4 @@ class Visitor:
|
||||
if isinstance(val, base_types.VNone):
|
||||
self.builder.ret_void()
|
||||
else:
|
||||
self.builder.ret(val.get_ssa_value(self.builder))
|
||||
self.builder.ret(val.auto_load(self.builder))
|
||||
|
@ -46,19 +46,19 @@ class VInt(VGeneric):
|
||||
.format(repr(self), repr(other)))
|
||||
|
||||
def set_value(self, builder, n):
|
||||
self.set_ssa_value(
|
||||
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
|
||||
self.auto_store(
|
||||
builder, n.o_intx(self.nbits, builder).auto_load(builder))
|
||||
|
||||
def set_const_value(self, builder, n):
|
||||
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
|
||||
self.auto_store(builder, lc.Constant.int(self.get_llvm_type(), n))
|
||||
|
||||
def o_bool(self, builder, inv=False):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_ssa_value(
|
||||
r.auto_store(
|
||||
builder, builder.icmp(
|
||||
lc.ICMP_EQ if inv else lc.ICMP_NE,
|
||||
self.get_ssa_value(builder),
|
||||
self.auto_load(builder),
|
||||
lc.Constant.int(self.get_llvm_type(), 0)))
|
||||
return r
|
||||
|
||||
@ -68,9 +68,9 @@ class VInt(VGeneric):
|
||||
def o_neg(self, builder):
|
||||
r = VInt(self.nbits)
|
||||
if builder is not None:
|
||||
r.set_ssa_value(
|
||||
r.auto_store(
|
||||
builder, builder.mul(
|
||||
self.get_ssa_value(builder),
|
||||
self.auto_load(builder),
|
||||
lc.Constant.int(self.get_llvm_type(), -1)))
|
||||
return r
|
||||
|
||||
@ -78,15 +78,15 @@ class VInt(VGeneric):
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
if self.nbits == target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, self.get_ssa_value(builder))
|
||||
r.auto_store(
|
||||
builder, self.auto_load(builder))
|
||||
if self.nbits > target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, builder.trunc(self.get_ssa_value(builder),
|
||||
r.auto_store(
|
||||
builder, builder.trunc(self.auto_load(builder),
|
||||
r.get_llvm_type()))
|
||||
if self.nbits < target_bits:
|
||||
r.set_ssa_value(
|
||||
builder, builder.sext(self.get_ssa_value(builder),
|
||||
r.auto_store(
|
||||
builder, builder.sext(self.auto_load(builder),
|
||||
r.get_llvm_type()))
|
||||
return r
|
||||
o_roundx = o_intx
|
||||
@ -101,9 +101,9 @@ def _make_vint_binop_method(builder_name):
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
bf = getattr(builder, builder_name)
|
||||
r.set_ssa_value(
|
||||
builder, bf(left.get_ssa_value(builder),
|
||||
right.get_ssa_value(builder)))
|
||||
r.auto_store(
|
||||
builder, bf(left.auto_load(builder),
|
||||
right.auto_load(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
@ -128,11 +128,11 @@ def _make_vint_cmp_method(icmp_val):
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
r.set_ssa_value(
|
||||
r.auto_store(
|
||||
builder,
|
||||
builder.icmp(
|
||||
icmp_val, left.get_ssa_value(builder),
|
||||
right.get_ssa_value(builder)))
|
||||
icmp_val, left.auto_load(builder),
|
||||
right.auto_load(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
@ -161,5 +161,5 @@ class VBool(VInt):
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
||||
r.auto_store(builder, self.auto_load(builder))
|
||||
return r
|
||||
|
@ -71,7 +71,7 @@ class VFraction(VGeneric):
|
||||
return lc.Type.vector(lc.Type.int(64), 2)
|
||||
|
||||
def _nd(self, builder):
|
||||
ssa_value = self.get_ssa_value(builder)
|
||||
ssa_value = self.auto_load(builder)
|
||||
a = builder.extract_element(
|
||||
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
||||
b = builder.extract_element(
|
||||
@ -79,16 +79,16 @@ class VFraction(VGeneric):
|
||||
return a, b
|
||||
|
||||
def set_value_nd(self, builder, a, b):
|
||||
a = a.o_int64(builder).get_ssa_value(builder)
|
||||
b = b.o_int64(builder).get_ssa_value(builder)
|
||||
a = a.o_int64(builder).auto_load(builder)
|
||||
b = b.o_int64(builder).auto_load(builder)
|
||||
a, b = _reduce(builder, a, b)
|
||||
a, b = _signnum(builder, a, b)
|
||||
self.set_ssa_value(builder, _make_ssa(builder, a, b))
|
||||
self.auto_store(builder, _make_ssa(builder, a, b))
|
||||
|
||||
def set_value(self, builder, v):
|
||||
if not isinstance(v, VFraction):
|
||||
raise TypeError
|
||||
self.set_ssa_value(builder, v.get_ssa_value(builder))
|
||||
self.auto_store(builder, v.auto_load(builder))
|
||||
|
||||
def o_getattr(self, attr, builder):
|
||||
if attr == "numerator":
|
||||
@ -100,9 +100,9 @@ class VFraction(VGeneric):
|
||||
r = VInt(64)
|
||||
if builder is not None:
|
||||
elt = builder.extract_element(
|
||||
self.get_ssa_value(builder),
|
||||
self.auto_load(builder),
|
||||
lc.Constant.int(lc.Type.int(), idx))
|
||||
r.set_ssa_value(builder, elt)
|
||||
r.auto_store(builder, elt)
|
||||
return r
|
||||
|
||||
def o_bool(self, builder):
|
||||
@ -110,8 +110,8 @@ class VFraction(VGeneric):
|
||||
if builder is not None:
|
||||
zero = lc.Constant.int(lc.Type.int(64), 0)
|
||||
a = builder.extract_element(
|
||||
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, a, zero))
|
||||
self.auto_load(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||
r.auto_store(builder, builder.icmp(lc.ICMP_NE, a, zero))
|
||||
return r
|
||||
|
||||
def o_intx(self, target_bits, builder):
|
||||
@ -120,7 +120,7 @@ class VFraction(VGeneric):
|
||||
else:
|
||||
r = VInt(64)
|
||||
a, b = self._nd(builder)
|
||||
r.set_ssa_value(builder, builder.sdiv(a, b))
|
||||
r.auto_store(builder, builder.sdiv(a, b))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
def o_roundx(self, target_bits, builder):
|
||||
@ -131,7 +131,7 @@ class VFraction(VGeneric):
|
||||
a, b = self._nd(builder)
|
||||
h_b = builder.ashr(b, lc.Constant.int(lc.Type.int(), 1))
|
||||
a = builder.add(a, h_b)
|
||||
r.set_ssa_value(builder, builder.sdiv(a, b))
|
||||
r.auto_store(builder, builder.sdiv(a, b))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
def _o_eq_inv(self, other, builder, ne):
|
||||
@ -144,7 +144,7 @@ class VFraction(VGeneric):
|
||||
a, b = self._nd(builder)
|
||||
ssa_r = builder.and_(
|
||||
builder.icmp(lc.ICMP_EQ, a,
|
||||
other.get_ssa_value()),
|
||||
other.auto_load()),
|
||||
builder.icmp(lc.ICMP_EQ, b,
|
||||
lc.Constant.int(lc.Type.int(64), 1)))
|
||||
else:
|
||||
@ -156,7 +156,7 @@ class VFraction(VGeneric):
|
||||
if ne:
|
||||
ssa_r = builder.xor(ssa_r,
|
||||
lc.Constant.int(lc.Type.int(1), 1))
|
||||
r.set_ssa_value(builder, ssa_r)
|
||||
r.auto_store(builder, ssa_r)
|
||||
return r
|
||||
|
||||
def o_eq(self, other, builder):
|
||||
@ -171,7 +171,7 @@ class VFraction(VGeneric):
|
||||
r = VFraction()
|
||||
if builder is not None:
|
||||
if isinstance(other, VInt):
|
||||
i = other.o_int64(builder).get_ssa_value()
|
||||
i = other.o_int64(builder).auto_load()
|
||||
x, rd = self._nd(builder)
|
||||
y = builder.mul(rd, i)
|
||||
else:
|
||||
@ -188,7 +188,7 @@ class VFraction(VGeneric):
|
||||
else:
|
||||
rn = builder.add(x, y)
|
||||
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
|
||||
r.set_ssa_value(builder, _make_ssa(builder, rn, rd))
|
||||
r.auto_store(builder, _make_ssa(builder, rn, rd))
|
||||
return r
|
||||
|
||||
def o_add(self, other, builder):
|
||||
@ -212,7 +212,7 @@ class VFraction(VGeneric):
|
||||
if invert:
|
||||
a, b = b, a
|
||||
if isinstance(other, VInt):
|
||||
i = other.o_int64(builder).get_ssa_value(builder)
|
||||
i = other.o_int64(builder).auto_load(builder)
|
||||
if div:
|
||||
b = builder.mul(b, i)
|
||||
else:
|
||||
@ -228,7 +228,7 @@ class VFraction(VGeneric):
|
||||
if div or invert:
|
||||
a, b = _signnum(builder, a, b)
|
||||
a, b = _reduce(builder, a, b)
|
||||
r.set_ssa_value(builder, _make_ssa(builder, a, b))
|
||||
r.auto_store(builder, _make_ssa(builder, a, b))
|
||||
return r
|
||||
|
||||
def o_mul(self, other, builder):
|
||||
|
@ -16,6 +16,19 @@ class _TypeScanner(ast.NodeVisitor):
|
||||
ns[target.id].merge(val)
|
||||
else:
|
||||
ns[target.id] = deepcopy(val)
|
||||
elif isinstance(target, ast.Subscript):
|
||||
target = target.value
|
||||
levels = 0
|
||||
while isinstance(target, ast.Subscript):
|
||||
target = target.value
|
||||
levels += 1
|
||||
if isinstance(target, ast.Name):
|
||||
target_value = ns[target.id]
|
||||
for i in range(levels):
|
||||
target_value = target_value.o_subscript(None, None)
|
||||
target_value.merge_subscript(val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -40,6 +53,7 @@ class _TypeScanner(ast.NodeVisitor):
|
||||
else:
|
||||
ns["return"] = deepcopy(val)
|
||||
|
||||
|
||||
def infer_function_types(env, node, param_types):
|
||||
ns = deepcopy(param_types)
|
||||
ts = _TypeScanner(env, ns)
|
||||
|
@ -46,7 +46,7 @@ class Module:
|
||||
for k, v in ns.items():
|
||||
v.alloca(builder, k)
|
||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
|
||||
ns[arg_ast.arg].auto_store(builder, arg_llvm)
|
||||
|
||||
visitor = ast_body.Visitor(self.env, ns, builder)
|
||||
visitor.visit_statements(funcdef.body)
|
||||
@ -55,6 +55,6 @@ class Module:
|
||||
if isinstance(retval, base_types.VNone):
|
||||
builder.ret_void()
|
||||
else:
|
||||
builder.ret(retval.get_ssa_value(builder))
|
||||
builder.ret(retval.auto_load(builder))
|
||||
|
||||
return function, retval
|
||||
|
@ -1,11 +1,17 @@
|
||||
from types import SimpleNamespace
|
||||
from copy import copy
|
||||
|
||||
from llvm import core as lc
|
||||
|
||||
|
||||
class VGeneric:
|
||||
def __init__(self):
|
||||
self._llvm_value = None
|
||||
self.llvm_value = None
|
||||
|
||||
def new(self):
|
||||
r = copy(self)
|
||||
r.llvm_value = None
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return "<" + self.__class__.__name__ + ">"
|
||||
@ -18,25 +24,25 @@ class VGeneric:
|
||||
raise TypeError("Incompatible types: {} and {}"
|
||||
.format(repr(self), repr(other)))
|
||||
|
||||
def get_ssa_value(self, builder):
|
||||
if isinstance(self._llvm_value, lc.AllocaInstruction):
|
||||
return builder.load(self._llvm_value)
|
||||
def auto_load(self, builder):
|
||||
if isinstance(self.llvm_value.type, lc.PointerType):
|
||||
return builder.load(self.llvm_value)
|
||||
else:
|
||||
return self._llvm_value
|
||||
return self.llvm_value
|
||||
|
||||
def set_ssa_value(self, builder, value):
|
||||
if self._llvm_value is None:
|
||||
self._llvm_value = value
|
||||
elif isinstance(self._llvm_value, lc.AllocaInstruction):
|
||||
builder.store(value, self._llvm_value)
|
||||
def auto_store(self, builder, llvm_value):
|
||||
if self.llvm_value is None:
|
||||
self.llvm_value = llvm_value
|
||||
elif isinstance(self.llvm_value.type, lc.PointerType):
|
||||
builder.store(llvm_value, self.llvm_value)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Attempted to set LLVM SSA value multiple times")
|
||||
|
||||
def alloca(self, builder, name):
|
||||
if self._llvm_value is not None:
|
||||
if self.llvm_value is not None:
|
||||
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
|
||||
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
||||
self.llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
||||
|
||||
def o_int(self, builder):
|
||||
return self.o_intx(32, builder)
|
||||
|
@ -5,13 +5,13 @@ from fractions import Fraction
|
||||
|
||||
from llvm import ee as le
|
||||
|
||||
from artiq.language.core import int64
|
||||
from artiq.language.core import int64, array
|
||||
from artiq.py2llvm.infer_types import infer_function_types
|
||||
from artiq.py2llvm import base_types
|
||||
from artiq.py2llvm import base_types, arrays
|
||||
from artiq.py2llvm.module import Module
|
||||
|
||||
|
||||
def test_types(choice):
|
||||
def test_base_types(choice):
|
||||
a = 2 # promoted later to int64
|
||||
b = a + 1 # initially int32, becomes int64 after a is promoted
|
||||
c = b//2 # initially int32, becomes int64 after b is promoted
|
||||
@ -27,13 +27,17 @@ def test_types(choice):
|
||||
return x + c
|
||||
|
||||
|
||||
class FunctionTypesCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ns = infer_function_types(
|
||||
None, ast.parse(inspect.getsource(test_types)),
|
||||
dict())
|
||||
def _build_function_types(f):
|
||||
return infer_function_types(
|
||||
None, ast.parse(inspect.getsource(f)),
|
||||
dict())
|
||||
|
||||
def test_base_types(self):
|
||||
|
||||
class FunctionBaseTypesCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ns = _build_function_types(test_base_types)
|
||||
|
||||
def test_simple_types(self):
|
||||
self.assertIsInstance(self.ns["foo"], base_types.VBool)
|
||||
self.assertIsInstance(self.ns["bar"], base_types.VNone)
|
||||
self.assertIsInstance(self.ns["d"], base_types.VInt)
|
||||
@ -51,6 +55,23 @@ class FunctionTypesCase(unittest.TestCase):
|
||||
self.assertEqual(self.ns["return"].nbits, 64)
|
||||
|
||||
|
||||
def test_array_types():
|
||||
a = array(0, 5)
|
||||
a[3] = int64(8)
|
||||
return a
|
||||
|
||||
|
||||
class FunctionArrayTypesCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ns = _build_function_types(test_array_types)
|
||||
|
||||
def test_array_types(self):
|
||||
self.assertIsInstance(self.ns["a"], arrays.VArray)
|
||||
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
|
||||
self.assertEqual(self.ns["a"].el_init.nbits, 64)
|
||||
self.assertEqual(self.ns["a"].count, 5)
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, function, param_types):
|
||||
module = Module()
|
||||
@ -99,6 +120,23 @@ def arith_encode(op, a, b, c, d):
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def array_test():
|
||||
a = array(array(2, 5), 5)
|
||||
a[3][2] = 11
|
||||
a[4][1] = 42
|
||||
a[0][0] += 6
|
||||
|
||||
acc = 0
|
||||
i = 0
|
||||
while i < 5:
|
||||
j = 0
|
||||
while j < 5:
|
||||
acc += a[i][j]
|
||||
j += 1
|
||||
i += 1
|
||||
return acc
|
||||
|
||||
|
||||
class CodeGenCase(unittest.TestCase):
|
||||
def test_is_prime(self):
|
||||
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
|
||||
@ -138,3 +176,7 @@ class CodeGenCase(unittest.TestCase):
|
||||
|
||||
def test_frac_div(self):
|
||||
self._test_frac_arith(3)
|
||||
|
||||
def test_array(self):
|
||||
array_test_c = CompiledFunction(array_test, dict())
|
||||
self.assertEqual(array_test_c(), array_test())
|
||||
|
Loading…
Reference in New Issue
Block a user