forked from M-Labs/artiq
py2llvm: array support
This commit is contained in:
parent
e2ca571c89
commit
eec52a2e29
|
@ -0,0 +1,70 @@
|
||||||
|
from llvm import core as lc
|
||||||
|
|
||||||
|
from artiq.py2llvm.values import VGeneric
|
||||||
|
from artiq.py2llvm.base_types import VInt
|
||||||
|
|
||||||
|
|
||||||
|
class VArray(VGeneric):
|
||||||
|
def __init__(self, el_init, count):
|
||||||
|
VGeneric.__init__(self)
|
||||||
|
self.el_init = el_init
|
||||||
|
self.count = count
|
||||||
|
if not count:
|
||||||
|
raise TypeError("Arrays must have at least one element")
|
||||||
|
|
||||||
|
def get_llvm_type(self):
|
||||||
|
return lc.Type.array(self.el_init.get_llvm_type(), self.count)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<VArray:{} x{}>".format(repr(self.el_init), self.count)
|
||||||
|
|
||||||
|
def same_type(self, other):
|
||||||
|
return (
|
||||||
|
isinstance(other, VArray)
|
||||||
|
and self.el_init.same_type(other.el_init)
|
||||||
|
and self.count == other.count)
|
||||||
|
|
||||||
|
def merge(self, other):
|
||||||
|
if isinstance(other, VArray):
|
||||||
|
self.el_init.merge(other.el_init)
|
||||||
|
else:
|
||||||
|
raise TypeError("Incompatible types: {} and {}"
|
||||||
|
.format(repr(self), repr(other)))
|
||||||
|
|
||||||
|
def merge_subscript(self, other):
|
||||||
|
self.el_init.merge(other)
|
||||||
|
|
||||||
|
def set_value(self, builder, v):
|
||||||
|
if not isinstance(v, VArray):
|
||||||
|
raise TypeError
|
||||||
|
if v.llvm_value is not None:
|
||||||
|
raise NotImplementedError("Array aliasing is not supported")
|
||||||
|
|
||||||
|
i = VInt()
|
||||||
|
i.alloca(builder, "ai_i")
|
||||||
|
i.auto_store(builder, lc.Constant.int(lc.Type.int(), 0))
|
||||||
|
|
||||||
|
function = builder.basic_block.function
|
||||||
|
copy_block = function.append_basic_block("ai_copy")
|
||||||
|
end_block = function.append_basic_block("ai_end")
|
||||||
|
builder.branch(copy_block)
|
||||||
|
|
||||||
|
builder.position_at_end(copy_block)
|
||||||
|
self.o_subscript(i, builder).set_value(builder, v.el_init)
|
||||||
|
i.auto_store(builder, builder.add(
|
||||||
|
i.auto_load(builder), lc.Constant.int(lc.Type.int(), 1)))
|
||||||
|
cont = builder.icmp(
|
||||||
|
lc.ICMP_SLT, i.auto_load(builder),
|
||||||
|
lc.Constant.int(lc.Type.int(), self.count))
|
||||||
|
builder.cbranch(cont, copy_block, end_block)
|
||||||
|
|
||||||
|
builder.position_at_end(end_block)
|
||||||
|
|
||||||
|
def o_subscript(self, index, builder):
|
||||||
|
r = self.el_init.new()
|
||||||
|
if builder is not None:
|
||||||
|
index = index.o_int(builder).auto_load(builder)
|
||||||
|
ssa_r = builder.gep(self.llvm_value, [
|
||||||
|
lc.Constant.int(lc.Type.int(), 0), index])
|
||||||
|
r.auto_store(builder, ssa_r)
|
||||||
|
return r
|
|
@ -1,9 +1,41 @@
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
from artiq.py2llvm import values, base_types, fractions
|
from artiq.py2llvm import values, base_types, fractions, arrays
|
||||||
from artiq.py2llvm.tools import is_terminated
|
from artiq.py2llvm.tools import is_terminated
|
||||||
|
|
||||||
|
|
||||||
|
_ast_unops = {
|
||||||
|
ast.Invert: "o_inv",
|
||||||
|
ast.Not: "o_not",
|
||||||
|
ast.UAdd: "o_pos",
|
||||||
|
ast.USub: "o_neg"
|
||||||
|
}
|
||||||
|
|
||||||
|
_ast_binops = {
|
||||||
|
ast.Add: values.operators.add,
|
||||||
|
ast.Sub: values.operators.sub,
|
||||||
|
ast.Mult: values.operators.mul,
|
||||||
|
ast.Div: values.operators.truediv,
|
||||||
|
ast.FloorDiv: values.operators.floordiv,
|
||||||
|
ast.Mod: values.operators.mod,
|
||||||
|
ast.Pow: values.operators.pow,
|
||||||
|
ast.LShift: values.operators.lshift,
|
||||||
|
ast.RShift: values.operators.rshift,
|
||||||
|
ast.BitOr: values.operators.or_,
|
||||||
|
ast.BitXor: values.operators.xor,
|
||||||
|
ast.BitAnd: values.operators.and_
|
||||||
|
}
|
||||||
|
|
||||||
|
_ast_cmps = {
|
||||||
|
ast.Eq: values.operators.eq,
|
||||||
|
ast.NotEq: values.operators.ne,
|
||||||
|
ast.Lt: values.operators.lt,
|
||||||
|
ast.LtE: values.operators.le,
|
||||||
|
ast.Gt: values.operators.gt,
|
||||||
|
ast.GtE: values.operators.ge
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Visitor:
|
class Visitor:
|
||||||
def __init__(self, env, ns, builder=None):
|
def __init__(self, env, ns, builder=None):
|
||||||
self.env = env
|
self.env = env
|
||||||
|
@ -53,48 +85,20 @@ class Visitor:
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def _visit_expr_UnaryOp(self, node):
|
def _visit_expr_UnaryOp(self, node):
|
||||||
ast_unops = {
|
|
||||||
ast.Invert: "o_inv",
|
|
||||||
ast.Not: "o_not",
|
|
||||||
ast.UAdd: "o_pos",
|
|
||||||
ast.USub: "o_neg"
|
|
||||||
}
|
|
||||||
value = self.visit_expression(node.operand)
|
value = self.visit_expression(node.operand)
|
||||||
return getattr(value, ast_unops[type(node.op)])(self.builder)
|
return getattr(value, _ast_unops[type(node.op)])(self.builder)
|
||||||
|
|
||||||
def _visit_expr_BinOp(self, node):
|
def _visit_expr_BinOp(self, node):
|
||||||
ast_binops = {
|
return _ast_binops[type(node.op)](self.visit_expression(node.left),
|
||||||
ast.Add: values.operators.add,
|
|
||||||
ast.Sub: values.operators.sub,
|
|
||||||
ast.Mult: values.operators.mul,
|
|
||||||
ast.Div: values.operators.truediv,
|
|
||||||
ast.FloorDiv: values.operators.floordiv,
|
|
||||||
ast.Mod: values.operators.mod,
|
|
||||||
ast.Pow: values.operators.pow,
|
|
||||||
ast.LShift: values.operators.lshift,
|
|
||||||
ast.RShift: values.operators.rshift,
|
|
||||||
ast.BitOr: values.operators.or_,
|
|
||||||
ast.BitXor: values.operators.xor,
|
|
||||||
ast.BitAnd: values.operators.and_
|
|
||||||
}
|
|
||||||
return ast_binops[type(node.op)](self.visit_expression(node.left),
|
|
||||||
self.visit_expression(node.right),
|
self.visit_expression(node.right),
|
||||||
self.builder)
|
self.builder)
|
||||||
|
|
||||||
def _visit_expr_Compare(self, node):
|
def _visit_expr_Compare(self, node):
|
||||||
ast_cmps = {
|
|
||||||
ast.Eq: values.operators.eq,
|
|
||||||
ast.NotEq: values.operators.ne,
|
|
||||||
ast.Lt: values.operators.lt,
|
|
||||||
ast.LtE: values.operators.le,
|
|
||||||
ast.Gt: values.operators.gt,
|
|
||||||
ast.GtE: values.operators.ge
|
|
||||||
}
|
|
||||||
comparisons = []
|
comparisons = []
|
||||||
old_comparator = self.visit_expression(node.left)
|
old_comparator = self.visit_expression(node.left)
|
||||||
for op, comparator_a in zip(node.ops, node.comparators):
|
for op, comparator_a in zip(node.ops, node.comparators):
|
||||||
comparator = self.visit_expression(comparator_a)
|
comparator = self.visit_expression(comparator_a)
|
||||||
comparison = ast_cmps[type(op)](old_comparator, comparator,
|
comparison = _ast_cmps[type(op)](old_comparator, comparator,
|
||||||
self.builder)
|
self.builder)
|
||||||
comparisons.append(comparison)
|
comparisons.append(comparison)
|
||||||
old_comparator = comparator
|
old_comparator = comparator
|
||||||
|
@ -115,6 +119,14 @@ class Visitor:
|
||||||
denominator = self.visit_expression(node.args[1])
|
denominator = self.visit_expression(node.args[1])
|
||||||
r.set_value_nd(self.builder, numerator, denominator)
|
r.set_value_nd(self.builder, numerator, denominator)
|
||||||
return r
|
return r
|
||||||
|
elif fn == "array":
|
||||||
|
element = self.visit_expression(node.args[0])
|
||||||
|
if (isinstance(node.args[1], ast.Num)
|
||||||
|
and isinstance(node.args[1].n, int)):
|
||||||
|
count = node.args[1].n
|
||||||
|
else:
|
||||||
|
raise ValueError("Array size must be integer and constant")
|
||||||
|
return arrays.VArray(element, count)
|
||||||
elif fn == "syscall":
|
elif fn == "syscall":
|
||||||
return self.env.syscall(
|
return self.env.syscall(
|
||||||
node.args[0].s,
|
node.args[0].s,
|
||||||
|
@ -127,6 +139,14 @@ class Visitor:
|
||||||
value = self.visit_expression(node.value)
|
value = self.visit_expression(node.value)
|
||||||
return value.o_getattr(node.attr, self.builder)
|
return value.o_getattr(node.attr, self.builder)
|
||||||
|
|
||||||
|
def _visit_expr_Subscript(self, node):
|
||||||
|
value = self.visit_expression(node.value)
|
||||||
|
if isinstance(node.slice, ast.Index):
|
||||||
|
index = self.visit_expression(node.slice.value)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return value.o_subscript(index, self.builder)
|
||||||
|
|
||||||
def visit_statements(self, stmts):
|
def visit_statements(self, stmts):
|
||||||
for node in stmts:
|
for node in stmts:
|
||||||
node_type = node.__class__.__name__
|
node_type = node.__class__.__name__
|
||||||
|
@ -143,18 +163,14 @@ class Visitor:
|
||||||
def _visit_stmt_Assign(self, node):
|
def _visit_stmt_Assign(self, node):
|
||||||
val = self.visit_expression(node.value)
|
val = self.visit_expression(node.value)
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
if isinstance(target, ast.Name):
|
target = self.visit_expression(target)
|
||||||
self.ns[target.id].set_value(self.builder, val)
|
target.set_value(self.builder, val)
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _visit_stmt_AugAssign(self, node):
|
def _visit_stmt_AugAssign(self, node):
|
||||||
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target,
|
target = self.visit_expression(node.target)
|
||||||
right=node.value))
|
right = self.visit_expression(node.value)
|
||||||
if isinstance(node.target, ast.Name):
|
val = _ast_binops[type(node.op)](target, right, self.builder)
|
||||||
self.ns[node.target.id].set_value(self.builder, val)
|
target.set_value(self.builder, val)
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _visit_stmt_Expr(self, node):
|
def _visit_stmt_Expr(self, node):
|
||||||
self.visit_expression(node.value)
|
self.visit_expression(node.value)
|
||||||
|
@ -166,7 +182,7 @@ class Visitor:
|
||||||
merge_block = function.append_basic_block("i_merge")
|
merge_block = function.append_basic_block("i_merge")
|
||||||
|
|
||||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||||
self.builder.cbranch(condition.get_ssa_value(self.builder),
|
self.builder.cbranch(condition.auto_load(self.builder),
|
||||||
then_block, else_block)
|
then_block, else_block)
|
||||||
|
|
||||||
self.builder.position_at_end(then_block)
|
self.builder.position_at_end(then_block)
|
||||||
|
@ -189,14 +205,14 @@ class Visitor:
|
||||||
|
|
||||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||||
self.builder.cbranch(
|
self.builder.cbranch(
|
||||||
condition.get_ssa_value(self.builder), body_block, else_block)
|
condition.auto_load(self.builder), body_block, else_block)
|
||||||
|
|
||||||
self.builder.position_at_end(body_block)
|
self.builder.position_at_end(body_block)
|
||||||
self.visit_statements(node.body)
|
self.visit_statements(node.body)
|
||||||
if not is_terminated(self.builder.basic_block):
|
if not is_terminated(self.builder.basic_block):
|
||||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||||
self.builder.cbranch(
|
self.builder.cbranch(
|
||||||
condition.get_ssa_value(self.builder), body_block, merge_block)
|
condition.auto_load(self.builder), body_block, merge_block)
|
||||||
|
|
||||||
self.builder.position_at_end(else_block)
|
self.builder.position_at_end(else_block)
|
||||||
self.visit_statements(node.orelse)
|
self.visit_statements(node.orelse)
|
||||||
|
@ -213,4 +229,4 @@ class Visitor:
|
||||||
if isinstance(val, base_types.VNone):
|
if isinstance(val, base_types.VNone):
|
||||||
self.builder.ret_void()
|
self.builder.ret_void()
|
||||||
else:
|
else:
|
||||||
self.builder.ret(val.get_ssa_value(self.builder))
|
self.builder.ret(val.auto_load(self.builder))
|
||||||
|
|
|
@ -46,19 +46,19 @@ class VInt(VGeneric):
|
||||||
.format(repr(self), repr(other)))
|
.format(repr(self), repr(other)))
|
||||||
|
|
||||||
def set_value(self, builder, n):
|
def set_value(self, builder, n):
|
||||||
self.set_ssa_value(
|
self.auto_store(
|
||||||
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
|
builder, n.o_intx(self.nbits, builder).auto_load(builder))
|
||||||
|
|
||||||
def set_const_value(self, builder, n):
|
def set_const_value(self, builder, n):
|
||||||
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
|
self.auto_store(builder, lc.Constant.int(self.get_llvm_type(), n))
|
||||||
|
|
||||||
def o_bool(self, builder, inv=False):
|
def o_bool(self, builder, inv=False):
|
||||||
r = VBool()
|
r = VBool()
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
r.set_ssa_value(
|
r.auto_store(
|
||||||
builder, builder.icmp(
|
builder, builder.icmp(
|
||||||
lc.ICMP_EQ if inv else lc.ICMP_NE,
|
lc.ICMP_EQ if inv else lc.ICMP_NE,
|
||||||
self.get_ssa_value(builder),
|
self.auto_load(builder),
|
||||||
lc.Constant.int(self.get_llvm_type(), 0)))
|
lc.Constant.int(self.get_llvm_type(), 0)))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@ -68,9 +68,9 @@ class VInt(VGeneric):
|
||||||
def o_neg(self, builder):
|
def o_neg(self, builder):
|
||||||
r = VInt(self.nbits)
|
r = VInt(self.nbits)
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
r.set_ssa_value(
|
r.auto_store(
|
||||||
builder, builder.mul(
|
builder, builder.mul(
|
||||||
self.get_ssa_value(builder),
|
self.auto_load(builder),
|
||||||
lc.Constant.int(self.get_llvm_type(), -1)))
|
lc.Constant.int(self.get_llvm_type(), -1)))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@ -78,15 +78,15 @@ class VInt(VGeneric):
|
||||||
r = VInt(target_bits)
|
r = VInt(target_bits)
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
if self.nbits == target_bits:
|
if self.nbits == target_bits:
|
||||||
r.set_ssa_value(
|
r.auto_store(
|
||||||
builder, self.get_ssa_value(builder))
|
builder, self.auto_load(builder))
|
||||||
if self.nbits > target_bits:
|
if self.nbits > target_bits:
|
||||||
r.set_ssa_value(
|
r.auto_store(
|
||||||
builder, builder.trunc(self.get_ssa_value(builder),
|
builder, builder.trunc(self.auto_load(builder),
|
||||||
r.get_llvm_type()))
|
r.get_llvm_type()))
|
||||||
if self.nbits < target_bits:
|
if self.nbits < target_bits:
|
||||||
r.set_ssa_value(
|
r.auto_store(
|
||||||
builder, builder.sext(self.get_ssa_value(builder),
|
builder, builder.sext(self.auto_load(builder),
|
||||||
r.get_llvm_type()))
|
r.get_llvm_type()))
|
||||||
return r
|
return r
|
||||||
o_roundx = o_intx
|
o_roundx = o_intx
|
||||||
|
@ -101,9 +101,9 @@ def _make_vint_binop_method(builder_name):
|
||||||
left = self.o_intx(target_bits, builder)
|
left = self.o_intx(target_bits, builder)
|
||||||
right = other.o_intx(target_bits, builder)
|
right = other.o_intx(target_bits, builder)
|
||||||
bf = getattr(builder, builder_name)
|
bf = getattr(builder, builder_name)
|
||||||
r.set_ssa_value(
|
r.auto_store(
|
||||||
builder, bf(left.get_ssa_value(builder),
|
builder, bf(left.auto_load(builder),
|
||||||
right.get_ssa_value(builder)))
|
right.auto_load(builder)))
|
||||||
return r
|
return r
|
||||||
else:
|
else:
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
@ -128,11 +128,11 @@ def _make_vint_cmp_method(icmp_val):
|
||||||
target_bits = max(self.nbits, other.nbits)
|
target_bits = max(self.nbits, other.nbits)
|
||||||
left = self.o_intx(target_bits, builder)
|
left = self.o_intx(target_bits, builder)
|
||||||
right = other.o_intx(target_bits, builder)
|
right = other.o_intx(target_bits, builder)
|
||||||
r.set_ssa_value(
|
r.auto_store(
|
||||||
builder,
|
builder,
|
||||||
builder.icmp(
|
builder.icmp(
|
||||||
icmp_val, left.get_ssa_value(builder),
|
icmp_val, left.auto_load(builder),
|
||||||
right.get_ssa_value(builder)))
|
right.auto_load(builder)))
|
||||||
return r
|
return r
|
||||||
else:
|
else:
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
@ -161,5 +161,5 @@ class VBool(VInt):
|
||||||
def o_bool(self, builder):
|
def o_bool(self, builder):
|
||||||
r = VBool()
|
r = VBool()
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
r.set_ssa_value(builder, self.get_ssa_value(builder))
|
r.auto_store(builder, self.auto_load(builder))
|
||||||
return r
|
return r
|
||||||
|
|
|
@ -71,7 +71,7 @@ class VFraction(VGeneric):
|
||||||
return lc.Type.vector(lc.Type.int(64), 2)
|
return lc.Type.vector(lc.Type.int(64), 2)
|
||||||
|
|
||||||
def _nd(self, builder):
|
def _nd(self, builder):
|
||||||
ssa_value = self.get_ssa_value(builder)
|
ssa_value = self.auto_load(builder)
|
||||||
a = builder.extract_element(
|
a = builder.extract_element(
|
||||||
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
||||||
b = builder.extract_element(
|
b = builder.extract_element(
|
||||||
|
@ -79,16 +79,16 @@ class VFraction(VGeneric):
|
||||||
return a, b
|
return a, b
|
||||||
|
|
||||||
def set_value_nd(self, builder, a, b):
|
def set_value_nd(self, builder, a, b):
|
||||||
a = a.o_int64(builder).get_ssa_value(builder)
|
a = a.o_int64(builder).auto_load(builder)
|
||||||
b = b.o_int64(builder).get_ssa_value(builder)
|
b = b.o_int64(builder).auto_load(builder)
|
||||||
a, b = _reduce(builder, a, b)
|
a, b = _reduce(builder, a, b)
|
||||||
a, b = _signnum(builder, a, b)
|
a, b = _signnum(builder, a, b)
|
||||||
self.set_ssa_value(builder, _make_ssa(builder, a, b))
|
self.auto_store(builder, _make_ssa(builder, a, b))
|
||||||
|
|
||||||
def set_value(self, builder, v):
|
def set_value(self, builder, v):
|
||||||
if not isinstance(v, VFraction):
|
if not isinstance(v, VFraction):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
self.set_ssa_value(builder, v.get_ssa_value(builder))
|
self.auto_store(builder, v.auto_load(builder))
|
||||||
|
|
||||||
def o_getattr(self, attr, builder):
|
def o_getattr(self, attr, builder):
|
||||||
if attr == "numerator":
|
if attr == "numerator":
|
||||||
|
@ -100,9 +100,9 @@ class VFraction(VGeneric):
|
||||||
r = VInt(64)
|
r = VInt(64)
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
elt = builder.extract_element(
|
elt = builder.extract_element(
|
||||||
self.get_ssa_value(builder),
|
self.auto_load(builder),
|
||||||
lc.Constant.int(lc.Type.int(), idx))
|
lc.Constant.int(lc.Type.int(), idx))
|
||||||
r.set_ssa_value(builder, elt)
|
r.auto_store(builder, elt)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def o_bool(self, builder):
|
def o_bool(self, builder):
|
||||||
|
@ -110,8 +110,8 @@ class VFraction(VGeneric):
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
zero = lc.Constant.int(lc.Type.int(64), 0)
|
zero = lc.Constant.int(lc.Type.int(64), 0)
|
||||||
a = builder.extract_element(
|
a = builder.extract_element(
|
||||||
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
self.auto_load(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||||
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, a, zero))
|
r.auto_store(builder, builder.icmp(lc.ICMP_NE, a, zero))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def o_intx(self, target_bits, builder):
|
def o_intx(self, target_bits, builder):
|
||||||
|
@ -120,7 +120,7 @@ class VFraction(VGeneric):
|
||||||
else:
|
else:
|
||||||
r = VInt(64)
|
r = VInt(64)
|
||||||
a, b = self._nd(builder)
|
a, b = self._nd(builder)
|
||||||
r.set_ssa_value(builder, builder.sdiv(a, b))
|
r.auto_store(builder, builder.sdiv(a, b))
|
||||||
return r.o_intx(target_bits, builder)
|
return r.o_intx(target_bits, builder)
|
||||||
|
|
||||||
def o_roundx(self, target_bits, builder):
|
def o_roundx(self, target_bits, builder):
|
||||||
|
@ -131,7 +131,7 @@ class VFraction(VGeneric):
|
||||||
a, b = self._nd(builder)
|
a, b = self._nd(builder)
|
||||||
h_b = builder.ashr(b, lc.Constant.int(lc.Type.int(), 1))
|
h_b = builder.ashr(b, lc.Constant.int(lc.Type.int(), 1))
|
||||||
a = builder.add(a, h_b)
|
a = builder.add(a, h_b)
|
||||||
r.set_ssa_value(builder, builder.sdiv(a, b))
|
r.auto_store(builder, builder.sdiv(a, b))
|
||||||
return r.o_intx(target_bits, builder)
|
return r.o_intx(target_bits, builder)
|
||||||
|
|
||||||
def _o_eq_inv(self, other, builder, ne):
|
def _o_eq_inv(self, other, builder, ne):
|
||||||
|
@ -144,7 +144,7 @@ class VFraction(VGeneric):
|
||||||
a, b = self._nd(builder)
|
a, b = self._nd(builder)
|
||||||
ssa_r = builder.and_(
|
ssa_r = builder.and_(
|
||||||
builder.icmp(lc.ICMP_EQ, a,
|
builder.icmp(lc.ICMP_EQ, a,
|
||||||
other.get_ssa_value()),
|
other.auto_load()),
|
||||||
builder.icmp(lc.ICMP_EQ, b,
|
builder.icmp(lc.ICMP_EQ, b,
|
||||||
lc.Constant.int(lc.Type.int(64), 1)))
|
lc.Constant.int(lc.Type.int(64), 1)))
|
||||||
else:
|
else:
|
||||||
|
@ -156,7 +156,7 @@ class VFraction(VGeneric):
|
||||||
if ne:
|
if ne:
|
||||||
ssa_r = builder.xor(ssa_r,
|
ssa_r = builder.xor(ssa_r,
|
||||||
lc.Constant.int(lc.Type.int(1), 1))
|
lc.Constant.int(lc.Type.int(1), 1))
|
||||||
r.set_ssa_value(builder, ssa_r)
|
r.auto_store(builder, ssa_r)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def o_eq(self, other, builder):
|
def o_eq(self, other, builder):
|
||||||
|
@ -171,7 +171,7 @@ class VFraction(VGeneric):
|
||||||
r = VFraction()
|
r = VFraction()
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
if isinstance(other, VInt):
|
if isinstance(other, VInt):
|
||||||
i = other.o_int64(builder).get_ssa_value()
|
i = other.o_int64(builder).auto_load()
|
||||||
x, rd = self._nd(builder)
|
x, rd = self._nd(builder)
|
||||||
y = builder.mul(rd, i)
|
y = builder.mul(rd, i)
|
||||||
else:
|
else:
|
||||||
|
@ -188,7 +188,7 @@ class VFraction(VGeneric):
|
||||||
else:
|
else:
|
||||||
rn = builder.add(x, y)
|
rn = builder.add(x, y)
|
||||||
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
|
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
|
||||||
r.set_ssa_value(builder, _make_ssa(builder, rn, rd))
|
r.auto_store(builder, _make_ssa(builder, rn, rd))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def o_add(self, other, builder):
|
def o_add(self, other, builder):
|
||||||
|
@ -212,7 +212,7 @@ class VFraction(VGeneric):
|
||||||
if invert:
|
if invert:
|
||||||
a, b = b, a
|
a, b = b, a
|
||||||
if isinstance(other, VInt):
|
if isinstance(other, VInt):
|
||||||
i = other.o_int64(builder).get_ssa_value(builder)
|
i = other.o_int64(builder).auto_load(builder)
|
||||||
if div:
|
if div:
|
||||||
b = builder.mul(b, i)
|
b = builder.mul(b, i)
|
||||||
else:
|
else:
|
||||||
|
@ -228,7 +228,7 @@ class VFraction(VGeneric):
|
||||||
if div or invert:
|
if div or invert:
|
||||||
a, b = _signnum(builder, a, b)
|
a, b = _signnum(builder, a, b)
|
||||||
a, b = _reduce(builder, a, b)
|
a, b = _reduce(builder, a, b)
|
||||||
r.set_ssa_value(builder, _make_ssa(builder, a, b))
|
r.auto_store(builder, _make_ssa(builder, a, b))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def o_mul(self, other, builder):
|
def o_mul(self, other, builder):
|
||||||
|
|
|
@ -16,6 +16,19 @@ class _TypeScanner(ast.NodeVisitor):
|
||||||
ns[target.id].merge(val)
|
ns[target.id].merge(val)
|
||||||
else:
|
else:
|
||||||
ns[target.id] = deepcopy(val)
|
ns[target.id] = deepcopy(val)
|
||||||
|
elif isinstance(target, ast.Subscript):
|
||||||
|
target = target.value
|
||||||
|
levels = 0
|
||||||
|
while isinstance(target, ast.Subscript):
|
||||||
|
target = target.value
|
||||||
|
levels += 1
|
||||||
|
if isinstance(target, ast.Name):
|
||||||
|
target_value = ns[target.id]
|
||||||
|
for i in range(levels):
|
||||||
|
target_value = target_value.o_subscript(None, None)
|
||||||
|
target_value.merge_subscript(val)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -40,6 +53,7 @@ class _TypeScanner(ast.NodeVisitor):
|
||||||
else:
|
else:
|
||||||
ns["return"] = deepcopy(val)
|
ns["return"] = deepcopy(val)
|
||||||
|
|
||||||
|
|
||||||
def infer_function_types(env, node, param_types):
|
def infer_function_types(env, node, param_types):
|
||||||
ns = deepcopy(param_types)
|
ns = deepcopy(param_types)
|
||||||
ts = _TypeScanner(env, ns)
|
ts = _TypeScanner(env, ns)
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Module:
|
||||||
for k, v in ns.items():
|
for k, v in ns.items():
|
||||||
v.alloca(builder, k)
|
v.alloca(builder, k)
|
||||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||||
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
|
ns[arg_ast.arg].auto_store(builder, arg_llvm)
|
||||||
|
|
||||||
visitor = ast_body.Visitor(self.env, ns, builder)
|
visitor = ast_body.Visitor(self.env, ns, builder)
|
||||||
visitor.visit_statements(funcdef.body)
|
visitor.visit_statements(funcdef.body)
|
||||||
|
@ -55,6 +55,6 @@ class Module:
|
||||||
if isinstance(retval, base_types.VNone):
|
if isinstance(retval, base_types.VNone):
|
||||||
builder.ret_void()
|
builder.ret_void()
|
||||||
else:
|
else:
|
||||||
builder.ret(retval.get_ssa_value(builder))
|
builder.ret(retval.auto_load(builder))
|
||||||
|
|
||||||
return function, retval
|
return function, retval
|
||||||
|
|
|
@ -1,11 +1,17 @@
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
from llvm import core as lc
|
from llvm import core as lc
|
||||||
|
|
||||||
|
|
||||||
class VGeneric:
|
class VGeneric:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._llvm_value = None
|
self.llvm_value = None
|
||||||
|
|
||||||
|
def new(self):
|
||||||
|
r = copy(self)
|
||||||
|
r.llvm_value = None
|
||||||
|
return r
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<" + self.__class__.__name__ + ">"
|
return "<" + self.__class__.__name__ + ">"
|
||||||
|
@ -18,25 +24,25 @@ class VGeneric:
|
||||||
raise TypeError("Incompatible types: {} and {}"
|
raise TypeError("Incompatible types: {} and {}"
|
||||||
.format(repr(self), repr(other)))
|
.format(repr(self), repr(other)))
|
||||||
|
|
||||||
def get_ssa_value(self, builder):
|
def auto_load(self, builder):
|
||||||
if isinstance(self._llvm_value, lc.AllocaInstruction):
|
if isinstance(self.llvm_value.type, lc.PointerType):
|
||||||
return builder.load(self._llvm_value)
|
return builder.load(self.llvm_value)
|
||||||
else:
|
else:
|
||||||
return self._llvm_value
|
return self.llvm_value
|
||||||
|
|
||||||
def set_ssa_value(self, builder, value):
|
def auto_store(self, builder, llvm_value):
|
||||||
if self._llvm_value is None:
|
if self.llvm_value is None:
|
||||||
self._llvm_value = value
|
self.llvm_value = llvm_value
|
||||||
elif isinstance(self._llvm_value, lc.AllocaInstruction):
|
elif isinstance(self.llvm_value.type, lc.PointerType):
|
||||||
builder.store(value, self._llvm_value)
|
builder.store(llvm_value, self.llvm_value)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Attempted to set LLVM SSA value multiple times")
|
"Attempted to set LLVM SSA value multiple times")
|
||||||
|
|
||||||
def alloca(self, builder, name):
|
def alloca(self, builder, name):
|
||||||
if self._llvm_value is not None:
|
if self.llvm_value is not None:
|
||||||
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
|
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
|
||||||
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
self.llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
||||||
|
|
||||||
def o_int(self, builder):
|
def o_int(self, builder):
|
||||||
return self.o_intx(32, builder)
|
return self.o_intx(32, builder)
|
||||||
|
|
|
@ -5,13 +5,13 @@ from fractions import Fraction
|
||||||
|
|
||||||
from llvm import ee as le
|
from llvm import ee as le
|
||||||
|
|
||||||
from artiq.language.core import int64
|
from artiq.language.core import int64, array
|
||||||
from artiq.py2llvm.infer_types import infer_function_types
|
from artiq.py2llvm.infer_types import infer_function_types
|
||||||
from artiq.py2llvm import base_types
|
from artiq.py2llvm import base_types, arrays
|
||||||
from artiq.py2llvm.module import Module
|
from artiq.py2llvm.module import Module
|
||||||
|
|
||||||
|
|
||||||
def test_types(choice):
|
def test_base_types(choice):
|
||||||
a = 2 # promoted later to int64
|
a = 2 # promoted later to int64
|
||||||
b = a + 1 # initially int32, becomes int64 after a is promoted
|
b = a + 1 # initially int32, becomes int64 after a is promoted
|
||||||
c = b//2 # initially int32, becomes int64 after b is promoted
|
c = b//2 # initially int32, becomes int64 after b is promoted
|
||||||
|
@ -27,13 +27,17 @@ def test_types(choice):
|
||||||
return x + c
|
return x + c
|
||||||
|
|
||||||
|
|
||||||
class FunctionTypesCase(unittest.TestCase):
|
def _build_function_types(f):
|
||||||
def setUp(self):
|
return infer_function_types(
|
||||||
self.ns = infer_function_types(
|
None, ast.parse(inspect.getsource(f)),
|
||||||
None, ast.parse(inspect.getsource(test_types)),
|
|
||||||
dict())
|
dict())
|
||||||
|
|
||||||
def test_base_types(self):
|
|
||||||
|
class FunctionBaseTypesCase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.ns = _build_function_types(test_base_types)
|
||||||
|
|
||||||
|
def test_simple_types(self):
|
||||||
self.assertIsInstance(self.ns["foo"], base_types.VBool)
|
self.assertIsInstance(self.ns["foo"], base_types.VBool)
|
||||||
self.assertIsInstance(self.ns["bar"], base_types.VNone)
|
self.assertIsInstance(self.ns["bar"], base_types.VNone)
|
||||||
self.assertIsInstance(self.ns["d"], base_types.VInt)
|
self.assertIsInstance(self.ns["d"], base_types.VInt)
|
||||||
|
@ -51,6 +55,23 @@ class FunctionTypesCase(unittest.TestCase):
|
||||||
self.assertEqual(self.ns["return"].nbits, 64)
|
self.assertEqual(self.ns["return"].nbits, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_array_types():
|
||||||
|
a = array(0, 5)
|
||||||
|
a[3] = int64(8)
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionArrayTypesCase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.ns = _build_function_types(test_array_types)
|
||||||
|
|
||||||
|
def test_array_types(self):
|
||||||
|
self.assertIsInstance(self.ns["a"], arrays.VArray)
|
||||||
|
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
|
||||||
|
self.assertEqual(self.ns["a"].el_init.nbits, 64)
|
||||||
|
self.assertEqual(self.ns["a"].count, 5)
|
||||||
|
|
||||||
|
|
||||||
class CompiledFunction:
|
class CompiledFunction:
|
||||||
def __init__(self, function, param_types):
|
def __init__(self, function, param_types):
|
||||||
module = Module()
|
module = Module()
|
||||||
|
@ -99,6 +120,23 @@ def arith_encode(op, a, b, c, d):
|
||||||
return f.numerator*1000 + f.denominator
|
return f.numerator*1000 + f.denominator
|
||||||
|
|
||||||
|
|
||||||
|
def array_test():
|
||||||
|
a = array(array(2, 5), 5)
|
||||||
|
a[3][2] = 11
|
||||||
|
a[4][1] = 42
|
||||||
|
a[0][0] += 6
|
||||||
|
|
||||||
|
acc = 0
|
||||||
|
i = 0
|
||||||
|
while i < 5:
|
||||||
|
j = 0
|
||||||
|
while j < 5:
|
||||||
|
acc += a[i][j]
|
||||||
|
j += 1
|
||||||
|
i += 1
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
class CodeGenCase(unittest.TestCase):
|
class CodeGenCase(unittest.TestCase):
|
||||||
def test_is_prime(self):
|
def test_is_prime(self):
|
||||||
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
|
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
|
||||||
|
@ -138,3 +176,7 @@ class CodeGenCase(unittest.TestCase):
|
||||||
|
|
||||||
def test_frac_div(self):
|
def test_frac_div(self):
|
||||||
self._test_frac_arith(3)
|
self._test_frac_arith(3)
|
||||||
|
|
||||||
|
def test_array(self):
|
||||||
|
array_test_c = CompiledFunction(array_test, dict())
|
||||||
|
self.assertEqual(array_test_c(), array_test())
|
||||||
|
|
Loading…
Reference in New Issue