remove old compiler code

This commit is contained in:
Sebastien Bourdeauducq 2015-11-24 15:52:36 +08:00
parent 2503dcd837
commit e5b58b50aa
6 changed files with 0 additions and 1330 deletions

View File

@ -1,358 +0,0 @@
import inspect
from pythonparser import parse, ast
import llvmlite_artiq.ir as ll
from artiq.py2llvm.values import VGeneric, operators
from artiq.py2llvm.base_types import VBool, VInt, VFloat
def _gcd(a, b):
if a < 0:
a = -a
while a:
c = a
a = b % a
b = c
return b
def init_module(module):
func_def = parse(inspect.getsource(_gcd)).body[0]
function, _ = module.compile_function(func_def,
{"a": VInt(64), "b": VInt(64)})
function.linkage = "internal"
def _reduce(builder, a, b):
module = builder.basic_block.function.module
for f in module.functions:
if f.name == "_gcd":
gcd_f = f
break
gcd = builder.call(gcd_f, [a, b])
a = builder.sdiv(a, gcd)
b = builder.sdiv(b, gcd)
return a, b
def _signnum(builder, a, b):
function = builder.basic_block.function
orig_block = builder.basic_block
swap_block = function.append_basic_block("sn_swap")
merge_block = function.append_basic_block("sn_merge")
condition = builder.icmp_signed(
"<", b, ll.Constant(ll.IntType(64), 0))
builder.cbranch(condition, swap_block, merge_block)
builder.position_at_end(swap_block)
minusone = ll.Constant(ll.IntType(64), -1)
a_swp = builder.mul(minusone, a)
b_swp = builder.mul(minusone, b)
builder.branch(merge_block)
builder.position_at_end(merge_block)
a_phi = builder.phi(ll.IntType(64))
a_phi.add_incoming(a, orig_block)
a_phi.add_incoming(a_swp, swap_block)
b_phi = builder.phi(ll.IntType(64))
b_phi.add_incoming(b, orig_block)
b_phi.add_incoming(b_swp, swap_block)
return a_phi, b_phi
def _make_ssa(builder, n, d):
value = ll.Constant(ll.ArrayType(ll.IntType(64), 2), ll.Undefined)
value = builder.insert_value(value, n, 0)
value = builder.insert_value(value, d, 1)
return value
class VFraction(VGeneric):
def get_llvm_type(self):
return ll.ArrayType(ll.IntType(64), 2)
def _nd(self, builder):
ssa_value = self.auto_load(builder)
a = builder.extract_value(ssa_value, 0)
b = builder.extract_value(ssa_value, 1)
return a, b
def set_value_nd(self, builder, a, b):
a = a.o_int64(builder).auto_load(builder)
b = b.o_int64(builder).auto_load(builder)
a, b = _reduce(builder, a, b)
a, b = _signnum(builder, a, b)
self.auto_store(builder, _make_ssa(builder, a, b))
def set_value(self, builder, v):
if not isinstance(v, VFraction):
raise TypeError
self.auto_store(builder, v.auto_load(builder))
def o_getattr(self, attr, builder):
if attr == "numerator":
idx = 0
elif attr == "denominator":
idx = 1
else:
raise AttributeError
r = VInt(64)
if builder is not None:
elt = builder.extract_value(self.auto_load(builder), idx)
r.auto_store(builder, elt)
return r
def o_bool(self, builder):
r = VBool()
if builder is not None:
zero = ll.Constant(ll.IntType(64), 0)
a = builder.extract_element(self.auto_load(builder), 0)
r.auto_store(builder, builder.icmp_signed("!=", a, zero))
return r
def o_intx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
a, b = self._nd(builder)
r.auto_store(builder, builder.sdiv(a, b))
return r.o_intx(target_bits, builder)
def o_roundx(self, target_bits, builder):
if builder is None:
return VInt(target_bits)
else:
r = VInt(64)
a, b = self._nd(builder)
h_b = builder.ashr(b, ll.Constant(ll.IntType(64), 1))
function = builder.basic_block.function
add_block = function.append_basic_block("fr_add")
sub_block = function.append_basic_block("fr_sub")
merge_block = function.append_basic_block("fr_merge")
condition = builder.icmp_signed(
"<", a, ll.Constant(ll.IntType(64), 0))
builder.cbranch(condition, sub_block, add_block)
builder.position_at_end(add_block)
a_add = builder.add(a, h_b)
builder.branch(merge_block)
builder.position_at_end(sub_block)
a_sub = builder.sub(a, h_b)
builder.branch(merge_block)
builder.position_at_end(merge_block)
a = builder.phi(ll.IntType(64))
a.add_incoming(a_add, add_block)
a.add_incoming(a_sub, sub_block)
r.auto_store(builder, builder.sdiv(a, b))
return r.o_intx(target_bits, builder)
def o_float(self, builder):
r = VFloat()
if builder is not None:
a, b = self._nd(builder)
af = builder.sitofp(a, r.get_llvm_type())
bf = builder.sitofp(b, r.get_llvm_type())
r.auto_store(builder, builder.fdiv(af, bf))
return r
def _o_eq_inv(self, other, builder, ne):
if not isinstance(other, (VInt, VFraction)):
return NotImplemented
r = VBool()
if builder is not None:
if isinstance(other, VInt):
other = other.o_int64(builder)
a, b = self._nd(builder)
ssa_r = builder.and_(
builder.icmp_signed("==", a,
other.auto_load()),
builder.icmp_signed("==", b,
ll.Constant(ll.IntType(64), 1)))
else:
a, b = self._nd(builder)
c, d = other._nd(builder)
ssa_r = builder.and_(
builder.icmp_signed("==", a, c),
builder.icmp_signed("==", b, d))
if ne:
ssa_r = builder.xor(ssa_r,
ll.Constant(ll.IntType(1), 1))
r.auto_store(builder, ssa_r)
return r
def o_eq(self, other, builder):
return self._o_eq_inv(other, builder, False)
def o_ne(self, other, builder):
return self._o_eq_inv(other, builder, True)
def _o_cmp(self, other, icmp, builder):
diff = self.o_sub(other, builder)
if diff is NotImplemented:
return NotImplemented
r = VBool()
if builder is not None:
diff = diff.auto_load(builder)
a = builder.extract_value(diff, 0)
zero = ll.Constant(ll.IntType(64), 0)
ssa_r = builder.icmp_signed(icmp, a, zero)
r.auto_store(builder, ssa_r)
return r
def o_lt(self, other, builder):
return self._o_cmp(other, "<", builder)
def o_le(self, other, builder):
return self._o_cmp(other, "<=", builder)
def o_gt(self, other, builder):
return self._o_cmp(other, ">", builder)
def o_ge(self, other, builder):
return self._o_cmp(other, ">=", builder)
def _o_addsub(self, other, builder, sub, invert=False):
if isinstance(other, VFloat):
a = self.o_getattr("numerator", builder)
b = self.o_getattr("denominator", builder)
if sub:
if invert:
return operators.truediv(
operators.sub(operators.mul(other,
b,
builder),
a,
builder),
b,
builder)
else:
return operators.truediv(
operators.sub(a,
operators.mul(other,
b,
builder),
builder),
b,
builder)
else:
return operators.truediv(
operators.add(operators.mul(other,
b,
builder),
a,
builder),
b,
builder)
else:
if not isinstance(other, (VFraction, VInt)):
return NotImplemented
r = VFraction()
if builder is not None:
if isinstance(other, VInt):
i = other.o_int64(builder).auto_load(builder)
x, rd = self._nd(builder)
y = builder.mul(rd, i)
else:
a, b = self._nd(builder)
c, d = other._nd(builder)
rd = builder.mul(b, d)
x = builder.mul(a, d)
y = builder.mul(c, b)
if sub:
if invert:
rn = builder.sub(y, x)
else:
rn = builder.sub(x, y)
else:
rn = builder.add(x, y)
rn, rd = _reduce(builder, rn, rd) # rd is already > 0
r.auto_store(builder, _make_ssa(builder, rn, rd))
return r
def o_add(self, other, builder):
return self._o_addsub(other, builder, False)
def o_sub(self, other, builder):
return self._o_addsub(other, builder, True)
def or_add(self, other, builder):
return self._o_addsub(other, builder, False)
def or_sub(self, other, builder):
return self._o_addsub(other, builder, True, True)
def _o_muldiv(self, other, builder, div, invert=False):
if isinstance(other, VFloat):
a = self.o_getattr("numerator", builder)
b = self.o_getattr("denominator", builder)
if invert:
a, b = b, a
if div:
return operators.truediv(a,
operators.mul(b, other, builder),
builder)
else:
return operators.truediv(operators.mul(a, other, builder),
b,
builder)
else:
if not isinstance(other, (VFraction, VInt)):
return NotImplemented
r = VFraction()
if builder is not None:
a, b = self._nd(builder)
if invert:
a, b = b, a
if isinstance(other, VInt):
i = other.o_int64(builder).auto_load(builder)
if div:
b = builder.mul(b, i)
else:
a = builder.mul(a, i)
else:
c, d = other._nd(builder)
if div:
a = builder.mul(a, d)
b = builder.mul(b, c)
else:
a = builder.mul(a, c)
b = builder.mul(b, d)
if div or invert:
a, b = _signnum(builder, a, b)
a, b = _reduce(builder, a, b)
r.auto_store(builder, _make_ssa(builder, a, b))
return r
def o_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def o_truediv(self, other, builder):
return self._o_muldiv(other, builder, True)
def or_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def or_truediv(self, other, builder):
# multiply by the inverse
return self._o_muldiv(other, builder, False, True)
def o_floordiv(self, other, builder):
r = self.o_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
def or_floordiv(self, other, builder):
r = self.or_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)

