forked from M-Labs/artiq
remove old compiler code
This commit is contained in:
parent
2503dcd837
commit
e5b58b50aa
|
@ -1,358 +0,0 @@
|
|||
import inspect
|
||||
from pythonparser import parse, ast
|
||||
|
||||
import llvmlite_artiq.ir as ll
|
||||
|
||||
from artiq.py2llvm.values import VGeneric, operators
|
||||
from artiq.py2llvm.base_types import VBool, VInt, VFloat
|
||||
|
||||
|
||||
def _gcd(a, b):
|
||||
if a < 0:
|
||||
a = -a
|
||||
while a:
|
||||
c = a
|
||||
a = b % a
|
||||
b = c
|
||||
return b
|
||||
|
||||
|
||||
def init_module(module):
|
||||
func_def = parse(inspect.getsource(_gcd)).body[0]
|
||||
function, _ = module.compile_function(func_def,
|
||||
{"a": VInt(64), "b": VInt(64)})
|
||||
function.linkage = "internal"
|
||||
|
||||
|
||||
def _reduce(builder, a, b):
|
||||
module = builder.basic_block.function.module
|
||||
for f in module.functions:
|
||||
if f.name == "_gcd":
|
||||
gcd_f = f
|
||||
break
|
||||
gcd = builder.call(gcd_f, [a, b])
|
||||
a = builder.sdiv(a, gcd)
|
||||
b = builder.sdiv(b, gcd)
|
||||
return a, b
|
||||
|
||||
|
||||
def _signnum(builder, a, b):
|
||||
function = builder.basic_block.function
|
||||
orig_block = builder.basic_block
|
||||
swap_block = function.append_basic_block("sn_swap")
|
||||
merge_block = function.append_basic_block("sn_merge")
|
||||
|
||||
condition = builder.icmp_signed(
|
||||
"<", b, ll.Constant(ll.IntType(64), 0))
|
||||
builder.cbranch(condition, swap_block, merge_block)
|
||||
|
||||
builder.position_at_end(swap_block)
|
||||
minusone = ll.Constant(ll.IntType(64), -1)
|
||||
a_swp = builder.mul(minusone, a)
|
||||
b_swp = builder.mul(minusone, b)
|
||||
builder.branch(merge_block)
|
||||
|
||||
builder.position_at_end(merge_block)
|
||||
a_phi = builder.phi(ll.IntType(64))
|
||||
a_phi.add_incoming(a, orig_block)
|
||||
a_phi.add_incoming(a_swp, swap_block)
|
||||
b_phi = builder.phi(ll.IntType(64))
|
||||
b_phi.add_incoming(b, orig_block)
|
||||
b_phi.add_incoming(b_swp, swap_block)
|
||||
|
||||
return a_phi, b_phi
|
||||
|
||||
|
||||
def _make_ssa(builder, n, d):
|
||||
value = ll.Constant(ll.ArrayType(ll.IntType(64), 2), ll.Undefined)
|
||||
value = builder.insert_value(value, n, 0)
|
||||
value = builder.insert_value(value, d, 1)
|
||||
return value
|
||||
|
||||
|
||||
class VFraction(VGeneric):
|
||||
def get_llvm_type(self):
|
||||
return ll.ArrayType(ll.IntType(64), 2)
|
||||
|
||||
def _nd(self, builder):
|
||||
ssa_value = self.auto_load(builder)
|
||||
a = builder.extract_value(ssa_value, 0)
|
||||
b = builder.extract_value(ssa_value, 1)
|
||||
return a, b
|
||||
|
||||
def set_value_nd(self, builder, a, b):
|
||||
a = a.o_int64(builder).auto_load(builder)
|
||||
b = b.o_int64(builder).auto_load(builder)
|
||||
a, b = _reduce(builder, a, b)
|
||||
a, b = _signnum(builder, a, b)
|
||||
self.auto_store(builder, _make_ssa(builder, a, b))
|
||||
|
||||
def set_value(self, builder, v):
|
||||
if not isinstance(v, VFraction):
|
||||
raise TypeError
|
||||
self.auto_store(builder, v.auto_load(builder))
|
||||
|
||||
def o_getattr(self, attr, builder):
|
||||
if attr == "numerator":
|
||||
idx = 0
|
||||
elif attr == "denominator":
|
||||
idx = 1
|
||||
else:
|
||||
raise AttributeError
|
||||
r = VInt(64)
|
||||
if builder is not None:
|
||||
elt = builder.extract_value(self.auto_load(builder), idx)
|
||||
r.auto_store(builder, elt)
|
||||
return r
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
zero = ll.Constant(ll.IntType(64), 0)
|
||||
a = builder.extract_element(self.auto_load(builder), 0)
|
||||
r.auto_store(builder, builder.icmp_signed("!=", a, zero))
|
||||
return r
|
||||
|
||||
def o_intx(self, target_bits, builder):
|
||||
if builder is None:
|
||||
return VInt(target_bits)
|
||||
else:
|
||||
r = VInt(64)
|
||||
a, b = self._nd(builder)
|
||||
r.auto_store(builder, builder.sdiv(a, b))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
def o_roundx(self, target_bits, builder):
|
||||
if builder is None:
|
||||
return VInt(target_bits)
|
||||
else:
|
||||
r = VInt(64)
|
||||
a, b = self._nd(builder)
|
||||
h_b = builder.ashr(b, ll.Constant(ll.IntType(64), 1))
|
||||
|
||||
function = builder.basic_block.function
|
||||
add_block = function.append_basic_block("fr_add")
|
||||
sub_block = function.append_basic_block("fr_sub")
|
||||
merge_block = function.append_basic_block("fr_merge")
|
||||
|
||||
condition = builder.icmp_signed(
|
||||
"<", a, ll.Constant(ll.IntType(64), 0))
|
||||
builder.cbranch(condition, sub_block, add_block)
|
||||
|
||||
builder.position_at_end(add_block)
|
||||
a_add = builder.add(a, h_b)
|
||||
builder.branch(merge_block)
|
||||
builder.position_at_end(sub_block)
|
||||
a_sub = builder.sub(a, h_b)
|
||||
builder.branch(merge_block)
|
||||
|
||||
builder.position_at_end(merge_block)
|
||||
a = builder.phi(ll.IntType(64))
|
||||
a.add_incoming(a_add, add_block)
|
||||
a.add_incoming(a_sub, sub_block)
|
||||
r.auto_store(builder, builder.sdiv(a, b))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
def o_float(self, builder):
|
||||
r = VFloat()
|
||||
if builder is not None:
|
||||
a, b = self._nd(builder)
|
||||
af = builder.sitofp(a, r.get_llvm_type())
|
||||
bf = builder.sitofp(b, r.get_llvm_type())
|
||||
r.auto_store(builder, builder.fdiv(af, bf))
|
||||
return r
|
||||
|
||||
def _o_eq_inv(self, other, builder, ne):
|
||||
if not isinstance(other, (VInt, VFraction)):
|
||||
return NotImplemented
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
if isinstance(other, VInt):
|
||||
other = other.o_int64(builder)
|
||||
a, b = self._nd(builder)
|
||||
ssa_r = builder.and_(
|
||||
builder.icmp_signed("==", a,
|
||||
other.auto_load()),
|
||||
builder.icmp_signed("==", b,
|
||||
ll.Constant(ll.IntType(64), 1)))
|
||||
else:
|
||||
a, b = self._nd(builder)
|
||||
c, d = other._nd(builder)
|
||||
ssa_r = builder.and_(
|
||||
builder.icmp_signed("==", a, c),
|
||||
builder.icmp_signed("==", b, d))
|
||||
if ne:
|
||||
ssa_r = builder.xor(ssa_r,
|
||||
ll.Constant(ll.IntType(1), 1))
|
||||
r.auto_store(builder, ssa_r)
|
||||
return r
|
||||
|
||||
def o_eq(self, other, builder):
|
||||
return self._o_eq_inv(other, builder, False)
|
||||
|
||||
def o_ne(self, other, builder):
|
||||
return self._o_eq_inv(other, builder, True)
|
||||
|
||||
def _o_cmp(self, other, icmp, builder):
|
||||
diff = self.o_sub(other, builder)
|
||||
if diff is NotImplemented:
|
||||
return NotImplemented
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
diff = diff.auto_load(builder)
|
||||
a = builder.extract_value(diff, 0)
|
||||
zero = ll.Constant(ll.IntType(64), 0)
|
||||
ssa_r = builder.icmp_signed(icmp, a, zero)
|
||||
r.auto_store(builder, ssa_r)
|
||||
return r
|
||||
|
||||
def o_lt(self, other, builder):
|
||||
return self._o_cmp(other, "<", builder)
|
||||
|
||||
def o_le(self, other, builder):
|
||||
return self._o_cmp(other, "<=", builder)
|
||||
|
||||
def o_gt(self, other, builder):
|
||||
return self._o_cmp(other, ">", builder)
|
||||
|
||||
def o_ge(self, other, builder):
|
||||
return self._o_cmp(other, ">=", builder)
|
||||
|
||||
def _o_addsub(self, other, builder, sub, invert=False):
|
||||
if isinstance(other, VFloat):
|
||||
a = self.o_getattr("numerator", builder)
|
||||
b = self.o_getattr("denominator", builder)
|
||||
if sub:
|
||||
if invert:
|
||||
return operators.truediv(
|
||||
operators.sub(operators.mul(other,
|
||||
b,
|
||||
builder),
|
||||
a,
|
||||
builder),
|
||||
b,
|
||||
builder)
|
||||
else:
|
||||
return operators.truediv(
|
||||
operators.sub(a,
|
||||
operators.mul(other,
|
||||
b,
|
||||
builder),
|
||||
builder),
|
||||
b,
|
||||
builder)
|
||||
else:
|
||||
return operators.truediv(
|
||||
operators.add(operators.mul(other,
|
||||
b,
|
||||
builder),
|
||||
a,
|
||||
builder),
|
||||
b,
|
||||
builder)
|
||||
else:
|
||||
if not isinstance(other, (VFraction, VInt)):
|
||||
return NotImplemented
|
||||
r = VFraction()
|
||||
if builder is not None:
|
||||
if isinstance(other, VInt):
|
||||
i = other.o_int64(builder).auto_load(builder)
|
||||
x, rd = self._nd(builder)
|
||||
y = builder.mul(rd, i)
|
||||
else:
|
||||
a, b = self._nd(builder)
|
||||
c, d = other._nd(builder)
|
||||
rd = builder.mul(b, d)
|
||||
x = builder.mul(a, d)
|
||||
y = builder.mul(c, b)
|
||||
if sub:
|
||||
if invert:
|
||||
rn = builder.sub(y, x)
|
||||
else:
|
||||
rn = builder.sub(x, y)
|
||||
else:
|
||||
rn = builder.add(x, y)
|
||||
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
|
||||
r.auto_store(builder, _make_ssa(builder, rn, rd))
|
||||
return r
|
||||
|
||||
def o_add(self, other, builder):
|
||||
return self._o_addsub(other, builder, False)
|
||||
|
||||
def o_sub(self, other, builder):
|
||||
return self._o_addsub(other, builder, True)
|
||||
|
||||
def or_add(self, other, builder):
|
||||
return self._o_addsub(other, builder, False)
|
||||
|
||||
def or_sub(self, other, builder):
|
||||
return self._o_addsub(other, builder, True, True)
|
||||
|
||||
def _o_muldiv(self, other, builder, div, invert=False):
|
||||
if isinstance(other, VFloat):
|
||||
a = self.o_getattr("numerator", builder)
|
||||
b = self.o_getattr("denominator", builder)
|
||||
if invert:
|
||||
a, b = b, a
|
||||
if div:
|
||||
return operators.truediv(a,
|
||||
operators.mul(b, other, builder),
|
||||
builder)
|
||||
else:
|
||||
return operators.truediv(operators.mul(a, other, builder),
|
||||
b,
|
||||
builder)
|
||||
else:
|
||||
if not isinstance(other, (VFraction, VInt)):
|
||||
return NotImplemented
|
||||
r = VFraction()
|
||||
if builder is not None:
|
||||
a, b = self._nd(builder)
|
||||
if invert:
|
||||
a, b = b, a
|
||||
if isinstance(other, VInt):
|
||||
i = other.o_int64(builder).auto_load(builder)
|
||||
if div:
|
||||
b = builder.mul(b, i)
|
||||
else:
|
||||
a = builder.mul(a, i)
|
||||
else:
|
||||
c, d = other._nd(builder)
|
||||
if div:
|
||||
a = builder.mul(a, d)
|
||||
b = builder.mul(b, c)
|
||||
else:
|
||||
a = builder.mul(a, c)
|
||||
b = builder.mul(b, d)
|
||||
if div or invert:
|
||||
a, b = _signnum(builder, a, b)
|
||||
a, b = _reduce(builder, a, b)
|
||||
r.auto_store(builder, _make_ssa(builder, a, b))
|
||||
return r
|
||||
|
||||
def o_mul(self, other, builder):
|
||||
return self._o_muldiv(other, builder, False)
|
||||
|
||||
def o_truediv(self, other, builder):
|
||||
return self._o_muldiv(other, builder, True)
|
||||
|
||||
def or_mul(self, other, builder):
|
||||
return self._o_muldiv(other, builder, False)
|
||||
|
||||
def or_truediv(self, other, builder):
|
||||
# multiply by the inverse
|
||||
return self._o_muldiv(other, builder, False, True)
|
||||
|
||||
def o_floordiv(self, other, builder):
|
||||
r = self.o_truediv(other, builder)
|
||||
if r is NotImplemented:
|
||||
return r
|
||||
else:
|
||||
return r.o_int(builder)
|
||||
|
||||
def or_floordiv(self, other, builder):
|
||||
r = self.or_truediv(other, builder)
|
||||
if r is NotImplemented:
|
||||
return r
|
||||
else:
|
||||
return r.o_int(builder)
|
|
@ -1,169 +0,0 @@
|
|||
import unittest
|
||||
from pythonparser import parse, ast
|
||||
import inspect
|
||||
from fractions import Fraction
|
||||
from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double
|
||||
import struct
|
||||
|
||||
import llvmlite_or1k.binding as llvm
|
||||
|
||||
from artiq.language.core import int64
|
||||
from artiq.py2llvm.infer_types import infer_function_types
|
||||
from artiq.py2llvm import base_types, lists
|
||||
from artiq.py2llvm.module import Module
|
||||
|
||||
def simplify_encode(a, b):
|
||||
f = Fraction(a, b)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode(op, a, b, c, d):
|
||||
if op == 0:
|
||||
f = Fraction(a, b) - Fraction(c, d)
|
||||
elif op == 1:
|
||||
f = Fraction(a, b) + Fraction(c, d)
|
||||
elif op == 2:
|
||||
f = Fraction(a, b) * Fraction(c, d)
|
||||
else:
|
||||
f = Fraction(a, b) / Fraction(c, d)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode_int(op, a, b, x):
|
||||
if op == 0:
|
||||
f = Fraction(a, b) - x
|
||||
elif op == 1:
|
||||
f = Fraction(a, b) + x
|
||||
elif op == 2:
|
||||
f = Fraction(a, b) * x
|
||||
else:
|
||||
f = Fraction(a, b) / x
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode_int_rev(op, a, b, x):
|
||||
if op == 0:
|
||||
f = x - Fraction(a, b)
|
||||
elif op == 1:
|
||||
f = x + Fraction(a, b)
|
||||
elif op == 2:
|
||||
f = x * Fraction(a, b)
|
||||
else:
|
||||
f = x / Fraction(a, b)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_float(op, a, b, x):
|
||||
if op == 0:
|
||||
return Fraction(a, b) - x
|
||||
elif op == 1:
|
||||
return Fraction(a, b) + x
|
||||
elif op == 2:
|
||||
return Fraction(a, b) * x
|
||||
else:
|
||||
return Fraction(a, b) / x
|
||||
|
||||
|
||||
def frac_arith_float_rev(op, a, b, x):
|
||||
if op == 0:
|
||||
return x - Fraction(a, b)
|
||||
elif op == 1:
|
||||
return x + Fraction(a, b)
|
||||
elif op == 2:
|
||||
return x * Fraction(a, b)
|
||||
else:
|
||||
return x / Fraction(a, b)
|
||||
|
||||
|
||||
class CodeGenCase(unittest.TestCase):
|
||||
def test_frac_simplify(self):
|
||||
simplify_encode_c = CompiledFunction(
|
||||
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
self.assertEqual(
|
||||
simplify_encode_c(a, b), simplify_encode(a, b))
|
||||
|
||||
def _test_frac_arith(self, op):
|
||||
frac_arith_encode_c = CompiledFunction(
|
||||
frac_arith_encode, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"c": base_types.VInt(), "d": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for c in _test_range():
|
||||
for d in _test_range():
|
||||
self.assertEqual(
|
||||
frac_arith_encode_c(op, a, b, c, d),
|
||||
frac_arith_encode(op, a, b, c, d))
|
||||
|
||||
def test_frac_add(self):
|
||||
self._test_frac_arith(0)
|
||||
|
||||
def test_frac_sub(self):
|
||||
self._test_frac_arith(1)
|
||||
|
||||
def test_frac_mul(self):
|
||||
self._test_frac_arith(2)
|
||||
|
||||
def test_frac_div(self):
|
||||
self._test_frac_arith(3)
|
||||
|
||||
def _test_frac_arith_int(self, op, rev):
|
||||
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
|
||||
f_c = CompiledFunction(f, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"x": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for x in _test_range():
|
||||
self.assertEqual(
|
||||
f_c(op, a, b, x),
|
||||
f(op, a, b, x))
|
||||
|
||||
def test_frac_add_int(self):
|
||||
self._test_frac_arith_int(0, False)
|
||||
self._test_frac_arith_int(0, True)
|
||||
|
||||
def test_frac_sub_int(self):
|
||||
self._test_frac_arith_int(1, False)
|
||||
self._test_frac_arith_int(1, True)
|
||||
|
||||
def test_frac_mul_int(self):
|
||||
self._test_frac_arith_int(2, False)
|
||||
self._test_frac_arith_int(2, True)
|
||||
|
||||
def test_frac_div_int(self):
|
||||
self._test_frac_arith_int(3, False)
|
||||
self._test_frac_arith_int(3, True)
|
||||
|
||||
def _test_frac_arith_float(self, op, rev):
|
||||
f = frac_arith_float_rev if rev else frac_arith_float
|
||||
f_c = CompiledFunction(f, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"x": base_types.VFloat()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for x in _test_range():
|
||||
self.assertAlmostEqual(
|
||||
f_c(op, a, b, x/2),
|
||||
f(op, a, b, x/2))
|
||||
|
||||
def test_frac_add_float(self):
|
||||
self._test_frac_arith_float(0, False)
|
||||
self._test_frac_arith_float(0, True)
|
||||
|
||||
def test_frac_sub_float(self):
|
||||
self._test_frac_arith_float(1, False)
|
||||
self._test_frac_arith_float(1, True)
|
||||
|
||||
def test_frac_mul_float(self):
|
||||
self._test_frac_arith_float(2, False)
|
||||
self._test_frac_arith_float(2, True)
|
||||
|
||||
def test_frac_div_float(self):
|
||||
self._test_frac_arith_float(3, False)
|
||||
self._test_frac_arith_float(3, True)
|
|
@ -1,548 +0,0 @@
|
|||
import inspect
|
||||
import textwrap
|
||||
import ast
|
||||
import types
|
||||
import builtins
|
||||
from fractions import Fraction
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from itertools import zip_longest, chain
|
||||
|
||||
from artiq.language import core as core_language
|
||||
from artiq.language import units
|
||||
from artiq.transforms.tools import *
|
||||
|
||||
|
||||
def new_mangled_name(in_use_names, name):
|
||||
mangled_name = name
|
||||
i = 2
|
||||
while mangled_name in in_use_names:
|
||||
mangled_name = name + str(i)
|
||||
i += 1
|
||||
in_use_names.add(mangled_name)
|
||||
return mangled_name
|
||||
|
||||
|
||||
class MangledName:
|
||||
def __init__(self, s):
|
||||
self.s = s
|
||||
|
||||
|
||||
class AttributeInfo:
|
||||
def __init__(self, obj, mangled_name, read_write):
|
||||
self.obj = obj
|
||||
self.mangled_name = mangled_name
|
||||
self.read_write = read_write
|
||||
|
||||
|
||||
def is_inlinable(core, func):
|
||||
if hasattr(func, "k_function_info"):
|
||||
if func.k_function_info.core_name == "":
|
||||
return True # portable function
|
||||
if getattr(func.__self__, func.k_function_info.core_name) is core:
|
||||
return True # kernel function for the same core device
|
||||
return False
|
||||
|
||||
|
||||
class GlobalNamespace:
|
||||
def __init__(self, func):
|
||||
self.func_gd = inspect.getmodule(func).__dict__
|
||||
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
return self.func_gd[item]
|
||||
except KeyError:
|
||||
return getattr(builtins, item)
|
||||
|
||||
|
||||
class UndefinedArg:
|
||||
pass
|
||||
|
||||
|
||||
def get_function_args(func_args, func_tr, args, kwargs):
|
||||
# OrderedDict prevents non-determinism in argument init
|
||||
r = OrderedDict()
|
||||
|
||||
# Process positional arguments. Any missing positional argument values
|
||||
# are set to UndefinedArg.
|
||||
for arg, arg_value in zip_longest(func_args.args, args,
|
||||
fillvalue=UndefinedArg):
|
||||
if arg is UndefinedArg:
|
||||
raise TypeError("Got too many positional arguments")
|
||||
if arg.arg in r:
|
||||
raise SyntaxError("Duplicate argument '{}' in function definition"
|
||||
.format(arg.arg))
|
||||
r[arg.arg] = arg_value
|
||||
|
||||
# Process keyword arguments. Any missing keyword-only argument values
|
||||
# are set to UndefinedArg.
|
||||
valid_arg_names = {arg.arg for arg in
|
||||
chain(func_args.args, func_args.kwonlyargs)}
|
||||
for arg in func_args.kwonlyargs:
|
||||
if arg.arg in r:
|
||||
raise SyntaxError("Duplicate argument '{}' in function definition"
|
||||
.format(arg.arg))
|
||||
r[arg.arg] = UndefinedArg
|
||||
for arg_name, arg_value in kwargs.items():
|
||||
if arg_name not in valid_arg_names:
|
||||
raise TypeError("Got unexpected keyword argument '{}'"
|
||||
.format(arg_name))
|
||||
if r[arg_name] is not UndefinedArg:
|
||||
raise TypeError("Got multiple values for argument '{}'"
|
||||
.format(arg_name))
|
||||
r[arg_name] = arg_value
|
||||
|
||||
# Replace any UndefinedArg positional arguments with the default value,
|
||||
# when provided.
|
||||
for arg, default in zip(func_args.args[-len(func_args.defaults):],
|
||||
func_args.defaults):
|
||||
if r[arg.arg] is UndefinedArg:
|
||||
r[arg.arg] = func_tr.code_visit(default)
|
||||
# Same with keyword-only arguments.
|
||||
for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults):
|
||||
if default is not None and r[arg.arg] is UndefinedArg:
|
||||
r[arg.arg] = func_tr.code_visit(default)
|
||||
|
||||
# Check that no argument was left undefined.
|
||||
missing_arguments = ["'"+arg+"'" for arg, value in r.items()
|
||||
if value is UndefinedArg]
|
||||
if missing_arguments:
|
||||
raise TypeError("Missing argument(s): " + " ".join(missing_arguments))
|
||||
|
||||
return r
|
||||
|
||||
|
||||
# args/kwargs can contain values or AST nodes
|
||||
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
|
||||
func, args, kwargs):
|
||||
global_namespace = GlobalNamespace(func)
|
||||
func_tr = Function(core,
|
||||
global_namespace, attribute_namespace, in_use_names,
|
||||
retval_name, mappers)
|
||||
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
|
||||
|
||||
# Initialize arguments.
|
||||
# The local namespace is empty so code_visit will always resolve
|
||||
# using the global namespace.
|
||||
arg_init = []
|
||||
arg_name_map = []
|
||||
arg_dict = get_function_args(func_def.args, func_tr, args, kwargs)
|
||||
for arg_name, arg_value in arg_dict.items():
|
||||
if isinstance(arg_value, ast.AST):
|
||||
value = arg_value
|
||||
else:
|
||||
try:
|
||||
value = ast.copy_location(value_to_ast(arg_value), func_def)
|
||||
except NotASTRepresentable:
|
||||
value = None
|
||||
if value is None:
|
||||
# static object
|
||||
func_tr.local_namespace[arg_name] = arg_value
|
||||
else:
|
||||
# set parameter value with "name = value"
|
||||
# assignment at beginning of function
|
||||
new_name = new_mangled_name(in_use_names, arg_name)
|
||||
arg_name_map.append((arg_name, new_name))
|
||||
target = ast.copy_location(ast.Name(new_name, ast.Store()),
|
||||
func_def)
|
||||
assign = ast.copy_location(ast.Assign([target], value),
|
||||
func_def)
|
||||
arg_init.append(assign)
|
||||
# Commit arguments to the local namespace at the end to handle cases
|
||||
# such as f(x, y=x) (for the default value of y, x must be resolved
|
||||
# using the global namespace).
|
||||
for arg_name, mangled_name in arg_name_map:
|
||||
func_tr.local_namespace[arg_name] = MangledName(mangled_name)
|
||||
|
||||
func_def = func_tr.code_visit(func_def)
|
||||
func_def.body[0:0] = arg_init
|
||||
return func_def
|
||||
|
||||
|
||||
class Function:
|
||||
def __init__(self, core,
|
||||
global_namespace, attribute_namespace, in_use_names,
|
||||
retval_name, mappers):
|
||||
# The core device on which this function is executing.
|
||||
self.core = core
|
||||
|
||||
# Local and global namespaces:
|
||||
# original name -> MangledName or static object
|
||||
self.local_namespace = dict()
|
||||
self.global_namespace = global_namespace
|
||||
|
||||
# (id(static object), attribute) -> AttributeInfo
|
||||
self.attribute_namespace = attribute_namespace
|
||||
|
||||
# All names currently in use, in the namespace of the combined
|
||||
# function.
|
||||
# When creating a name for a new object, check that it is not
|
||||
# already in this set.
|
||||
self.in_use_names = in_use_names
|
||||
|
||||
# Name of the variable to store the return value to, or None
|
||||
# to keep the return statement.
|
||||
self.retval_name = retval_name
|
||||
|
||||
# Host object mappers, for RPC and exception numbers
|
||||
self.mappers = mappers
|
||||
|
||||
self._insertion_point = None
|
||||
|
||||
# This is ast.NodeVisitor/NodeTransformer from CPython, modified
|
||||
# to add code_ prefix.
|
||||
def code_visit(self, node):
|
||||
method = "code_visit_" + node.__class__.__name__
|
||||
visitor = getattr(self, method, self.code_generic_visit)
|
||||
return visitor(node)
|
||||
|
||||
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
||||
# to update self._insertion_point.
|
||||
def code_generic_visit(self, node, exclude_fields=set()):
|
||||
for field, old_value in ast.iter_fields(node):
|
||||
if field in exclude_fields:
|
||||
continue
|
||||
old_value = getattr(node, field, None)
|
||||
if isinstance(old_value, list):
|
||||
prev_insertion_point = self._insertion_point
|
||||
new_values = []
|
||||
if field in ("body", "orelse", "finalbody"):
|
||||
self._insertion_point = new_values
|
||||
for value in old_value:
|
||||
if isinstance(value, ast.AST):
|
||||
value = self.code_visit(value)
|
||||
if value is None:
|
||||
continue
|
||||
elif not isinstance(value, ast.AST):
|
||||
new_values.extend(value)
|
||||
continue
|
||||
new_values.append(value)
|
||||
old_value[:] = new_values
|
||||
self._insertion_point = prev_insertion_point
|
||||
elif isinstance(old_value, ast.AST):
|
||||
new_node = self.code_visit(old_value)
|
||||
if new_node is None:
|
||||
delattr(node, field)
|
||||
else:
|
||||
setattr(node, field, new_node)
|
||||
return node
|
||||
|
||||
def code_visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Store):
|
||||
if (node.id in self.local_namespace
|
||||
and isinstance(self.local_namespace[node.id],
|
||||
MangledName)):
|
||||
new_name = self.local_namespace[node.id].s
|
||||
else:
|
||||
new_name = new_mangled_name(self.in_use_names, node.id)
|
||||
self.local_namespace[node.id] = MangledName(new_name)
|
||||
node.id = new_name
|
||||
return node
|
||||
else:
|
||||
try:
|
||||
obj = self.local_namespace[node.id]
|
||||
except KeyError:
|
||||
try:
|
||||
obj = self.global_namespace[node.id]
|
||||
except KeyError:
|
||||
raise NameError("name '{}' is not defined".format(node.id))
|
||||
if isinstance(obj, MangledName):
|
||||
node.id = obj.s
|
||||
return node
|
||||
else:
|
||||
try:
|
||||
return value_to_ast(obj)
|
||||
except NotASTRepresentable:
|
||||
raise NotImplementedError(
|
||||
"Static object cannot be used here")
|
||||
|
||||
def code_visit_Attribute(self, node):
|
||||
# There are two cases of attributes:
|
||||
# 1. static object attributes, e.g. self.foo
|
||||
# 2. dynamic expression attributes, e.g.
|
||||
# (Fraction(1, 2) + x).numerator
|
||||
# Static object resolution has no side effects so we try it first.
|
||||
try:
|
||||
obj = self.static_visit(node.value)
|
||||
except:
|
||||
self.code_generic_visit(node)
|
||||
return node
|
||||
else:
|
||||
key = (id(obj), node.attr)
|
||||
try:
|
||||
attr_info = self.attribute_namespace[key]
|
||||
except KeyError:
|
||||
new_name = new_mangled_name(self.in_use_names, node.attr)
|
||||
attr_info = AttributeInfo(obj, new_name, False)
|
||||
self.attribute_namespace[key] = attr_info
|
||||
if isinstance(node.ctx, ast.Store):
|
||||
attr_info.read_write = True
|
||||
return ast.copy_location(
|
||||
ast.Name(attr_info.mangled_name, node.ctx),
|
||||
node)
|
||||
|
||||
def code_visit_Call(self, node):
|
||||
func = self.static_visit(node.func)
|
||||
node.args = [self.code_visit(arg) for arg in node.args]
|
||||
for kw in node.keywords:
|
||||
kw.value = self.code_visit(kw.value)
|
||||
|
||||
if is_embeddable(func):
|
||||
node.func = ast.copy_location(
|
||||
ast.Name(func.__name__, ast.Load()),
|
||||
node)
|
||||
return node
|
||||
elif is_inlinable(self.core, func):
|
||||
retval_name = func.k_function_info.k_function.__name__ + "_return"
|
||||
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
|
||||
args = [func.__self__] + node.args
|
||||
kwargs = {kw.arg: kw.value for kw in node.keywords}
|
||||
inlined = get_inline(self.core,
|
||||
self.attribute_namespace, self.in_use_names,
|
||||
retval_name_m, self.mappers,
|
||||
func.k_function_info.k_function,
|
||||
args, kwargs)
|
||||
seq = ast.copy_location(
|
||||
ast.With(
|
||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||
ctx=ast.Load()),
|
||||
optional_vars=None)],
|
||||
body=inlined.body),
|
||||
node)
|
||||
self._insertion_point.append(seq)
|
||||
return ast.copy_location(ast.Name(retval_name_m, ast.Load()),
|
||||
node)
|
||||
else:
|
||||
arg1 = ast.copy_location(ast.Str("rpc"), node)
|
||||
arg2 = ast.copy_location(
|
||||
value_to_ast(self.mappers.rpc.encode(func)), node)
|
||||
node.args[0:0] = [arg1, arg2]
|
||||
node.func = ast.copy_location(
|
||||
ast.Name("syscall", ast.Load()), node)
|
||||
return node
|
||||
|
||||
def code_visit_Return(self, node):
|
||||
self.code_generic_visit(node)
|
||||
if self.retval_name is None:
|
||||
return node
|
||||
else:
|
||||
return ast.copy_location(
|
||||
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
|
||||
value=node.value),
|
||||
node)
|
||||
|
||||
def code_visit_Expr(self, node):
|
||||
if isinstance(node.value, ast.Str):
|
||||
# Strip docstrings. This also removes strings appearing in the
|
||||
# middle of the code, but they are nops.
|
||||
return None
|
||||
self.code_generic_visit(node)
|
||||
if isinstance(node.value, ast.Name):
|
||||
# Remove Expr nodes that contain only a name, likely due to
|
||||
# function call inlining. Such nodes that were originally in the
|
||||
# code are also removed, but this does not affect the semantics of
|
||||
# the code as they are nops.
|
||||
return None
|
||||
else:
|
||||
return node
|
||||
|
||||
def encode_exception(self, e):
|
||||
exception_class = self.static_visit(e)
|
||||
if not inspect.isclass(exception_class):
|
||||
raise NotImplementedError("Exception type must be a class")
|
||||
if issubclass(exception_class, core_language.RuntimeException):
|
||||
exception_id = exception_class.eid
|
||||
else:
|
||||
exception_id = self.mappers.exception.encode(exception_class)
|
||||
return ast.copy_location(
|
||||
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
||||
args=[value_to_ast(exception_id)], keywords=[]),
|
||||
e)
|
||||
|
||||
def code_visit_Raise(self, node):
|
||||
if node.cause is not None:
|
||||
raise NotImplementedError("Exception causes are not supported")
|
||||
if node.exc is not None:
|
||||
node.exc = self.encode_exception(node.exc)
|
||||
return node
|
||||
|
||||
def code_visit_ExceptHandler(self, node):
|
||||
if node.name is not None:
|
||||
raise NotImplementedError("'as target' is not supported")
|
||||
if node.type is not None:
|
||||
if isinstance(node.type, ast.Tuple):
|
||||
node.type.elts = [self.encode_exception(e)
|
||||
for e in node.type.elts]
|
||||
else:
|
||||
node.type = self.encode_exception(node.type)
|
||||
self.code_generic_visit(node)
|
||||
return node
|
||||
|
||||
def get_user_ctxm(self, context_expr):
|
||||
try:
|
||||
ctxm = self.static_visit(context_expr)
|
||||
except:
|
||||
# this also catches watchdog()
|
||||
return None
|
||||
else:
|
||||
if (ctxm is core_language.sequential
|
||||
or ctxm is core_language.parallel):
|
||||
return None
|
||||
return ctxm
|
||||
|
||||
def code_visit_With(self, node):
|
||||
if len(node.items) != 1:
|
||||
raise NotImplementedError
|
||||
item = node.items[0]
|
||||
if item.optional_vars is not None:
|
||||
raise NotImplementedError
|
||||
ctxm = self.get_user_ctxm(item.context_expr)
|
||||
if ctxm is None:
|
||||
self.code_generic_visit(node)
|
||||
return node
|
||||
|
||||
# user context manager
|
||||
self.code_generic_visit(node, {"items"})
|
||||
if (not hasattr(ctxm, "__enter__")
|
||||
or not hasattr(ctxm.__enter__, "k_function_info")):
|
||||
raise NotImplementedError
|
||||
enter = get_inline(self.core,
|
||||
self.attribute_namespace, self.in_use_names,
|
||||
None, self.mappers,
|
||||
ctxm.__enter__.k_function_info.k_function,
|
||||
[ctxm], dict())
|
||||
if (not hasattr(ctxm, "__exit__")
|
||||
or not hasattr(ctxm.__exit__, "k_function_info")):
|
||||
raise NotImplementedError
|
||||
exit = get_inline(self.core,
|
||||
self.attribute_namespace, self.in_use_names,
|
||||
None, self.mappers,
|
||||
ctxm.__exit__.k_function_info.k_function,
|
||||
[ctxm, None, None, None], dict())
|
||||
try_stmt = ast.copy_location(
|
||||
ast.Try(body=node.body,
|
||||
handlers=[],
|
||||
orelse=[],
|
||||
finalbody=exit.body), node)
|
||||
return ast.copy_location(
|
||||
ast.With(
|
||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||
ctx=ast.Load()),
|
||||
optional_vars=None)],
|
||||
body=enter.body + [try_stmt]),
|
||||
node)
|
||||
|
||||
def code_visit_FunctionDef(self, node):
|
||||
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
|
||||
kw_defaults=[], kwarg=None, defaults=[])
|
||||
node.decorator_list = []
|
||||
self.code_generic_visit(node)
|
||||
return node
|
||||
|
||||
def static_visit(self, node):
|
||||
method = "static_visit_" + node.__class__.__name__
|
||||
visitor = getattr(self, method)
|
||||
return visitor(node)
|
||||
|
||||
def static_visit_Name(self, node):
|
||||
try:
|
||||
obj = self.local_namespace[node.id]
|
||||
except KeyError:
|
||||
try:
|
||||
obj = self.global_namespace[node.id]
|
||||
except KeyError:
|
||||
raise NameError("name '{}' is not defined".format(node.id))
|
||||
if isinstance(obj, MangledName):
|
||||
raise NotImplementedError(
|
||||
"Only a static object can be used here")
|
||||
return obj
|
||||
|
||||
def static_visit_Attribute(self, node):
|
||||
value = self.static_visit(node.value)
|
||||
return getattr(value, node.attr)
|
||||
|
||||
|
||||
class HostObjectMapper:
|
||||
def __init__(self, first_encoding=0):
|
||||
self._next_encoding = first_encoding
|
||||
# id(object) -> (encoding, object)
|
||||
# this format is required to support non-hashable host objects.
|
||||
self._d = dict()
|
||||
|
||||
def encode(self, obj):
|
||||
try:
|
||||
return self._d[id(obj)][0]
|
||||
except KeyError:
|
||||
encoding = self._next_encoding
|
||||
self._d[id(obj)] = (encoding, obj)
|
||||
self._next_encoding += 1
|
||||
return encoding
|
||||
|
||||
def get_map(self):
|
||||
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
||||
|
||||
|
||||
def get_attr_init(attribute_namespace, loc_node):
|
||||
attr_init = []
|
||||
for (_, attr), attr_info in attribute_namespace.items():
|
||||
if hasattr(attr_info.obj, attr):
|
||||
value = getattr(attr_info.obj, attr)
|
||||
if (hasattr(value, "kernel_attr_init")
|
||||
and not value.kernel_attr_init):
|
||||
continue
|
||||
value = ast.copy_location(value_to_ast(value), loc_node)
|
||||
target = ast.copy_location(ast.Name(attr_info.mangled_name,
|
||||
ast.Store()),
|
||||
loc_node)
|
||||
assign = ast.copy_location(ast.Assign([target], value),
|
||||
loc_node)
|
||||
attr_init.append(assign)
|
||||
return attr_init
|
||||
|
||||
|
||||
def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
||||
attr_writeback = []
|
||||
for (_, attr), attr_info in attribute_namespace.items():
|
||||
if attr_info.read_write:
|
||||
setter = partial(setattr, attr_info.obj, attr)
|
||||
func = ast.copy_location(
|
||||
ast.Name("syscall", ast.Load()), loc_node)
|
||||
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
|
||||
arg2 = ast.copy_location(
|
||||
value_to_ast(rpc_mapper.encode(setter)), loc_node)
|
||||
arg3 = ast.copy_location(
|
||||
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
|
||||
call = ast.copy_location(
|
||||
ast.Call(func=func, args=[arg1, arg2, arg3], keywords=[]),
|
||||
loc_node)
|
||||
expr = ast.copy_location(ast.Expr(call), loc_node)
|
||||
attr_writeback.append(expr)
|
||||
return attr_writeback
|
||||
|
||||
|
||||
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
|
||||
# OrderedDict prevents non-determinism in attribute init
|
||||
attribute_namespace = OrderedDict()
|
||||
# NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names!
|
||||
in_use_names = embeddable_func_names | {"sequential", "parallel",
|
||||
"watchdog"}
|
||||
mappers = types.SimpleNamespace(
|
||||
rpc=HostObjectMapper(),
|
||||
exception=HostObjectMapper(core_language.first_user_eid)
|
||||
)
|
||||
func_def = get_inline(
|
||||
core=core,
|
||||
attribute_namespace=attribute_namespace,
|
||||
in_use_names=in_use_names,
|
||||
retval_name=None,
|
||||
mappers=mappers,
|
||||
func=k_function,
|
||||
args=k_args,
|
||||
kwargs=k_kwargs)
|
||||
|
||||
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
|
||||
if with_attr_writeback:
|
||||
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
|
||||
func_def)
|
||||
|
||||
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
|
@ -1,130 +0,0 @@
|
|||
import ast
|
||||
import types
|
||||
|
||||
from artiq.transforms.tools import *
|
||||
|
||||
|
||||
# -1 statement duration could not be pre-determined
|
||||
# 0 statement has no effect on timeline
|
||||
# >0 statement is a static delay that advances the timeline
|
||||
# by the given amount
|
||||
def _get_duration(stmt):
|
||||
if isinstance(stmt, (ast.Expr, ast.Assign)):
|
||||
return _get_duration(stmt.value)
|
||||
elif isinstance(stmt, ast.If):
|
||||
if (all(_get_duration(s) == 0 for s in stmt.body)
|
||||
and all(_get_duration(s) == 0 for s in stmt.orelse)):
|
||||
return 0
|
||||
else:
|
||||
return -1
|
||||
elif isinstance(stmt, ast.Try):
|
||||
if (all(_get_duration(s) == 0 for s in stmt.body)
|
||||
and all(_get_duration(s) == 0 for s in stmt.orelse)
|
||||
and all(_get_duration(s) == 0 for s in stmt.finalbody)
|
||||
and all(_get_duration(s) == 0 for s in handler.body
|
||||
for handler in stmt.handlers)):
|
||||
return 0
|
||||
else:
|
||||
return -1
|
||||
elif isinstance(stmt, ast.Call):
|
||||
name = stmt.func.id
|
||||
assert(name != "delay")
|
||||
if name == "delay_mu":
|
||||
try:
|
||||
da = eval_constant(stmt.args[0])
|
||||
except NotConstant:
|
||||
da = -1
|
||||
return da
|
||||
elif name == "at_mu":
|
||||
return -1
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _interleave_timelines(timelines):
|
||||
r = []
|
||||
|
||||
current_stmts = []
|
||||
for stmts in timelines:
|
||||
it = iter(stmts)
|
||||
try:
|
||||
stmt = next(it)
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
current_stmts.append(types.SimpleNamespace(
|
||||
delay=_get_duration(stmt), stmt=stmt, it=it))
|
||||
|
||||
while current_stmts:
|
||||
dt = min(stmt.delay for stmt in current_stmts)
|
||||
if dt < 0:
|
||||
# contains statement(s) with indeterminate duration
|
||||
return None
|
||||
if dt > 0:
|
||||
# advance timeline by dt
|
||||
for stmt in current_stmts:
|
||||
stmt.delay -= dt
|
||||
if stmt.delay == 0:
|
||||
ref_stmt = stmt.stmt
|
||||
delay_stmt = ast.copy_location(
|
||||
ast.Expr(ast.Call(
|
||||
func=ast.Name("delay_mu", ast.Load()),
|
||||
args=[value_to_ast(dt)], keywords=[])),
|
||||
ref_stmt)
|
||||
r.append(delay_stmt)
|
||||
else:
|
||||
for stmt in current_stmts:
|
||||
if stmt.delay == 0:
|
||||
r.append(stmt.stmt)
|
||||
# discard executed statements
|
||||
exhausted_list = []
|
||||
for stmt_i, stmt in enumerate(current_stmts):
|
||||
if stmt.delay == 0:
|
||||
try:
|
||||
stmt.stmt = next(stmt.it)
|
||||
except StopIteration:
|
||||
exhausted_list.append(stmt_i)
|
||||
else:
|
||||
stmt.delay = _get_duration(stmt.stmt)
|
||||
for offset, i in enumerate(exhausted_list):
|
||||
current_stmts.pop(i-offset)
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _interleave_stmts(stmts):
|
||||
replacements = []
|
||||
for stmt_i, stmt in enumerate(stmts):
|
||||
if isinstance(stmt, (ast.For, ast.While, ast.If)):
|
||||
_interleave_stmts(stmt.body)
|
||||
_interleave_stmts(stmt.orelse)
|
||||
elif isinstance(stmt, ast.Try):
|
||||
_interleave_stmts(stmt.body)
|
||||
_interleave_stmts(stmt.orelse)
|
||||
_interleave_stmts(stmt.finalbody)
|
||||
for handler in stmt.handlers:
|
||||
_interleave_stmts(handler.body)
|
||||
elif isinstance(stmt, ast.With):
|
||||
btype = stmt.items[0].context_expr.id
|
||||
if btype == "sequential":
|
||||
_interleave_stmts(stmt.body)
|
||||
replacements.append((stmt_i, stmt.body))
|
||||
elif btype == "parallel":
|
||||
timelines = [[s] for s in stmt.body]
|
||||
for timeline in timelines:
|
||||
_interleave_stmts(timeline)
|
||||
merged = _interleave_timelines(timelines)
|
||||
if merged is not None:
|
||||
replacements.append((stmt_i, merged))
|
||||
else:
|
||||
raise ValueError("Unknown block type: " + btype)
|
||||
offset = 0
|
||||
for location, new_stmts in replacements:
|
||||
stmts[offset+location:offset+location+1] = new_stmts
|
||||
offset += len(new_stmts) - 1
|
||||
|
||||
|
||||
def interleave(func_def):
|
||||
_interleave_stmts(func_def.body)
|
|
@ -1,43 +0,0 @@
|
|||
def visit_With(self, node):
|
||||
self.generic_visit(node)
|
||||
if (isinstance(node.items[0].context_expr, ast.Call)
|
||||
and node.items[0].context_expr.func.id == "watchdog"):
|
||||
|
||||
idname = "__watchdog_id_" + str(self.watchdog_id_counter)
|
||||
self.watchdog_id_counter += 1
|
||||
|
||||
time = ast.BinOp(left=node.items[0].context_expr.args[0],
|
||||
op=ast.Mult(),
|
||||
right=ast.Num(1000))
|
||||
time_int = ast.Call(
|
||||
func=ast.Name("round", ast.Load()),
|
||||
args=[time],
|
||||
keywords=[], starargs=None, kwargs=None)
|
||||
syscall_set = ast.Call(
|
||||
func=ast.Name("syscall", ast.Load()),
|
||||
args=[ast.Str("watchdog_set"), time_int],
|
||||
keywords=[], starargs=None, kwargs=None)
|
||||
stmt_set = ast.copy_location(
|
||||
ast.Assign(targets=[ast.Name(idname, ast.Store())],
|
||||
value=syscall_set),
|
||||
node)
|
||||
|
||||
syscall_clear = ast.Call(
|
||||
func=ast.Name("syscall", ast.Load()),
|
||||
args=[ast.Str("watchdog_clear"),
|
||||
ast.Name(idname, ast.Load())],
|
||||
keywords=[], starargs=None, kwargs=None)
|
||||
stmt_clear = ast.copy_location(ast.Expr(syscall_clear), node)
|
||||
|
||||
node.items[0] = ast.withitem(
|
||||
context_expr=ast.Name(id="sequential",
|
||||
ctx=ast.Load()),
|
||||
optional_vars=None)
|
||||
node.body = [
|
||||
stmt_set,
|
||||
ast.Try(body=node.body,
|
||||
handlers=[],
|
||||
orelse=[],
|
||||
finalbody=[stmt_clear])
|
||||
]
|
||||
return node
|
|
@ -1,82 +0,0 @@
|
|||
import ast
|
||||
from copy import deepcopy
|
||||
|
||||
from artiq.transforms.tools import eval_ast, value_to_ast
|
||||
|
||||
|
||||
def _count_stmts(node):
|
||||
if isinstance(node, list):
|
||||
return sum(map(_count_stmts, node))
|
||||
elif isinstance(node, ast.With):
|
||||
return 1 + _count_stmts(node.body)
|
||||
elif isinstance(node, (ast.For, ast.While, ast.If)):
|
||||
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
|
||||
elif isinstance(node, ast.Try):
|
||||
r = 1 + _count_stmts(node.body) \
|
||||
+ _count_stmts(node.orelse) \
|
||||
+ _count_stmts(node.finalbody)
|
||||
for handler in node.handlers:
|
||||
r += 1 + _count_stmts(handler.body)
|
||||
return r
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def _loop_breakable(node):
|
||||
if isinstance(node, list):
|
||||
return any(map(_loop_breakable, node))
|
||||
elif isinstance(node, (ast.Break, ast.Continue)):
|
||||
return True
|
||||
elif isinstance(node, ast.With):
|
||||
return _loop_breakable(node.body)
|
||||
elif isinstance(node, ast.If):
|
||||
return _loop_breakable(node.body) or _loop_breakable(node.orelse)
|
||||
elif isinstance(node, ast.Try):
|
||||
if (_loop_breakable(node.body)
|
||||
or _loop_breakable(node.orelse)
|
||||
or _loop_breakable(node.finalbody)):
|
||||
return True
|
||||
for handler in node.handlers:
|
||||
if _loop_breakable(handler.body):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class _LoopUnroller(ast.NodeTransformer):
|
||||
def __init__(self, limit):
|
||||
self.limit = limit
|
||||
|
||||
def visit_For(self, node):
|
||||
self.generic_visit(node)
|
||||
try:
|
||||
it = eval_ast(node.iter)
|
||||
except:
|
||||
return node
|
||||
l_it = len(it)
|
||||
if l_it:
|
||||
if (not _loop_breakable(node.body)
|
||||
and l_it*_count_stmts(node.body) < self.limit):
|
||||
replacement = []
|
||||
for i in it:
|
||||
if not isinstance(i, int):
|
||||
replacement = None
|
||||
break
|
||||
replacement.append(ast.copy_location(
|
||||
ast.Assign(targets=[node.target],
|
||||
value=value_to_ast(i)),
|
||||
node))
|
||||
replacement += deepcopy(node.body)
|
||||
if replacement is not None:
|
||||
return replacement
|
||||
else:
|
||||
return node
|
||||
else:
|
||||
return node
|
||||
else:
|
||||
return node.orelse
|
||||
|
||||
|
||||
def unroll_loops(node, limit):
|
||||
_LoopUnroller(limit).visit(node)
|
Loading…
Reference in New Issue