mirror of https://github.com/m-labs/artiq.git
remove old compiler code
This commit is contained in:
parent
2503dcd837
commit
e5b58b50aa
|
@ -1,358 +0,0 @@
|
||||||
import inspect
|
|
||||||
from pythonparser import parse, ast
|
|
||||||
|
|
||||||
import llvmlite_artiq.ir as ll
|
|
||||||
|
|
||||||
from artiq.py2llvm.values import VGeneric, operators
|
|
||||||
from artiq.py2llvm.base_types import VBool, VInt, VFloat
|
|
||||||
|
|
||||||
|
|
||||||
def _gcd(a, b):
|
|
||||||
if a < 0:
|
|
||||||
a = -a
|
|
||||||
while a:
|
|
||||||
c = a
|
|
||||||
a = b % a
|
|
||||||
b = c
|
|
||||||
return b
|
|
||||||
|
|
||||||
|
|
||||||
def init_module(module):
|
|
||||||
func_def = parse(inspect.getsource(_gcd)).body[0]
|
|
||||||
function, _ = module.compile_function(func_def,
|
|
||||||
{"a": VInt(64), "b": VInt(64)})
|
|
||||||
function.linkage = "internal"
|
|
||||||
|
|
||||||
|
|
||||||
def _reduce(builder, a, b):
|
|
||||||
module = builder.basic_block.function.module
|
|
||||||
for f in module.functions:
|
|
||||||
if f.name == "_gcd":
|
|
||||||
gcd_f = f
|
|
||||||
break
|
|
||||||
gcd = builder.call(gcd_f, [a, b])
|
|
||||||
a = builder.sdiv(a, gcd)
|
|
||||||
b = builder.sdiv(b, gcd)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
|
|
||||||
def _signnum(builder, a, b):
|
|
||||||
function = builder.basic_block.function
|
|
||||||
orig_block = builder.basic_block
|
|
||||||
swap_block = function.append_basic_block("sn_swap")
|
|
||||||
merge_block = function.append_basic_block("sn_merge")
|
|
||||||
|
|
||||||
condition = builder.icmp_signed(
|
|
||||||
"<", b, ll.Constant(ll.IntType(64), 0))
|
|
||||||
builder.cbranch(condition, swap_block, merge_block)
|
|
||||||
|
|
||||||
builder.position_at_end(swap_block)
|
|
||||||
minusone = ll.Constant(ll.IntType(64), -1)
|
|
||||||
a_swp = builder.mul(minusone, a)
|
|
||||||
b_swp = builder.mul(minusone, b)
|
|
||||||
builder.branch(merge_block)
|
|
||||||
|
|
||||||
builder.position_at_end(merge_block)
|
|
||||||
a_phi = builder.phi(ll.IntType(64))
|
|
||||||
a_phi.add_incoming(a, orig_block)
|
|
||||||
a_phi.add_incoming(a_swp, swap_block)
|
|
||||||
b_phi = builder.phi(ll.IntType(64))
|
|
||||||
b_phi.add_incoming(b, orig_block)
|
|
||||||
b_phi.add_incoming(b_swp, swap_block)
|
|
||||||
|
|
||||||
return a_phi, b_phi
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ssa(builder, n, d):
|
|
||||||
value = ll.Constant(ll.ArrayType(ll.IntType(64), 2), ll.Undefined)
|
|
||||||
value = builder.insert_value(value, n, 0)
|
|
||||||
value = builder.insert_value(value, d, 1)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class VFraction(VGeneric):
|
|
||||||
def get_llvm_type(self):
|
|
||||||
return ll.ArrayType(ll.IntType(64), 2)
|
|
||||||
|
|
||||||
def _nd(self, builder):
|
|
||||||
ssa_value = self.auto_load(builder)
|
|
||||||
a = builder.extract_value(ssa_value, 0)
|
|
||||||
b = builder.extract_value(ssa_value, 1)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
def set_value_nd(self, builder, a, b):
|
|
||||||
a = a.o_int64(builder).auto_load(builder)
|
|
||||||
b = b.o_int64(builder).auto_load(builder)
|
|
||||||
a, b = _reduce(builder, a, b)
|
|
||||||
a, b = _signnum(builder, a, b)
|
|
||||||
self.auto_store(builder, _make_ssa(builder, a, b))
|
|
||||||
|
|
||||||
def set_value(self, builder, v):
|
|
||||||
if not isinstance(v, VFraction):
|
|
||||||
raise TypeError
|
|
||||||
self.auto_store(builder, v.auto_load(builder))
|
|
||||||
|
|
||||||
def o_getattr(self, attr, builder):
|
|
||||||
if attr == "numerator":
|
|
||||||
idx = 0
|
|
||||||
elif attr == "denominator":
|
|
||||||
idx = 1
|
|
||||||
else:
|
|
||||||
raise AttributeError
|
|
||||||
r = VInt(64)
|
|
||||||
if builder is not None:
|
|
||||||
elt = builder.extract_value(self.auto_load(builder), idx)
|
|
||||||
r.auto_store(builder, elt)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def o_bool(self, builder):
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
zero = ll.Constant(ll.IntType(64), 0)
|
|
||||||
a = builder.extract_element(self.auto_load(builder), 0)
|
|
||||||
r.auto_store(builder, builder.icmp_signed("!=", a, zero))
|
|
||||||
return r
|
|
||||||
|
|
||||||
def o_intx(self, target_bits, builder):
|
|
||||||
if builder is None:
|
|
||||||
return VInt(target_bits)
|
|
||||||
else:
|
|
||||||
r = VInt(64)
|
|
||||||
a, b = self._nd(builder)
|
|
||||||
r.auto_store(builder, builder.sdiv(a, b))
|
|
||||||
return r.o_intx(target_bits, builder)
|
|
||||||
|
|
||||||
def o_roundx(self, target_bits, builder):
|
|
||||||
if builder is None:
|
|
||||||
return VInt(target_bits)
|
|
||||||
else:
|
|
||||||
r = VInt(64)
|
|
||||||
a, b = self._nd(builder)
|
|
||||||
h_b = builder.ashr(b, ll.Constant(ll.IntType(64), 1))
|
|
||||||
|
|
||||||
function = builder.basic_block.function
|
|
||||||
add_block = function.append_basic_block("fr_add")
|
|
||||||
sub_block = function.append_basic_block("fr_sub")
|
|
||||||
merge_block = function.append_basic_block("fr_merge")
|
|
||||||
|
|
||||||
condition = builder.icmp_signed(
|
|
||||||
"<", a, ll.Constant(ll.IntType(64), 0))
|
|
||||||
builder.cbranch(condition, sub_block, add_block)
|
|
||||||
|
|
||||||
builder.position_at_end(add_block)
|
|
||||||
a_add = builder.add(a, h_b)
|
|
||||||
builder.branch(merge_block)
|
|
||||||
builder.position_at_end(sub_block)
|
|
||||||
a_sub = builder.sub(a, h_b)
|
|
||||||
builder.branch(merge_block)
|
|
||||||
|
|
||||||
builder.position_at_end(merge_block)
|
|
||||||
a = builder.phi(ll.IntType(64))
|
|
||||||
a.add_incoming(a_add, add_block)
|
|
||||||
a.add_incoming(a_sub, sub_block)
|
|
||||||
r.auto_store(builder, builder.sdiv(a, b))
|
|
||||||
return r.o_intx(target_bits, builder)
|
|
||||||
|
|
||||||
def o_float(self, builder):
|
|
||||||
r = VFloat()
|
|
||||||
if builder is not None:
|
|
||||||
a, b = self._nd(builder)
|
|
||||||
af = builder.sitofp(a, r.get_llvm_type())
|
|
||||||
bf = builder.sitofp(b, r.get_llvm_type())
|
|
||||||
r.auto_store(builder, builder.fdiv(af, bf))
|
|
||||||
return r
|
|
||||||
|
|
||||||
def _o_eq_inv(self, other, builder, ne):
|
|
||||||
if not isinstance(other, (VInt, VFraction)):
|
|
||||||
return NotImplemented
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
if isinstance(other, VInt):
|
|
||||||
other = other.o_int64(builder)
|
|
||||||
a, b = self._nd(builder)
|
|
||||||
ssa_r = builder.and_(
|
|
||||||
builder.icmp_signed("==", a,
|
|
||||||
other.auto_load()),
|
|
||||||
builder.icmp_signed("==", b,
|
|
||||||
ll.Constant(ll.IntType(64), 1)))
|
|
||||||
else:
|
|
||||||
a, b = self._nd(builder)
|
|
||||||
c, d = other._nd(builder)
|
|
||||||
ssa_r = builder.and_(
|
|
||||||
builder.icmp_signed("==", a, c),
|
|
||||||
builder.icmp_signed("==", b, d))
|
|
||||||
if ne:
|
|
||||||
ssa_r = builder.xor(ssa_r,
|
|
||||||
ll.Constant(ll.IntType(1), 1))
|
|
||||||
r.auto_store(builder, ssa_r)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def o_eq(self, other, builder):
|
|
||||||
return self._o_eq_inv(other, builder, False)
|
|
||||||
|
|
||||||
def o_ne(self, other, builder):
|
|
||||||
return self._o_eq_inv(other, builder, True)
|
|
||||||
|
|
||||||
def _o_cmp(self, other, icmp, builder):
|
|
||||||
diff = self.o_sub(other, builder)
|
|
||||||
if diff is NotImplemented:
|
|
||||||
return NotImplemented
|
|
||||||
r = VBool()
|
|
||||||
if builder is not None:
|
|
||||||
diff = diff.auto_load(builder)
|
|
||||||
a = builder.extract_value(diff, 0)
|
|
||||||
zero = ll.Constant(ll.IntType(64), 0)
|
|
||||||
ssa_r = builder.icmp_signed(icmp, a, zero)
|
|
||||||
r.auto_store(builder, ssa_r)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def o_lt(self, other, builder):
|
|
||||||
return self._o_cmp(other, "<", builder)
|
|
||||||
|
|
||||||
def o_le(self, other, builder):
|
|
||||||
return self._o_cmp(other, "<=", builder)
|
|
||||||
|
|
||||||
def o_gt(self, other, builder):
|
|
||||||
return self._o_cmp(other, ">", builder)
|
|
||||||
|
|
||||||
def o_ge(self, other, builder):
|
|
||||||
return self._o_cmp(other, ">=", builder)
|
|
||||||
|
|
||||||
def _o_addsub(self, other, builder, sub, invert=False):
|
|
||||||
if isinstance(other, VFloat):
|
|
||||||
a = self.o_getattr("numerator", builder)
|
|
||||||
b = self.o_getattr("denominator", builder)
|
|
||||||
if sub:
|
|
||||||
if invert:
|
|
||||||
return operators.truediv(
|
|
||||||
operators.sub(operators.mul(other,
|
|
||||||
b,
|
|
||||||
builder),
|
|
||||||
a,
|
|
||||||
builder),
|
|
||||||
b,
|
|
||||||
builder)
|
|
||||||
else:
|
|
||||||
return operators.truediv(
|
|
||||||
operators.sub(a,
|
|
||||||
operators.mul(other,
|
|
||||||
b,
|
|
||||||
builder),
|
|
||||||
builder),
|
|
||||||
b,
|
|
||||||
builder)
|
|
||||||
else:
|
|
||||||
return operators.truediv(
|
|
||||||
operators.add(operators.mul(other,
|
|
||||||
b,
|
|
||||||
builder),
|
|
||||||
a,
|
|
||||||
builder),
|
|
||||||
b,
|
|
||||||
builder)
|
|
||||||
else:
|
|
||||||
if not isinstance(other, (VFraction, VInt)):
|
|
||||||
return NotImplemented
|
|
||||||
r = VFraction()
|
|
||||||
if builder is not None:
|
|
||||||
if isinstance(other, VInt):
|
|
||||||
i = other.o_int64(builder).auto_load(builder)
|
|
||||||
x, rd = self._nd(builder)
|
|
||||||
y = builder.mul(rd, i)
|
|
||||||
else:
|
|
||||||
a, b = self._nd(builder)
|
|
||||||
c, d = other._nd(builder)
|
|
||||||
rd = builder.mul(b, d)
|
|
||||||
x = builder.mul(a, d)
|
|
||||||
y = builder.mul(c, b)
|
|
||||||
if sub:
|
|
||||||
if invert:
|
|
||||||
rn = builder.sub(y, x)
|
|
||||||
else:
|
|
||||||
rn = builder.sub(x, y)
|
|
||||||
else:
|
|
||||||
rn = builder.add(x, y)
|
|
||||||
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
|
|
||||||
r.auto_store(builder, _make_ssa(builder, rn, rd))
|
|
||||||
return r
|
|
||||||
|
|
||||||
def o_add(self, other, builder):
|
|
||||||
return self._o_addsub(other, builder, False)
|
|
||||||
|
|
||||||
def o_sub(self, other, builder):
|
|
||||||
return self._o_addsub(other, builder, True)
|
|
||||||
|
|
||||||
def or_add(self, other, builder):
|
|
||||||
return self._o_addsub(other, builder, False)
|
|
||||||
|
|
||||||
def or_sub(self, other, builder):
|
|
||||||
return self._o_addsub(other, builder, True, True)
|
|
||||||
|
|
||||||
def _o_muldiv(self, other, builder, div, invert=False):
|
|
||||||
if isinstance(other, VFloat):
|
|
||||||
a = self.o_getattr("numerator", builder)
|
|
||||||
b = self.o_getattr("denominator", builder)
|
|
||||||
if invert:
|
|
||||||
a, b = b, a
|
|
||||||
if div:
|
|
||||||
return operators.truediv(a,
|
|
||||||
operators.mul(b, other, builder),
|
|
||||||
builder)
|
|
||||||
else:
|
|
||||||
return operators.truediv(operators.mul(a, other, builder),
|
|
||||||
b,
|
|
||||||
builder)
|
|
||||||
else:
|
|
||||||
if not isinstance(other, (VFraction, VInt)):
|
|
||||||
return NotImplemented
|
|
||||||
r = VFraction()
|
|
||||||
if builder is not None:
|
|
||||||
a, b = self._nd(builder)
|
|
||||||
if invert:
|
|
||||||
a, b = b, a
|
|
||||||
if isinstance(other, VInt):
|
|
||||||
i = other.o_int64(builder).auto_load(builder)
|
|
||||||
if div:
|
|
||||||
b = builder.mul(b, i)
|
|
||||||
else:
|
|
||||||
a = builder.mul(a, i)
|
|
||||||
else:
|
|
||||||
c, d = other._nd(builder)
|
|
||||||
if div:
|
|
||||||
a = builder.mul(a, d)
|
|
||||||
b = builder.mul(b, c)
|
|
||||||
else:
|
|
||||||
a = builder.mul(a, c)
|
|
||||||
b = builder.mul(b, d)
|
|
||||||
if div or invert:
|
|
||||||
a, b = _signnum(builder, a, b)
|
|
||||||
a, b = _reduce(builder, a, b)
|
|
||||||
r.auto_store(builder, _make_ssa(builder, a, b))
|
|
||||||
return r
|
|
||||||
|
|
||||||
def o_mul(self, other, builder):
|
|
||||||
return self._o_muldiv(other, builder, False)
|
|
||||||
|
|
||||||
def o_truediv(self, other, builder):
|
|
||||||
return self._o_muldiv(other, builder, True)
|
|
||||||
|
|
||||||
def or_mul(self, other, builder):
|
|
||||||
return self._o_muldiv(other, builder, False)
|
|
||||||
|
|
||||||
def or_truediv(self, other, builder):
|
|
||||||
# multiply by the inverse
|
|
||||||
return self._o_muldiv(other, builder, False, True)
|
|
||||||
|
|
||||||
def o_floordiv(self, other, builder):
|
|
||||||
r = self.o_truediv(other, builder)
|
|
||||||
if r is NotImplemented:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r.o_int(builder)
|
|
||||||
|
|
||||||
def or_floordiv(self, other, builder):
|
|
||||||
r = self.or_truediv(other, builder)
|
|
||||||
if r is NotImplemented:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r.o_int(builder)
|
|
|
@ -1,169 +0,0 @@
|
||||||
import unittest
|
|
||||||
from pythonparser import parse, ast
|
|
||||||
import inspect
|
|
||||||
from fractions import Fraction
|
|
||||||
from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double
|
|
||||||
import struct
|
|
||||||
|
|
||||||
import llvmlite_or1k.binding as llvm
|
|
||||||
|
|
||||||
from artiq.language.core import int64
|
|
||||||
from artiq.py2llvm.infer_types import infer_function_types
|
|
||||||
from artiq.py2llvm import base_types, lists
|
|
||||||
from artiq.py2llvm.module import Module
|
|
||||||
|
|
||||||
def simplify_encode(a, b):
|
|
||||||
f = Fraction(a, b)
|
|
||||||
return f.numerator*1000 + f.denominator
|
|
||||||
|
|
||||||
|
|
||||||
def frac_arith_encode(op, a, b, c, d):
|
|
||||||
if op == 0:
|
|
||||||
f = Fraction(a, b) - Fraction(c, d)
|
|
||||||
elif op == 1:
|
|
||||||
f = Fraction(a, b) + Fraction(c, d)
|
|
||||||
elif op == 2:
|
|
||||||
f = Fraction(a, b) * Fraction(c, d)
|
|
||||||
else:
|
|
||||||
f = Fraction(a, b) / Fraction(c, d)
|
|
||||||
return f.numerator*1000 + f.denominator
|
|
||||||
|
|
||||||
|
|
||||||
def frac_arith_encode_int(op, a, b, x):
|
|
||||||
if op == 0:
|
|
||||||
f = Fraction(a, b) - x
|
|
||||||
elif op == 1:
|
|
||||||
f = Fraction(a, b) + x
|
|
||||||
elif op == 2:
|
|
||||||
f = Fraction(a, b) * x
|
|
||||||
else:
|
|
||||||
f = Fraction(a, b) / x
|
|
||||||
return f.numerator*1000 + f.denominator
|
|
||||||
|
|
||||||
|
|
||||||
def frac_arith_encode_int_rev(op, a, b, x):
|
|
||||||
if op == 0:
|
|
||||||
f = x - Fraction(a, b)
|
|
||||||
elif op == 1:
|
|
||||||
f = x + Fraction(a, b)
|
|
||||||
elif op == 2:
|
|
||||||
f = x * Fraction(a, b)
|
|
||||||
else:
|
|
||||||
f = x / Fraction(a, b)
|
|
||||||
return f.numerator*1000 + f.denominator
|
|
||||||
|
|
||||||
|
|
||||||
def frac_arith_float(op, a, b, x):
|
|
||||||
if op == 0:
|
|
||||||
return Fraction(a, b) - x
|
|
||||||
elif op == 1:
|
|
||||||
return Fraction(a, b) + x
|
|
||||||
elif op == 2:
|
|
||||||
return Fraction(a, b) * x
|
|
||||||
else:
|
|
||||||
return Fraction(a, b) / x
|
|
||||||
|
|
||||||
|
|
||||||
def frac_arith_float_rev(op, a, b, x):
|
|
||||||
if op == 0:
|
|
||||||
return x - Fraction(a, b)
|
|
||||||
elif op == 1:
|
|
||||||
return x + Fraction(a, b)
|
|
||||||
elif op == 2:
|
|
||||||
return x * Fraction(a, b)
|
|
||||||
else:
|
|
||||||
return x / Fraction(a, b)
|
|
||||||
|
|
||||||
|
|
||||||
class CodeGenCase(unittest.TestCase):
|
|
||||||
def test_frac_simplify(self):
|
|
||||||
simplify_encode_c = CompiledFunction(
|
|
||||||
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
|
|
||||||
for a in _test_range():
|
|
||||||
for b in _test_range():
|
|
||||||
self.assertEqual(
|
|
||||||
simplify_encode_c(a, b), simplify_encode(a, b))
|
|
||||||
|
|
||||||
def _test_frac_arith(self, op):
|
|
||||||
frac_arith_encode_c = CompiledFunction(
|
|
||||||
frac_arith_encode, {
|
|
||||||
"op": base_types.VInt(),
|
|
||||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
|
||||||
"c": base_types.VInt(), "d": base_types.VInt()})
|
|
||||||
for a in _test_range():
|
|
||||||
for b in _test_range():
|
|
||||||
for c in _test_range():
|
|
||||||
for d in _test_range():
|
|
||||||
self.assertEqual(
|
|
||||||
frac_arith_encode_c(op, a, b, c, d),
|
|
||||||
frac_arith_encode(op, a, b, c, d))
|
|
||||||
|
|
||||||
def test_frac_add(self):
|
|
||||||
self._test_frac_arith(0)
|
|
||||||
|
|
||||||
def test_frac_sub(self):
|
|
||||||
self._test_frac_arith(1)
|
|
||||||
|
|
||||||
def test_frac_mul(self):
|
|
||||||
self._test_frac_arith(2)
|
|
||||||
|
|
||||||
def test_frac_div(self):
|
|
||||||
self._test_frac_arith(3)
|
|
||||||
|
|
||||||
def _test_frac_arith_int(self, op, rev):
|
|
||||||
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
|
|
||||||
f_c = CompiledFunction(f, {
|
|
||||||
"op": base_types.VInt(),
|
|
||||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
|
||||||
"x": base_types.VInt()})
|
|
||||||
for a in _test_range():
|
|
||||||
for b in _test_range():
|
|
||||||
for x in _test_range():
|
|
||||||
self.assertEqual(
|
|
||||||
f_c(op, a, b, x),
|
|
||||||
f(op, a, b, x))
|
|
||||||
|
|
||||||
def test_frac_add_int(self):
|
|
||||||
self._test_frac_arith_int(0, False)
|
|
||||||
self._test_frac_arith_int(0, True)
|
|
||||||
|
|
||||||
def test_frac_sub_int(self):
|
|
||||||
self._test_frac_arith_int(1, False)
|
|
||||||
self._test_frac_arith_int(1, True)
|
|
||||||
|
|
||||||
def test_frac_mul_int(self):
|
|
||||||
self._test_frac_arith_int(2, False)
|
|
||||||
self._test_frac_arith_int(2, True)
|
|
||||||
|
|
||||||
def test_frac_div_int(self):
|
|
||||||
self._test_frac_arith_int(3, False)
|
|
||||||
self._test_frac_arith_int(3, True)
|
|
||||||
|
|
||||||
def _test_frac_arith_float(self, op, rev):
|
|
||||||
f = frac_arith_float_rev if rev else frac_arith_float
|
|
||||||
f_c = CompiledFunction(f, {
|
|
||||||
"op": base_types.VInt(),
|
|
||||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
|
||||||
"x": base_types.VFloat()})
|
|
||||||
for a in _test_range():
|
|
||||||
for b in _test_range():
|
|
||||||
for x in _test_range():
|
|
||||||
self.assertAlmostEqual(
|
|
||||||
f_c(op, a, b, x/2),
|
|
||||||
f(op, a, b, x/2))
|
|
||||||
|
|
||||||
def test_frac_add_float(self):
|
|
||||||
self._test_frac_arith_float(0, False)
|
|
||||||
self._test_frac_arith_float(0, True)
|
|
||||||
|
|
||||||
def test_frac_sub_float(self):
|
|
||||||
self._test_frac_arith_float(1, False)
|
|
||||||
self._test_frac_arith_float(1, True)
|
|
||||||
|
|
||||||
def test_frac_mul_float(self):
|
|
||||||
self._test_frac_arith_float(2, False)
|
|
||||||
self._test_frac_arith_float(2, True)
|
|
||||||
|
|
||||||
def test_frac_div_float(self):
|
|
||||||
self._test_frac_arith_float(3, False)
|
|
||||||
self._test_frac_arith_float(3, True)
|
|
|
@ -1,548 +0,0 @@
|
||||||
import inspect
|
|
||||||
import textwrap
|
|
||||||
import ast
|
|
||||||
import types
|
|
||||||
import builtins
|
|
||||||
from fractions import Fraction
|
|
||||||
from collections import OrderedDict
|
|
||||||
from functools import partial
|
|
||||||
from itertools import zip_longest, chain
|
|
||||||
|
|
||||||
from artiq.language import core as core_language
|
|
||||||
from artiq.language import units
|
|
||||||
from artiq.transforms.tools import *
|
|
||||||
|
|
||||||
|
|
||||||
def new_mangled_name(in_use_names, name):
|
|
||||||
mangled_name = name
|
|
||||||
i = 2
|
|
||||||
while mangled_name in in_use_names:
|
|
||||||
mangled_name = name + str(i)
|
|
||||||
i += 1
|
|
||||||
in_use_names.add(mangled_name)
|
|
||||||
return mangled_name
|
|
||||||
|
|
||||||
|
|
||||||
class MangledName:
|
|
||||||
def __init__(self, s):
|
|
||||||
self.s = s
|
|
||||||
|
|
||||||
|
|
||||||
class AttributeInfo:
|
|
||||||
def __init__(self, obj, mangled_name, read_write):
|
|
||||||
self.obj = obj
|
|
||||||
self.mangled_name = mangled_name
|
|
||||||
self.read_write = read_write
|
|
||||||
|
|
||||||
|
|
||||||
def is_inlinable(core, func):
|
|
||||||
if hasattr(func, "k_function_info"):
|
|
||||||
if func.k_function_info.core_name == "":
|
|
||||||
return True # portable function
|
|
||||||
if getattr(func.__self__, func.k_function_info.core_name) is core:
|
|
||||||
return True # kernel function for the same core device
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalNamespace:
|
|
||||||
def __init__(self, func):
|
|
||||||
self.func_gd = inspect.getmodule(func).__dict__
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
try:
|
|
||||||
return self.func_gd[item]
|
|
||||||
except KeyError:
|
|
||||||
return getattr(builtins, item)
|
|
||||||
|
|
||||||
|
|
||||||
class UndefinedArg:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def get_function_args(func_args, func_tr, args, kwargs):
|
|
||||||
# OrderedDict prevents non-determinism in argument init
|
|
||||||
r = OrderedDict()
|
|
||||||
|
|
||||||
# Process positional arguments. Any missing positional argument values
|
|
||||||
# are set to UndefinedArg.
|
|
||||||
for arg, arg_value in zip_longest(func_args.args, args,
|
|
||||||
fillvalue=UndefinedArg):
|
|
||||||
if arg is UndefinedArg:
|
|
||||||
raise TypeError("Got too many positional arguments")
|
|
||||||
if arg.arg in r:
|
|
||||||
raise SyntaxError("Duplicate argument '{}' in function definition"
|
|
||||||
.format(arg.arg))
|
|
||||||
r[arg.arg] = arg_value
|
|
||||||
|
|
||||||
# Process keyword arguments. Any missing keyword-only argument values
|
|
||||||
# are set to UndefinedArg.
|
|
||||||
valid_arg_names = {arg.arg for arg in
|
|
||||||
chain(func_args.args, func_args.kwonlyargs)}
|
|
||||||
for arg in func_args.kwonlyargs:
|
|
||||||
if arg.arg in r:
|
|
||||||
raise SyntaxError("Duplicate argument '{}' in function definition"
|
|
||||||
.format(arg.arg))
|
|
||||||
r[arg.arg] = UndefinedArg
|
|
||||||
for arg_name, arg_value in kwargs.items():
|
|
||||||
if arg_name not in valid_arg_names:
|
|
||||||
raise TypeError("Got unexpected keyword argument '{}'"
|
|
||||||
.format(arg_name))
|
|
||||||
if r[arg_name] is not UndefinedArg:
|
|
||||||
raise TypeError("Got multiple values for argument '{}'"
|
|
||||||
.format(arg_name))
|
|
||||||
r[arg_name] = arg_value
|
|
||||||
|
|
||||||
# Replace any UndefinedArg positional arguments with the default value,
|
|
||||||
# when provided.
|
|
||||||
for arg, default in zip(func_args.args[-len(func_args.defaults):],
|
|
||||||
func_args.defaults):
|
|
||||||
if r[arg.arg] is UndefinedArg:
|
|
||||||
r[arg.arg] = func_tr.code_visit(default)
|
|
||||||
# Same with keyword-only arguments.
|
|
||||||
for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults):
|
|
||||||
if default is not None and r[arg.arg] is UndefinedArg:
|
|
||||||
r[arg.arg] = func_tr.code_visit(default)
|
|
||||||
|
|
||||||
# Check that no argument was left undefined.
|
|
||||||
missing_arguments = ["'"+arg+"'" for arg, value in r.items()
|
|
||||||
if value is UndefinedArg]
|
|
||||||
if missing_arguments:
|
|
||||||
raise TypeError("Missing argument(s): " + " ".join(missing_arguments))
|
|
||||||
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
# args/kwargs can contain values or AST nodes
|
|
||||||
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
|
|
||||||
func, args, kwargs):
|
|
||||||
global_namespace = GlobalNamespace(func)
|
|
||||||
func_tr = Function(core,
|
|
||||||
global_namespace, attribute_namespace, in_use_names,
|
|
||||||
retval_name, mappers)
|
|
||||||
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
|
|
||||||
|
|
||||||
# Initialize arguments.
|
|
||||||
# The local namespace is empty so code_visit will always resolve
|
|
||||||
# using the global namespace.
|
|
||||||
arg_init = []
|
|
||||||
arg_name_map = []
|
|
||||||
arg_dict = get_function_args(func_def.args, func_tr, args, kwargs)
|
|
||||||
for arg_name, arg_value in arg_dict.items():
|
|
||||||
if isinstance(arg_value, ast.AST):
|
|
||||||
value = arg_value
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
value = ast.copy_location(value_to_ast(arg_value), func_def)
|
|
||||||
except NotASTRepresentable:
|
|
||||||
value = None
|
|
||||||
if value is None:
|
|
||||||
# static object
|
|
||||||
func_tr.local_namespace[arg_name] = arg_value
|
|
||||||
else:
|
|
||||||
# set parameter value with "name = value"
|
|
||||||
# assignment at beginning of function
|
|
||||||
new_name = new_mangled_name(in_use_names, arg_name)
|
|
||||||
arg_name_map.append((arg_name, new_name))
|
|
||||||
target = ast.copy_location(ast.Name(new_name, ast.Store()),
|
|
||||||
func_def)
|
|
||||||
assign = ast.copy_location(ast.Assign([target], value),
|
|
||||||
func_def)
|
|
||||||
arg_init.append(assign)
|
|
||||||
# Commit arguments to the local namespace at the end to handle cases
|
|
||||||
# such as f(x, y=x) (for the default value of y, x must be resolved
|
|
||||||
# using the global namespace).
|
|
||||||
for arg_name, mangled_name in arg_name_map:
|
|
||||||
func_tr.local_namespace[arg_name] = MangledName(mangled_name)
|
|
||||||
|
|
||||||
func_def = func_tr.code_visit(func_def)
|
|
||||||
func_def.body[0:0] = arg_init
|
|
||||||
return func_def
|
|
||||||
|
|
||||||
|
|
||||||
class Function:
|
|
||||||
def __init__(self, core,
|
|
||||||
global_namespace, attribute_namespace, in_use_names,
|
|
||||||
retval_name, mappers):
|
|
||||||
# The core device on which this function is executing.
|
|
||||||
self.core = core
|
|
||||||
|
|
||||||
# Local and global namespaces:
|
|
||||||
# original name -> MangledName or static object
|
|
||||||
self.local_namespace = dict()
|
|
||||||
self.global_namespace = global_namespace
|
|
||||||
|
|
||||||
# (id(static object), attribute) -> AttributeInfo
|
|
||||||
self.attribute_namespace = attribute_namespace
|
|
||||||
|
|
||||||
# All names currently in use, in the namespace of the combined
|
|
||||||
# function.
|
|
||||||
# When creating a name for a new object, check that it is not
|
|
||||||
# already in this set.
|
|
||||||
self.in_use_names = in_use_names
|
|
||||||
|
|
||||||
# Name of the variable to store the return value to, or None
|
|
||||||
# to keep the return statement.
|
|
||||||
self.retval_name = retval_name
|
|
||||||
|
|
||||||
# Host object mappers, for RPC and exception numbers
|
|
||||||
self.mappers = mappers
|
|
||||||
|
|
||||||
self._insertion_point = None
|
|
||||||
|
|
||||||
# This is ast.NodeVisitor/NodeTransformer from CPython, modified
|
|
||||||
# to add code_ prefix.
|
|
||||||
def code_visit(self, node):
|
|
||||||
method = "code_visit_" + node.__class__.__name__
|
|
||||||
visitor = getattr(self, method, self.code_generic_visit)
|
|
||||||
return visitor(node)
|
|
||||||
|
|
||||||
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
|
||||||
# to update self._insertion_point.
|
|
||||||
def code_generic_visit(self, node, exclude_fields=set()):
|
|
||||||
for field, old_value in ast.iter_fields(node):
|
|
||||||
if field in exclude_fields:
|
|
||||||
continue
|
|
||||||
old_value = getattr(node, field, None)
|
|
||||||
if isinstance(old_value, list):
|
|
||||||
prev_insertion_point = self._insertion_point
|
|
||||||
new_values = []
|
|
||||||
if field in ("body", "orelse", "finalbody"):
|
|
||||||
self._insertion_point = new_values
|
|
||||||
for value in old_value:
|
|
||||||
if isinstance(value, ast.AST):
|
|
||||||
value = self.code_visit(value)
|
|
||||||
if value is None:
|
|
||||||
continue
|
|
||||||
elif not isinstance(value, ast.AST):
|
|
||||||
new_values.extend(value)
|
|
||||||
continue
|
|
||||||
new_values.append(value)
|
|
||||||
old_value[:] = new_values
|
|
||||||
self._insertion_point = prev_insertion_point
|
|
||||||
elif isinstance(old_value, ast.AST):
|
|
||||||
new_node = self.code_visit(old_value)
|
|
||||||
if new_node is None:
|
|
||||||
delattr(node, field)
|
|
||||||
else:
|
|
||||||
setattr(node, field, new_node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def code_visit_Name(self, node):
|
|
||||||
if isinstance(node.ctx, ast.Store):
|
|
||||||
if (node.id in self.local_namespace
|
|
||||||
and isinstance(self.local_namespace[node.id],
|
|
||||||
MangledName)):
|
|
||||||
new_name = self.local_namespace[node.id].s
|
|
||||||
else:
|
|
||||||
new_name = new_mangled_name(self.in_use_names, node.id)
|
|
||||||
self.local_namespace[node.id] = MangledName(new_name)
|
|
||||||
node.id = new_name
|
|
||||||
return node
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
obj = self.local_namespace[node.id]
|
|
||||||
except KeyError:
|
|
||||||
try:
|
|
||||||
obj = self.global_namespace[node.id]
|
|
||||||
except KeyError:
|
|
||||||
raise NameError("name '{}' is not defined".format(node.id))
|
|
||||||
if isinstance(obj, MangledName):
|
|
||||||
node.id = obj.s
|
|
||||||
return node
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
return value_to_ast(obj)
|
|
||||||
except NotASTRepresentable:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Static object cannot be used here")
|
|
||||||
|
|
||||||
def code_visit_Attribute(self, node):
|
|
||||||
# There are two cases of attributes:
|
|
||||||
# 1. static object attributes, e.g. self.foo
|
|
||||||
# 2. dynamic expression attributes, e.g.
|
|
||||||
# (Fraction(1, 2) + x).numerator
|
|
||||||
# Static object resolution has no side effects so we try it first.
|
|
||||||
try:
|
|
||||||
obj = self.static_visit(node.value)
|
|
||||||
except:
|
|
||||||
self.code_generic_visit(node)
|
|
||||||
return node
|
|
||||||
else:
|
|
||||||
key = (id(obj), node.attr)
|
|
||||||
try:
|
|
||||||
attr_info = self.attribute_namespace[key]
|
|
||||||
except KeyError:
|
|
||||||
new_name = new_mangled_name(self.in_use_names, node.attr)
|
|
||||||
attr_info = AttributeInfo(obj, new_name, False)
|
|
||||||
self.attribute_namespace[key] = attr_info
|
|
||||||
if isinstance(node.ctx, ast.Store):
|
|
||||||
attr_info.read_write = True
|
|
||||||
return ast.copy_location(
|
|
||||||
ast.Name(attr_info.mangled_name, node.ctx),
|
|
||||||
node)
|
|
||||||
|
|
||||||
def code_visit_Call(self, node):
|
|
||||||
func = self.static_visit(node.func)
|
|
||||||
node.args = [self.code_visit(arg) for arg in node.args]
|
|
||||||
for kw in node.keywords:
|
|
||||||
kw.value = self.code_visit(kw.value)
|
|
||||||
|
|
||||||
if is_embeddable(func):
|
|
||||||
node.func = ast.copy_location(
|
|
||||||
ast.Name(func.__name__, ast.Load()),
|
|
||||||
node)
|
|
||||||
return node
|
|
||||||
elif is_inlinable(self.core, func):
|
|
||||||
retval_name = func.k_function_info.k_function.__name__ + "_return"
|
|
||||||
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
|
|
||||||
args = [func.__self__] + node.args
|
|
||||||
kwargs = {kw.arg: kw.value for kw in node.keywords}
|
|
||||||
inlined = get_inline(self.core,
|
|
||||||
self.attribute_namespace, self.in_use_names,
|
|
||||||
retval_name_m, self.mappers,
|
|
||||||
func.k_function_info.k_function,
|
|
||||||
args, kwargs)
|
|
||||||
seq = ast.copy_location(
|
|
||||||
ast.With(
|
|
||||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
|
||||||
ctx=ast.Load()),
|
|
||||||
optional_vars=None)],
|
|
||||||
body=inlined.body),
|
|
||||||
node)
|
|
||||||
self._insertion_point.append(seq)
|
|
||||||
return ast.copy_location(ast.Name(retval_name_m, ast.Load()),
|
|
||||||
node)
|
|
||||||
else:
|
|
||||||
arg1 = ast.copy_location(ast.Str("rpc"), node)
|
|
||||||
arg2 = ast.copy_location(
|
|
||||||
value_to_ast(self.mappers.rpc.encode(func)), node)
|
|
||||||
node.args[0:0] = [arg1, arg2]
|
|
||||||
node.func = ast.copy_location(
|
|
||||||
ast.Name("syscall", ast.Load()), node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def code_visit_Return(self, node):
|
|
||||||
self.code_generic_visit(node)
|
|
||||||
if self.retval_name is None:
|
|
||||||
return node
|
|
||||||
else:
|
|
||||||
return ast.copy_location(
|
|
||||||
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
|
|
||||||
value=node.value),
|
|
||||||
node)
|
|
||||||
|
|
||||||
def code_visit_Expr(self, node):
|
|
||||||
if isinstance(node.value, ast.Str):
|
|
||||||
# Strip docstrings. This also removes strings appearing in the
|
|
||||||
# middle of the code, but they are nops.
|
|
||||||
return None
|
|
||||||
self.code_generic_visit(node)
|
|
||||||
if isinstance(node.value, ast.Name):
|
|
||||||
# Remove Expr nodes that contain only a name, likely due to
|
|
||||||
# function call inlining. Such nodes that were originally in the
|
|
||||||
# code are also removed, but this does not affect the semantics of
|
|
||||||
# the code as they are nops.
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return node
|
|
||||||
|
|
||||||
def encode_exception(self, e):
|
|
||||||
exception_class = self.static_visit(e)
|
|
||||||
if not inspect.isclass(exception_class):
|
|
||||||
raise NotImplementedError("Exception type must be a class")
|
|
||||||
if issubclass(exception_class, core_language.RuntimeException):
|
|
||||||
exception_id = exception_class.eid
|
|
||||||
else:
|
|
||||||
exception_id = self.mappers.exception.encode(exception_class)
|
|
||||||
return ast.copy_location(
|
|
||||||
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
|
||||||
args=[value_to_ast(exception_id)], keywords=[]),
|
|
||||||
e)
|
|
||||||
|
|
||||||
def code_visit_Raise(self, node):
|
|
||||||
if node.cause is not None:
|
|
||||||
raise NotImplementedError("Exception causes are not supported")
|
|
||||||
if node.exc is not None:
|
|
||||||
node.exc = self.encode_exception(node.exc)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def code_visit_ExceptHandler(self, node):
|
|
||||||
if node.name is not None:
|
|
||||||
raise NotImplementedError("'as target' is not supported")
|
|
||||||
if node.type is not None:
|
|
||||||
if isinstance(node.type, ast.Tuple):
|
|
||||||
node.type.elts = [self.encode_exception(e)
|
|
||||||
for e in node.type.elts]
|
|
||||||
else:
|
|
||||||
node.type = self.encode_exception(node.type)
|
|
||||||
self.code_generic_visit(node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def get_user_ctxm(self, context_expr):
|
|
||||||
try:
|
|
||||||
ctxm = self.static_visit(context_expr)
|
|
||||||
except:
|
|
||||||
# this also catches watchdog()
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if (ctxm is core_language.sequential
|
|
||||||
or ctxm is core_language.parallel):
|
|
||||||
return None
|
|
||||||
return ctxm
|
|
||||||
|
|
||||||
def code_visit_With(self, node):
|
|
||||||
if len(node.items) != 1:
|
|
||||||
raise NotImplementedError
|
|
||||||
item = node.items[0]
|
|
||||||
if item.optional_vars is not None:
|
|
||||||
raise NotImplementedError
|
|
||||||
ctxm = self.get_user_ctxm(item.context_expr)
|
|
||||||
if ctxm is None:
|
|
||||||
self.code_generic_visit(node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
# user context manager
|
|
||||||
self.code_generic_visit(node, {"items"})
|
|
||||||
if (not hasattr(ctxm, "__enter__")
|
|
||||||
or not hasattr(ctxm.__enter__, "k_function_info")):
|
|
||||||
raise NotImplementedError
|
|
||||||
enter = get_inline(self.core,
|
|
||||||
self.attribute_namespace, self.in_use_names,
|
|
||||||
None, self.mappers,
|
|
||||||
ctxm.__enter__.k_function_info.k_function,
|
|
||||||
[ctxm], dict())
|
|
||||||
if (not hasattr(ctxm, "__exit__")
|
|
||||||
or not hasattr(ctxm.__exit__, "k_function_info")):
|
|
||||||
raise NotImplementedError
|
|
||||||
exit = get_inline(self.core,
|
|
||||||
self.attribute_namespace, self.in_use_names,
|
|
||||||
None, self.mappers,
|
|
||||||
ctxm.__exit__.k_function_info.k_function,
|
|
||||||
[ctxm, None, None, None], dict())
|
|
||||||
try_stmt = ast.copy_location(
|
|
||||||
ast.Try(body=node.body,
|
|
||||||
handlers=[],
|
|
||||||
orelse=[],
|
|
||||||
finalbody=exit.body), node)
|
|
||||||
return ast.copy_location(
|
|
||||||
ast.With(
|
|
||||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
|
||||||
ctx=ast.Load()),
|
|
||||||
optional_vars=None)],
|
|
||||||
body=enter.body + [try_stmt]),
|
|
||||||
node)
|
|
||||||
|
|
||||||
def code_visit_FunctionDef(self, node):
|
|
||||||
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
|
|
||||||
kw_defaults=[], kwarg=None, defaults=[])
|
|
||||||
node.decorator_list = []
|
|
||||||
self.code_generic_visit(node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def static_visit(self, node):
|
|
||||||
method = "static_visit_" + node.__class__.__name__
|
|
||||||
visitor = getattr(self, method)
|
|
||||||
return visitor(node)
|
|
||||||
|
|
||||||
def static_visit_Name(self, node):
|
|
||||||
try:
|
|
||||||
obj = self.local_namespace[node.id]
|
|
||||||
except KeyError:
|
|
||||||
try:
|
|
||||||
obj = self.global_namespace[node.id]
|
|
||||||
except KeyError:
|
|
||||||
raise NameError("name '{}' is not defined".format(node.id))
|
|
||||||
if isinstance(obj, MangledName):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Only a static object can be used here")
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def static_visit_Attribute(self, node):
|
|
||||||
value = self.static_visit(node.value)
|
|
||||||
return getattr(value, node.attr)
|
|
||||||
|
|
||||||
|
|
||||||
class HostObjectMapper:
|
|
||||||
def __init__(self, first_encoding=0):
|
|
||||||
self._next_encoding = first_encoding
|
|
||||||
# id(object) -> (encoding, object)
|
|
||||||
# this format is required to support non-hashable host objects.
|
|
||||||
self._d = dict()
|
|
||||||
|
|
||||||
def encode(self, obj):
|
|
||||||
try:
|
|
||||||
return self._d[id(obj)][0]
|
|
||||||
except KeyError:
|
|
||||||
encoding = self._next_encoding
|
|
||||||
self._d[id(obj)] = (encoding, obj)
|
|
||||||
self._next_encoding += 1
|
|
||||||
return encoding
|
|
||||||
|
|
||||||
def get_map(self):
|
|
||||||
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def get_attr_init(attribute_namespace, loc_node):
|
|
||||||
attr_init = []
|
|
||||||
for (_, attr), attr_info in attribute_namespace.items():
|
|
||||||
if hasattr(attr_info.obj, attr):
|
|
||||||
value = getattr(attr_info.obj, attr)
|
|
||||||
if (hasattr(value, "kernel_attr_init")
|
|
||||||
and not value.kernel_attr_init):
|
|
||||||
continue
|
|
||||||
value = ast.copy_location(value_to_ast(value), loc_node)
|
|
||||||
target = ast.copy_location(ast.Name(attr_info.mangled_name,
|
|
||||||
ast.Store()),
|
|
||||||
loc_node)
|
|
||||||
assign = ast.copy_location(ast.Assign([target], value),
|
|
||||||
loc_node)
|
|
||||||
attr_init.append(assign)
|
|
||||||
return attr_init
|
|
||||||
|
|
||||||
|
|
||||||
def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
|
||||||
attr_writeback = []
|
|
||||||
for (_, attr), attr_info in attribute_namespace.items():
|
|
||||||
if attr_info.read_write:
|
|
||||||
setter = partial(setattr, attr_info.obj, attr)
|
|
||||||
func = ast.copy_location(
|
|
||||||
ast.Name("syscall", ast.Load()), loc_node)
|
|
||||||
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
|
|
||||||
arg2 = ast.copy_location(
|
|
||||||
value_to_ast(rpc_mapper.encode(setter)), loc_node)
|
|
||||||
arg3 = ast.copy_location(
|
|
||||||
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
|
|
||||||
call = ast.copy_location(
|
|
||||||
ast.Call(func=func, args=[arg1, arg2, arg3], keywords=[]),
|
|
||||||
loc_node)
|
|
||||||
expr = ast.copy_location(ast.Expr(call), loc_node)
|
|
||||||
attr_writeback.append(expr)
|
|
||||||
return attr_writeback
|
|
||||||
|
|
||||||
|
|
||||||
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
|
|
||||||
# OrderedDict prevents non-determinism in attribute init
|
|
||||||
attribute_namespace = OrderedDict()
|
|
||||||
# NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names!
|
|
||||||
in_use_names = embeddable_func_names | {"sequential", "parallel",
|
|
||||||
"watchdog"}
|
|
||||||
mappers = types.SimpleNamespace(
|
|
||||||
rpc=HostObjectMapper(),
|
|
||||||
exception=HostObjectMapper(core_language.first_user_eid)
|
|
||||||
)
|
|
||||||
func_def = get_inline(
|
|
||||||
core=core,
|
|
||||||
attribute_namespace=attribute_namespace,
|
|
||||||
in_use_names=in_use_names,
|
|
||||||
retval_name=None,
|
|
||||||
mappers=mappers,
|
|
||||||
func=k_function,
|
|
||||||
args=k_args,
|
|
||||||
kwargs=k_kwargs)
|
|
||||||
|
|
||||||
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
|
|
||||||
if with_attr_writeback:
|
|
||||||
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
|
|
||||||
func_def)
|
|
||||||
|
|
||||||
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
|
|
@ -1,130 +0,0 @@
|
||||||
import ast
|
|
||||||
import types
|
|
||||||
|
|
||||||
from artiq.transforms.tools import *
|
|
||||||
|
|
||||||
|
|
||||||
# -1 statement duration could not be pre-determined
|
|
||||||
# 0 statement has no effect on timeline
|
|
||||||
# >0 statement is a static delay that advances the timeline
|
|
||||||
# by the given amount
|
|
||||||
def _get_duration(stmt):
|
|
||||||
if isinstance(stmt, (ast.Expr, ast.Assign)):
|
|
||||||
return _get_duration(stmt.value)
|
|
||||||
elif isinstance(stmt, ast.If):
|
|
||||||
if (all(_get_duration(s) == 0 for s in stmt.body)
|
|
||||||
and all(_get_duration(s) == 0 for s in stmt.orelse)):
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return -1
|
|
||||||
elif isinstance(stmt, ast.Try):
|
|
||||||
if (all(_get_duration(s) == 0 for s in stmt.body)
|
|
||||||
and all(_get_duration(s) == 0 for s in stmt.orelse)
|
|
||||||
and all(_get_duration(s) == 0 for s in stmt.finalbody)
|
|
||||||
and all(_get_duration(s) == 0 for s in handler.body
|
|
||||||
for handler in stmt.handlers)):
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return -1
|
|
||||||
elif isinstance(stmt, ast.Call):
|
|
||||||
name = stmt.func.id
|
|
||||||
assert(name != "delay")
|
|
||||||
if name == "delay_mu":
|
|
||||||
try:
|
|
||||||
da = eval_constant(stmt.args[0])
|
|
||||||
except NotConstant:
|
|
||||||
da = -1
|
|
||||||
return da
|
|
||||||
elif name == "at_mu":
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def _interleave_timelines(timelines):
|
|
||||||
r = []
|
|
||||||
|
|
||||||
current_stmts = []
|
|
||||||
for stmts in timelines:
|
|
||||||
it = iter(stmts)
|
|
||||||
try:
|
|
||||||
stmt = next(it)
|
|
||||||
except StopIteration:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
current_stmts.append(types.SimpleNamespace(
|
|
||||||
delay=_get_duration(stmt), stmt=stmt, it=it))
|
|
||||||
|
|
||||||
while current_stmts:
|
|
||||||
dt = min(stmt.delay for stmt in current_stmts)
|
|
||||||
if dt < 0:
|
|
||||||
# contains statement(s) with indeterminate duration
|
|
||||||
return None
|
|
||||||
if dt > 0:
|
|
||||||
# advance timeline by dt
|
|
||||||
for stmt in current_stmts:
|
|
||||||
stmt.delay -= dt
|
|
||||||
if stmt.delay == 0:
|
|
||||||
ref_stmt = stmt.stmt
|
|
||||||
delay_stmt = ast.copy_location(
|
|
||||||
ast.Expr(ast.Call(
|
|
||||||
func=ast.Name("delay_mu", ast.Load()),
|
|
||||||
args=[value_to_ast(dt)], keywords=[])),
|
|
||||||
ref_stmt)
|
|
||||||
r.append(delay_stmt)
|
|
||||||
else:
|
|
||||||
for stmt in current_stmts:
|
|
||||||
if stmt.delay == 0:
|
|
||||||
r.append(stmt.stmt)
|
|
||||||
# discard executed statements
|
|
||||||
exhausted_list = []
|
|
||||||
for stmt_i, stmt in enumerate(current_stmts):
|
|
||||||
if stmt.delay == 0:
|
|
||||||
try:
|
|
||||||
stmt.stmt = next(stmt.it)
|
|
||||||
except StopIteration:
|
|
||||||
exhausted_list.append(stmt_i)
|
|
||||||
else:
|
|
||||||
stmt.delay = _get_duration(stmt.stmt)
|
|
||||||
for offset, i in enumerate(exhausted_list):
|
|
||||||
current_stmts.pop(i-offset)
|
|
||||||
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
def _interleave_stmts(stmts):
|
|
||||||
replacements = []
|
|
||||||
for stmt_i, stmt in enumerate(stmts):
|
|
||||||
if isinstance(stmt, (ast.For, ast.While, ast.If)):
|
|
||||||
_interleave_stmts(stmt.body)
|
|
||||||
_interleave_stmts(stmt.orelse)
|
|
||||||
elif isinstance(stmt, ast.Try):
|
|
||||||
_interleave_stmts(stmt.body)
|
|
||||||
_interleave_stmts(stmt.orelse)
|
|
||||||
_interleave_stmts(stmt.finalbody)
|
|
||||||
for handler in stmt.handlers:
|
|
||||||
_interleave_stmts(handler.body)
|
|
||||||
elif isinstance(stmt, ast.With):
|
|
||||||
btype = stmt.items[0].context_expr.id
|
|
||||||
if btype == "sequential":
|
|
||||||
_interleave_stmts(stmt.body)
|
|
||||||
replacements.append((stmt_i, stmt.body))
|
|
||||||
elif btype == "parallel":
|
|
||||||
timelines = [[s] for s in stmt.body]
|
|
||||||
for timeline in timelines:
|
|
||||||
_interleave_stmts(timeline)
|
|
||||||
merged = _interleave_timelines(timelines)
|
|
||||||
if merged is not None:
|
|
||||||
replacements.append((stmt_i, merged))
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown block type: " + btype)
|
|
||||||
offset = 0
|
|
||||||
for location, new_stmts in replacements:
|
|
||||||
stmts[offset+location:offset+location+1] = new_stmts
|
|
||||||
offset += len(new_stmts) - 1
|
|
||||||
|
|
||||||
|
|
||||||
def interleave(func_def):
|
|
||||||
_interleave_stmts(func_def.body)
|
|
|
@ -1,43 +0,0 @@
|
||||||
def visit_With(self, node):
|
|
||||||
self.generic_visit(node)
|
|
||||||
if (isinstance(node.items[0].context_expr, ast.Call)
|
|
||||||
and node.items[0].context_expr.func.id == "watchdog"):
|
|
||||||
|
|
||||||
idname = "__watchdog_id_" + str(self.watchdog_id_counter)
|
|
||||||
self.watchdog_id_counter += 1
|
|
||||||
|
|
||||||
time = ast.BinOp(left=node.items[0].context_expr.args[0],
|
|
||||||
op=ast.Mult(),
|
|
||||||
right=ast.Num(1000))
|
|
||||||
time_int = ast.Call(
|
|
||||||
func=ast.Name("round", ast.Load()),
|
|
||||||
args=[time],
|
|
||||||
keywords=[], starargs=None, kwargs=None)
|
|
||||||
syscall_set = ast.Call(
|
|
||||||
func=ast.Name("syscall", ast.Load()),
|
|
||||||
args=[ast.Str("watchdog_set"), time_int],
|
|
||||||
keywords=[], starargs=None, kwargs=None)
|
|
||||||
stmt_set = ast.copy_location(
|
|
||||||
ast.Assign(targets=[ast.Name(idname, ast.Store())],
|
|
||||||
value=syscall_set),
|
|
||||||
node)
|
|
||||||
|
|
||||||
syscall_clear = ast.Call(
|
|
||||||
func=ast.Name("syscall", ast.Load()),
|
|
||||||
args=[ast.Str("watchdog_clear"),
|
|
||||||
ast.Name(idname, ast.Load())],
|
|
||||||
keywords=[], starargs=None, kwargs=None)
|
|
||||||
stmt_clear = ast.copy_location(ast.Expr(syscall_clear), node)
|
|
||||||
|
|
||||||
node.items[0] = ast.withitem(
|
|
||||||
context_expr=ast.Name(id="sequential",
|
|
||||||
ctx=ast.Load()),
|
|
||||||
optional_vars=None)
|
|
||||||
node.body = [
|
|
||||||
stmt_set,
|
|
||||||
ast.Try(body=node.body,
|
|
||||||
handlers=[],
|
|
||||||
orelse=[],
|
|
||||||
finalbody=[stmt_clear])
|
|
||||||
]
|
|
||||||
return node
|
|
|
@ -1,82 +0,0 @@
|
||||||
import ast
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
from artiq.transforms.tools import eval_ast, value_to_ast
|
|
||||||
|
|
||||||
|
|
||||||
def _count_stmts(node):
|
|
||||||
if isinstance(node, list):
|
|
||||||
return sum(map(_count_stmts, node))
|
|
||||||
elif isinstance(node, ast.With):
|
|
||||||
return 1 + _count_stmts(node.body)
|
|
||||||
elif isinstance(node, (ast.For, ast.While, ast.If)):
|
|
||||||
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
|
|
||||||
elif isinstance(node, ast.Try):
|
|
||||||
r = 1 + _count_stmts(node.body) \
|
|
||||||
+ _count_stmts(node.orelse) \
|
|
||||||
+ _count_stmts(node.finalbody)
|
|
||||||
for handler in node.handlers:
|
|
||||||
r += 1 + _count_stmts(handler.body)
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
def _loop_breakable(node):
|
|
||||||
if isinstance(node, list):
|
|
||||||
return any(map(_loop_breakable, node))
|
|
||||||
elif isinstance(node, (ast.Break, ast.Continue)):
|
|
||||||
return True
|
|
||||||
elif isinstance(node, ast.With):
|
|
||||||
return _loop_breakable(node.body)
|
|
||||||
elif isinstance(node, ast.If):
|
|
||||||
return _loop_breakable(node.body) or _loop_breakable(node.orelse)
|
|
||||||
elif isinstance(node, ast.Try):
|
|
||||||
if (_loop_breakable(node.body)
|
|
||||||
or _loop_breakable(node.orelse)
|
|
||||||
or _loop_breakable(node.finalbody)):
|
|
||||||
return True
|
|
||||||
for handler in node.handlers:
|
|
||||||
if _loop_breakable(handler.body):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class _LoopUnroller(ast.NodeTransformer):
|
|
||||||
def __init__(self, limit):
|
|
||||||
self.limit = limit
|
|
||||||
|
|
||||||
def visit_For(self, node):
|
|
||||||
self.generic_visit(node)
|
|
||||||
try:
|
|
||||||
it = eval_ast(node.iter)
|
|
||||||
except:
|
|
||||||
return node
|
|
||||||
l_it = len(it)
|
|
||||||
if l_it:
|
|
||||||
if (not _loop_breakable(node.body)
|
|
||||||
and l_it*_count_stmts(node.body) < self.limit):
|
|
||||||
replacement = []
|
|
||||||
for i in it:
|
|
||||||
if not isinstance(i, int):
|
|
||||||
replacement = None
|
|
||||||
break
|
|
||||||
replacement.append(ast.copy_location(
|
|
||||||
ast.Assign(targets=[node.target],
|
|
||||||
value=value_to_ast(i)),
|
|
||||||
node))
|
|
||||||
replacement += deepcopy(node.body)
|
|
||||||
if replacement is not None:
|
|
||||||
return replacement
|
|
||||||
else:
|
|
||||||
return node
|
|
||||||
else:
|
|
||||||
return node
|
|
||||||
else:
|
|
||||||
return node.orelse
|
|
||||||
|
|
||||||
|
|
||||||
def unroll_loops(node, limit):
|
|
||||||
_LoopUnroller(limit).visit(node)
|
|
Loading…
Reference in New Issue