View File

@ -1,169 +0,0 @@
import unittest
from pythonparser import parse, ast
import inspect
from fractions import Fraction
from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double
import struct
import llvmlite_or1k.binding as llvm
from artiq.language.core import int64
from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import base_types, lists
from artiq.py2llvm.module import Module
def simplify_encode(a, b):
f = Fraction(a, b)
return f.numerator*1000 + f.denominator
def frac_arith_encode(op, a, b, c, d):
if op == 0:
f = Fraction(a, b) - Fraction(c, d)
elif op == 1:
f = Fraction(a, b) + Fraction(c, d)
elif op == 2:
f = Fraction(a, b) * Fraction(c, d)
else:
f = Fraction(a, b) / Fraction(c, d)
return f.numerator*1000 + f.denominator
def frac_arith_encode_int(op, a, b, x):
if op == 0:
f = Fraction(a, b) - x
elif op == 1:
f = Fraction(a, b) + x
elif op == 2:
f = Fraction(a, b) * x
else:
f = Fraction(a, b) / x
return f.numerator*1000 + f.denominator
def frac_arith_encode_int_rev(op, a, b, x):
if op == 0:
f = x - Fraction(a, b)
elif op == 1:
f = x + Fraction(a, b)
elif op == 2:
f = x * Fraction(a, b)
else:
f = x / Fraction(a, b)
return f.numerator*1000 + f.denominator
def frac_arith_float(op, a, b, x):
if op == 0:
return Fraction(a, b) - x
elif op == 1:
return Fraction(a, b) + x
elif op == 2:
return Fraction(a, b) * x
else:
return Fraction(a, b) / x
def frac_arith_float_rev(op, a, b, x):
if op == 0:
return x - Fraction(a, b)
elif op == 1:
return x + Fraction(a, b)
elif op == 2:
return x * Fraction(a, b)
else:
return x / Fraction(a, b)
class CodeGenCase(unittest.TestCase):
def test_frac_simplify(self):
simplify_encode_c = CompiledFunction(
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
for a in _test_range():
for b in _test_range():
self.assertEqual(
simplify_encode_c(a, b), simplify_encode(a, b))
def _test_frac_arith(self, op):
frac_arith_encode_c = CompiledFunction(
frac_arith_encode, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"c": base_types.VInt(), "d": base_types.VInt()})
for a in _test_range():
for b in _test_range():
for c in _test_range():
for d in _test_range():
self.assertEqual(
frac_arith_encode_c(op, a, b, c, d),
frac_arith_encode(op, a, b, c, d))
def test_frac_add(self):
self._test_frac_arith(0)
def test_frac_sub(self):
self._test_frac_arith(1)
def test_frac_mul(self):
self._test_frac_arith(2)
def test_frac_div(self):
self._test_frac_arith(3)
def _test_frac_arith_int(self, op, rev):
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VInt()})
for a in _test_range():
for b in _test_range():
for x in _test_range():
self.assertEqual(
f_c(op, a, b, x),
f(op, a, b, x))
def test_frac_add_int(self):
self._test_frac_arith_int(0, False)
self._test_frac_arith_int(0, True)
def test_frac_sub_int(self):
self._test_frac_arith_int(1, False)
self._test_frac_arith_int(1, True)
def test_frac_mul_int(self):
self._test_frac_arith_int(2, False)
self._test_frac_arith_int(2, True)
def test_frac_div_int(self):
self._test_frac_arith_int(3, False)
self._test_frac_arith_int(3, True)
def _test_frac_arith_float(self, op, rev):
f = frac_arith_float_rev if rev else frac_arith_float
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VFloat()})
for a in _test_range():
for b in _test_range():
for x in _test_range():
self.assertAlmostEqual(
f_c(op, a, b, x/2),
f(op, a, b, x/2))
def test_frac_add_float(self):
self._test_frac_arith_float(0, False)
self._test_frac_arith_float(0, True)
def test_frac_sub_float(self):
self._test_frac_arith_float(1, False)
self._test_frac_arith_float(1, True)
def test_frac_mul_float(self):
self._test_frac_arith_float(2, False)
self._test_frac_arith_float(2, True)
def test_frac_div_float(self):
self._test_frac_arith_float(3, False)
self._test_frac_arith_float(3, True)

View File

@ -1,548 +0,0 @@
import inspect
import textwrap
import ast
import types
import builtins
from fractions import Fraction
from collections import OrderedDict
from functools import partial
from itertools import zip_longest, chain
from artiq.language import core as core_language
from artiq.language import units
from artiq.transforms.tools import *
def new_mangled_name(in_use_names, name):
mangled_name = name
i = 2
while mangled_name in in_use_names:
mangled_name = name + str(i)
i += 1
in_use_names.add(mangled_name)
return mangled_name
class MangledName:
def __init__(self, s):
self.s = s
class AttributeInfo:
def __init__(self, obj, mangled_name, read_write):
self.obj = obj
self.mangled_name = mangled_name
self.read_write = read_write
def is_inlinable(core, func):
if hasattr(func, "k_function_info"):
if func.k_function_info.core_name == "":
return True # portable function
if getattr(func.__self__, func.k_function_info.core_name) is core:
return True # kernel function for the same core device
return False
class GlobalNamespace:
def __init__(self, func):
self.func_gd = inspect.getmodule(func).__dict__
def __getitem__(self, item):
try:
return self.func_gd[item]
except KeyError:
return getattr(builtins, item)
class UndefinedArg:
pass
def get_function_args(func_args, func_tr, args, kwargs):
# OrderedDict prevents non-determinism in argument init
r = OrderedDict()
# Process positional arguments. Any missing positional argument values
# are set to UndefinedArg.
for arg, arg_value in zip_longest(func_args.args, args,
fillvalue=UndefinedArg):
if arg is UndefinedArg:
raise TypeError("Got too many positional arguments")
if arg.arg in r:
raise SyntaxError("Duplicate argument '{}' in function definition"
.format(arg.arg))
r[arg.arg] = arg_value
# Process keyword arguments. Any missing keyword-only argument values
# are set to UndefinedArg.
valid_arg_names = {arg.arg for arg in
chain(func_args.args, func_args.kwonlyargs)}
for arg in func_args.kwonlyargs:
if arg.arg in r:
raise SyntaxError("Duplicate argument '{}' in function definition"
.format(arg.arg))
r[arg.arg] = UndefinedArg
for arg_name, arg_value in kwargs.items():
if arg_name not in valid_arg_names:
raise TypeError("Got unexpected keyword argument '{}'"
.format(arg_name))
if r[arg_name] is not UndefinedArg:
raise TypeError("Got multiple values for argument '{}'"
.format(arg_name))
r[arg_name] = arg_value
# Replace any UndefinedArg positional arguments with the default value,
# when provided.
for arg, default in zip(func_args.args[-len(func_args.defaults):],
func_args.defaults):
if r[arg.arg] is UndefinedArg:
r[arg.arg] = func_tr.code_visit(default)
# Same with keyword-only arguments.
for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults):
if default is not None and r[arg.arg] is UndefinedArg:
r[arg.arg] = func_tr.code_visit(default)
# Check that no argument was left undefined.
missing_arguments = ["'"+arg+"'" for arg, value in r.items()
if value is UndefinedArg]
if missing_arguments:
raise TypeError("Missing argument(s): " + " ".join(missing_arguments))
return r
# args/kwargs can contain values or AST nodes
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
func, args, kwargs):
global_namespace = GlobalNamespace(func)
func_tr = Function(core,
global_namespace, attribute_namespace, in_use_names,
retval_name, mappers)
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
# Initialize arguments.
# The local namespace is empty so code_visit will always resolve
# using the global namespace.
arg_init = []
arg_name_map = []
arg_dict = get_function_args(func_def.args, func_tr, args, kwargs)
for arg_name, arg_value in arg_dict.items():
if isinstance(arg_value, ast.AST):
value = arg_value
else:
try:
value = ast.copy_location(value_to_ast(arg_value), func_def)
except NotASTRepresentable:
value = None
if value is None:
# static object
func_tr.local_namespace[arg_name] = arg_value
else:
# set parameter value with "name = value"
# assignment at beginning of function
new_name = new_mangled_name(in_use_names, arg_name)
arg_name_map.append((arg_name, new_name))
target = ast.copy_location(ast.Name(new_name, ast.Store()),
func_def)
assign = ast.copy_location(ast.Assign([target], value),
func_def)
arg_init.append(assign)
# Commit arguments to the local namespace at the end to handle cases
# such as f(x, y=x) (for the default value of y, x must be resolved
# using the global namespace).
for arg_name, mangled_name in arg_name_map:
func_tr.local_namespace[arg_name] = MangledName(mangled_name)
func_def = func_tr.code_visit(func_def)
func_def.body[0:0] = arg_init
return func_def
class Function:
def __init__(self, core,
global_namespace, attribute_namespace, in_use_names,
retval_name, mappers):
# The core device on which this function is executing.
self.core = core
# Local and global namespaces:
# original name -> MangledName or static object
self.local_namespace = dict()
self.global_namespace = global_namespace
# (id(static object), attribute) -> AttributeInfo
self.attribute_namespace = attribute_namespace
# All names currently in use, in the namespace of the combined
# function.
# When creating a name for a new object, check that it is not
# already in this set.
self.in_use_names = in_use_names
# Name of the variable to store the return value to, or None
# to keep the return statement.
self.retval_name = retval_name
# Host object mappers, for RPC and exception numbers
self.mappers = mappers
self._insertion_point = None
# This is ast.NodeVisitor/NodeTransformer from CPython, modified
# to add code_ prefix.
def code_visit(self, node):
method = "code_visit_" + node.__class__.__name__
visitor = getattr(self, method, self.code_generic_visit)
return visitor(node)
# This is ast.NodeTransformer.generic_visit from CPython, modified
# to update self._insertion_point.
def code_generic_visit(self, node, exclude_fields=set()):
for field, old_value in ast.iter_fields(node):
if field in exclude_fields:
continue
old_value = getattr(node, field, None)
if isinstance(old_value, list):
prev_insertion_point = self._insertion_point
new_values = []
if field in ("body", "orelse", "finalbody"):
self._insertion_point = new_values
for value in old_value:
if isinstance(value, ast.AST):
value = self.code_visit(value)
if value is None:
continue
elif not isinstance(value, ast.AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
self._insertion_point = prev_insertion_point
elif isinstance(old_value, ast.AST):
new_node = self.code_visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
def code_visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
if (node.id in self.local_namespace
and isinstance(self.local_namespace[node.id],
MangledName)):
new_name = self.local_namespace[node.id].s
else:
new_name = new_mangled_name(self.in_use_names, node.id)
self.local_namespace[node.id] = MangledName(new_name)
node.id = new_name
return node
else:
try:
obj = self.local_namespace[node.id]
except KeyError:
try:
obj = self.global_namespace[node.id]
except KeyError:
raise NameError("name '{}' is not defined".format(node.id))
if isinstance(obj, MangledName):
node.id = obj.s
return node
else:
try:
return value_to_ast(obj)
except NotASTRepresentable:
raise NotImplementedError(
"Static object cannot be used here")
def code_visit_Attribute(self, node):
# There are two cases of attributes:
# 1. static object attributes, e.g. self.foo
# 2. dynamic expression attributes, e.g.
# (Fraction(1, 2) + x).numerator
# Static object resolution has no side effects so we try it first.
try:
obj = self.static_visit(node.value)
except:
self.code_generic_visit(node)
return node
else:
key = (id(obj), node.attr)
try:
attr_info = self.attribute_namespace[key]
except KeyError:
new_name = new_mangled_name(self.in_use_names, node.attr)
attr_info = AttributeInfo(obj, new_name, False)
self.attribute_namespace[key] = attr_info
if isinstance(node.ctx, ast.Store):
attr_info.read_write = True
return ast.copy_location(
ast.Name(attr_info.mangled_name, node.ctx),
node)
def code_visit_Call(self, node):
func = self.static_visit(node.func)
node.args = [self.code_visit(arg) for arg in node.args]
for kw in node.keywords:
kw.value = self.code_visit(kw.value)
if is_embeddable(func):
node.func = ast.copy_location(
ast.Name(func.__name__, ast.Load()),
node)
return node
elif is_inlinable(self.core, func):
retval_name = func.k_function_info.k_function.__name__ + "_return"
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
args = [func.__self__] + node.args
kwargs = {kw.arg: kw.value for kw in node.keywords}
inlined = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
retval_name_m, self.mappers,
func.k_function_info.k_function,
args, kwargs)
seq = ast.copy_location(
ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential",
ctx=ast.Load()),
optional_vars=None)],
body=inlined.body),
node)
self._insertion_point.append(seq)
return ast.copy_location(ast.Name(retval_name_m, ast.Load()),
node)
else:
arg1 = ast.copy_location(ast.Str("rpc"), node)
arg2 = ast.copy_location(
value_to_ast(self.mappers.rpc.encode(func)), node)
node.args[0:0] = [arg1, arg2]
node.func = ast.copy_location(
ast.Name("syscall", ast.Load()), node)
return node
def code_visit_Return(self, node):
self.code_generic_visit(node)
if self.retval_name is None:
return node
else:
return ast.copy_location(
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
value=node.value),
node)
def code_visit_Expr(self, node):
if isinstance(node.value, ast.Str):
# Strip docstrings. This also removes strings appearing in the
# middle of the code, but they are nops.
return None
self.code_generic_visit(node)
if isinstance(node.value, ast.Name):
# Remove Expr nodes that contain only a name, likely due to
# function call inlining. Such nodes that were originally in the
# code are also removed, but this does not affect the semantics of
# the code as they are nops.
return None
else:
return node
def encode_exception(self, e):
exception_class = self.static_visit(e)
if not inspect.isclass(exception_class):
raise NotImplementedError("Exception type must be a class")
if issubclass(exception_class, core_language.RuntimeException):
exception_id = exception_class.eid
else:
exception_id = self.mappers.exception.encode(exception_class)
return ast.copy_location(
ast.Call(func=ast.Name("EncodedException", ast.Load()),
args=[value_to_ast(exception_id)], keywords=[]),
e)
def code_visit_Raise(self, node):
if node.cause is not None:
raise NotImplementedError("Exception causes are not supported")
if node.exc is not None:
node.exc = self.encode_exception(node.exc)
return node
def code_visit_ExceptHandler(self, node):
if node.name is not None:
raise NotImplementedError("'as target' is not supported")
if node.type is not None:
if isinstance(node.type, ast.Tuple):
node.type.elts = [self.encode_exception(e)
for e in node.type.elts]
else:
node.type = self.encode_exception(node.type)
self.code_generic_visit(node)
return node
def get_user_ctxm(self, context_expr):
try:
ctxm = self.static_visit(context_expr)
except:
# this also catches watchdog()
return None
else:
if (ctxm is core_language.sequential
or ctxm is core_language.parallel):
return None
return ctxm
def code_visit_With(self, node):
if len(node.items) != 1:
raise NotImplementedError
item = node.items[0]
if item.optional_vars is not None:
raise NotImplementedError
ctxm = self.get_user_ctxm(item.context_expr)
if ctxm is None:
self.code_generic_visit(node)
return node
# user context manager
self.code_generic_visit(node, {"items"})
if (not hasattr(ctxm, "__enter__")
or not hasattr(ctxm.__enter__, "k_function_info")):
raise NotImplementedError
enter = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
None, self.mappers,
ctxm.__enter__.k_function_info.k_function,
[ctxm], dict())
if (not hasattr(ctxm, "__exit__")
or not hasattr(ctxm.__exit__, "k_function_info")):
raise NotImplementedError
exit = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
None, self.mappers,
ctxm.__exit__.k_function_info.k_function,
[ctxm, None, None, None], dict())
try_stmt = ast.copy_location(
ast.Try(body=node.body,
handlers=[],
orelse=[],
finalbody=exit.body), node)
return ast.copy_location(
ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential",
ctx=ast.Load()),
optional_vars=None)],
body=enter.body + [try_stmt]),
node)
def code_visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[])
node.decorator_list = []
self.code_generic_visit(node)
return node
def static_visit(self, node):
method = "static_visit_" + node.__class__.__name__
visitor = getattr(self, method)
return visitor(node)
def static_visit_Name(self, node):
try:
obj = self.local_namespace[node.id]
except KeyError:
try:
obj = self.global_namespace[node.id]
except KeyError:
raise NameError("name '{}' is not defined".format(node.id))
if isinstance(obj, MangledName):
raise NotImplementedError(
"Only a static object can be used here")
return obj
def static_visit_Attribute(self, node):
value = self.static_visit(node.value)
return getattr(value, node.attr)
class HostObjectMapper:
def __init__(self, first_encoding=0):
self._next_encoding = first_encoding
# id(object) -> (encoding, object)
# this format is required to support non-hashable host objects.
self._d = dict()
def encode(self, obj):
try:
return self._d[id(obj)][0]
except KeyError:
encoding = self._next_encoding
self._d[id(obj)] = (encoding, obj)
self._next_encoding += 1
return encoding
def get_map(self):
return {encoding: obj for i, (encoding, obj) in self._d.items()}
def get_attr_init(attribute_namespace, loc_node):
attr_init = []
for (_, attr), attr_info in attribute_namespace.items():
if hasattr(attr_info.obj, attr):
value = getattr(attr_info.obj, attr)
if (hasattr(value, "kernel_attr_init")
and not value.kernel_attr_init):
continue
value = ast.copy_location(value_to_ast(value), loc_node)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
loc_node)
assign = ast.copy_location(ast.Assign([target], value),
loc_node)
attr_init.append(assign)
return attr_init
def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
attr_writeback = []
for (_, attr), attr_info in attribute_namespace.items():
if attr_info.read_write:
setter = partial(setattr, attr_info.obj, attr)
func = ast.copy_location(
ast.Name("syscall", ast.Load()), loc_node)
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
arg2 = ast.copy_location(
value_to_ast(rpc_mapper.encode(setter)), loc_node)
arg3 = ast.copy_location(
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
call = ast.copy_location(
ast.Call(func=func, args=[arg1, arg2, arg3], keywords=[]),
loc_node)
expr = ast.copy_location(ast.Expr(call), loc_node)
attr_writeback.append(expr)
return attr_writeback
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
# OrderedDict prevents non-determinism in attribute init
attribute_namespace = OrderedDict()
# NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names!
in_use_names = embeddable_func_names | {"sequential", "parallel",
"watchdog"}
mappers = types.SimpleNamespace(
rpc=HostObjectMapper(),
exception=HostObjectMapper(core_language.first_user_eid)
)
func_def = get_inline(
core=core,
attribute_namespace=attribute_namespace,
in_use_names=in_use_names,
retval_name=None,
mappers=mappers,
func=k_function,
args=k_args,
kwargs=k_kwargs)
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
if with_attr_writeback:
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
func_def)
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()

View File

@ -1,130 +0,0 @@
import ast
import types
from artiq.transforms.tools import *
# -1 statement duration could not be pre-determined
# 0 statement has no effect on timeline
# >0 statement is a static delay that advances the timeline
# by the given amount
def _get_duration(stmt):
if isinstance(stmt, (ast.Expr, ast.Assign)):
return _get_duration(stmt.value)
elif isinstance(stmt, ast.If):
if (all(_get_duration(s) == 0 for s in stmt.body)
and all(_get_duration(s) == 0 for s in stmt.orelse)):
return 0
else:
return -1
elif isinstance(stmt, ast.Try):
if (all(_get_duration(s) == 0 for s in stmt.body)
and all(_get_duration(s) == 0 for s in stmt.orelse)
and all(_get_duration(s) == 0 for s in stmt.finalbody)
and all(_get_duration(s) == 0 for s in handler.body
for handler in stmt.handlers)):
return 0
else:
return -1
elif isinstance(stmt, ast.Call):
name = stmt.func.id
assert(name != "delay")
if name == "delay_mu":
try:
da = eval_constant(stmt.args[0])
except NotConstant:
da = -1
return da
elif name == "at_mu":
return -1
else:
return 0
else:
return 0
def _interleave_timelines(timelines):
r = []
current_stmts = []
for stmts in timelines:
it = iter(stmts)
try:
stmt = next(it)
except StopIteration:
pass
else:
current_stmts.append(types.SimpleNamespace(
delay=_get_duration(stmt), stmt=stmt, it=it))
while current_stmts:
dt = min(stmt.delay for stmt in current_stmts)
if dt < 0:
# contains statement(s) with indeterminate duration
return None
if dt > 0:
# advance timeline by dt
for stmt in current_stmts:
stmt.delay -= dt
if stmt.delay == 0:
ref_stmt = stmt.stmt
delay_stmt = ast.copy_location(
ast.Expr(ast.Call(
func=ast.Name("delay_mu", ast.Load()),
args=[value_to_ast(dt)], keywords=[])),
ref_stmt)
r.append(delay_stmt)
else:
for stmt in current_stmts:
if stmt.delay == 0:
r.append(stmt.stmt)
# discard executed statements
exhausted_list = []
for stmt_i, stmt in enumerate(current_stmts):
if stmt.delay == 0:
try:
stmt.stmt = next(stmt.it)
except StopIteration:
exhausted_list.append(stmt_i)
else:
stmt.delay = _get_duration(stmt.stmt)
for offset, i in enumerate(exhausted_list):
current_stmts.pop(i-offset)
return r
def _interleave_stmts(stmts):
replacements = []
for stmt_i, stmt in enumerate(stmts):
if isinstance(stmt, (ast.For, ast.While, ast.If)):
_interleave_stmts(stmt.body)
_interleave_stmts(stmt.orelse)
elif isinstance(stmt, ast.Try):
_interleave_stmts(stmt.body)
_interleave_stmts(stmt.orelse)
_interleave_stmts(stmt.finalbody)
for handler in stmt.handlers:
_interleave_stmts(handler.body)
elif isinstance(stmt, ast.With):
btype = stmt.items[0].context_expr.id
if btype == "sequential":
_interleave_stmts(stmt.body)
replacements.append((stmt_i, stmt.body))
elif btype == "parallel":
timelines = [[s] for s in stmt.body]
for timeline in timelines:
_interleave_stmts(timeline)
merged = _interleave_timelines(timelines)
if merged is not None:
replacements.append((stmt_i, merged))
else:
raise ValueError("Unknown block type: " + btype)
offset = 0
for location, new_stmts in replacements:
stmts[offset+location:offset+location+1] = new_stmts
offset += len(new_stmts) - 1
def interleave(func_def):
_interleave_stmts(func_def.body)

View File

@ -1,43 +0,0 @@
def visit_With(self, node):
self.generic_visit(node)
if (isinstance(node.items[0].context_expr, ast.Call)
and node.items[0].context_expr.func.id == "watchdog"):
idname = "__watchdog_id_" + str(self.watchdog_id_counter)
self.watchdog_id_counter += 1
time = ast.BinOp(left=node.items[0].context_expr.args[0],
op=ast.Mult(),
right=ast.Num(1000))
time_int = ast.Call(
func=ast.Name("round", ast.Load()),
args=[time],
keywords=[], starargs=None, kwargs=None)
syscall_set = ast.Call(
func=ast.Name("syscall", ast.Load()),
args=[ast.Str("watchdog_set"), time_int],
keywords=[], starargs=None, kwargs=None)
stmt_set = ast.copy_location(
ast.Assign(targets=[ast.Name(idname, ast.Store())],
value=syscall_set),
node)
syscall_clear = ast.Call(
func=ast.Name("syscall", ast.Load()),
args=[ast.Str("watchdog_clear"),
ast.Name(idname, ast.Load())],
keywords=[], starargs=None, kwargs=None)
stmt_clear = ast.copy_location(ast.Expr(syscall_clear), node)
node.items[0] = ast.withitem(
context_expr=ast.Name(id="sequential",
ctx=ast.Load()),
optional_vars=None)
node.body = [
stmt_set,
ast.Try(body=node.body,
handlers=[],
orelse=[],
finalbody=[stmt_clear])
]
return node

View File

@ -1,82 +0,0 @@
import ast
from copy import deepcopy
from artiq.transforms.tools import eval_ast, value_to_ast
def _count_stmts(node):
if isinstance(node, list):
return sum(map(_count_stmts, node))
elif isinstance(node, ast.With):
return 1 + _count_stmts(node.body)
elif isinstance(node, (ast.For, ast.While, ast.If)):
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
elif isinstance(node, ast.Try):
r = 1 + _count_stmts(node.body) \
+ _count_stmts(node.orelse) \
+ _count_stmts(node.finalbody)
for handler in node.handlers:
r += 1 + _count_stmts(handler.body)
return r
else:
return 1
def _loop_breakable(node):
if isinstance(node, list):
return any(map(_loop_breakable, node))
elif isinstance(node, (ast.Break, ast.Continue)):
return True
elif isinstance(node, ast.With):
return _loop_breakable(node.body)
elif isinstance(node, ast.If):
return _loop_breakable(node.body) or _loop_breakable(node.orelse)
elif isinstance(node, ast.Try):
if (_loop_breakable(node.body)
or _loop_breakable(node.orelse)
or _loop_breakable(node.finalbody)):
return True
for handler in node.handlers:
if _loop_breakable(handler.body):
return True
return False
else:
return False
class _LoopUnroller(ast.NodeTransformer):
def __init__(self, limit):
self.limit = limit
def visit_For(self, node):
self.generic_visit(node)
try:
it = eval_ast(node.iter)
except:
return node
l_it = len(it)
if l_it:
if (not _loop_breakable(node.body)
and l_it*_count_stmts(node.body) < self.limit):
replacement = []
for i in it:
if not isinstance(i, int):
replacement = None
break
replacement.append(ast.copy_location(
ast.Assign(targets=[node.target],
value=value_to_ast(i)),
node))
replacement += deepcopy(node.body)
if replacement is not None:
return replacement
else:
return node
else:
return node
else:
return node.orelse
def unroll_loops(node, limit):
_LoopUnroller(limit).visit(node)