diff --git a/artiq/py2llvm_old/fractions.py b/artiq/py2llvm_old/fractions.py deleted file mode 100644 index a2895107b..000000000 --- a/artiq/py2llvm_old/fractions.py +++ /dev/null @@ -1,358 +0,0 @@ -import inspect -from pythonparser import parse, ast - -import llvmlite_artiq.ir as ll - -from artiq.py2llvm.values import VGeneric, operators -from artiq.py2llvm.base_types import VBool, VInt, VFloat - - -def _gcd(a, b): - if a < 0: - a = -a - while a: - c = a - a = b % a - b = c - return b - - -def init_module(module): - func_def = parse(inspect.getsource(_gcd)).body[0] - function, _ = module.compile_function(func_def, - {"a": VInt(64), "b": VInt(64)}) - function.linkage = "internal" - - -def _reduce(builder, a, b): - module = builder.basic_block.function.module - for f in module.functions: - if f.name == "_gcd": - gcd_f = f - break - gcd = builder.call(gcd_f, [a, b]) - a = builder.sdiv(a, gcd) - b = builder.sdiv(b, gcd) - return a, b - - -def _signnum(builder, a, b): - function = builder.basic_block.function - orig_block = builder.basic_block - swap_block = function.append_basic_block("sn_swap") - merge_block = function.append_basic_block("sn_merge") - - condition = builder.icmp_signed( - "<", b, ll.Constant(ll.IntType(64), 0)) - builder.cbranch(condition, swap_block, merge_block) - - builder.position_at_end(swap_block) - minusone = ll.Constant(ll.IntType(64), -1) - a_swp = builder.mul(minusone, a) - b_swp = builder.mul(minusone, b) - builder.branch(merge_block) - - builder.position_at_end(merge_block) - a_phi = builder.phi(ll.IntType(64)) - a_phi.add_incoming(a, orig_block) - a_phi.add_incoming(a_swp, swap_block) - b_phi = builder.phi(ll.IntType(64)) - b_phi.add_incoming(b, orig_block) - b_phi.add_incoming(b_swp, swap_block) - - return a_phi, b_phi - - -def _make_ssa(builder, n, d): - value = ll.Constant(ll.ArrayType(ll.IntType(64), 2), ll.Undefined) - value = builder.insert_value(value, n, 0) - value = builder.insert_value(value, d, 1) - return value - - -class VFraction(VGeneric): - def get_llvm_type(self): - return ll.ArrayType(ll.IntType(64), 2) - - def _nd(self, builder): - ssa_value = self.auto_load(builder) - a = builder.extract_value(ssa_value, 0) - b = builder.extract_value(ssa_value, 1) - return a, b - - def set_value_nd(self, builder, a, b): - a = a.o_int64(builder).auto_load(builder) - b = b.o_int64(builder).auto_load(builder) - a, b = _reduce(builder, a, b) - a, b = _signnum(builder, a, b) - self.auto_store(builder, _make_ssa(builder, a, b)) - - def set_value(self, builder, v): - if not isinstance(v, VFraction): - raise TypeError - self.auto_store(builder, v.auto_load(builder)) - - def o_getattr(self, attr, builder): - if attr == "numerator": - idx = 0 - elif attr == "denominator": - idx = 1 - else: - raise AttributeError - r = VInt(64) - if builder is not None: - elt = builder.extract_value(self.auto_load(builder), idx) - r.auto_store(builder, elt) - return r - - def o_bool(self, builder): - r = VBool() - if builder is not None: - zero = ll.Constant(ll.IntType(64), 0) - a = builder.extract_element(self.auto_load(builder), 0) - r.auto_store(builder, builder.icmp_signed("!=", a, zero)) - return r - - def o_intx(self, target_bits, builder): - if builder is None: - return VInt(target_bits) - else: - r = VInt(64) - a, b = self._nd(builder) - r.auto_store(builder, builder.sdiv(a, b)) - return r.o_intx(target_bits, builder) - - def o_roundx(self, target_bits, builder): - if builder is None: - return VInt(target_bits) - else: - r = VInt(64) - a, b = self._nd(builder) - h_b = builder.ashr(b, ll.Constant(ll.IntType(64), 1)) - - function = builder.basic_block.function - add_block = function.append_basic_block("fr_add") - sub_block = function.append_basic_block("fr_sub") - merge_block = function.append_basic_block("fr_merge") - - condition = builder.icmp_signed( - "<", a, ll.Constant(ll.IntType(64), 0)) - builder.cbranch(condition, sub_block, add_block) - - builder.position_at_end(add_block) - a_add = builder.add(a, h_b) - builder.branch(merge_block) - builder.position_at_end(sub_block) - a_sub = builder.sub(a, h_b) - builder.branch(merge_block) - - builder.position_at_end(merge_block) - a = builder.phi(ll.IntType(64)) - a.add_incoming(a_add, add_block) - a.add_incoming(a_sub, sub_block) - r.auto_store(builder, builder.sdiv(a, b)) - return r.o_intx(target_bits, builder) - - def o_float(self, builder): - r = VFloat() - if builder is not None: - a, b = self._nd(builder) - af = builder.sitofp(a, r.get_llvm_type()) - bf = builder.sitofp(b, r.get_llvm_type()) - r.auto_store(builder, builder.fdiv(af, bf)) - return r - - def _o_eq_inv(self, other, builder, ne): - if not isinstance(other, (VInt, VFraction)): - return NotImplemented - r = VBool() - if builder is not None: - if isinstance(other, VInt): - other = other.o_int64(builder) - a, b = self._nd(builder) - ssa_r = builder.and_( - builder.icmp_signed("==", a, - other.auto_load()), - builder.icmp_signed("==", b, - ll.Constant(ll.IntType(64), 1))) - else: - a, b = self._nd(builder) - c, d = other._nd(builder) - ssa_r = builder.and_( - builder.icmp_signed("==", a, c), - builder.icmp_signed("==", b, d)) - if ne: - ssa_r = builder.xor(ssa_r, - ll.Constant(ll.IntType(1), 1)) - r.auto_store(builder, ssa_r) - return r - - def o_eq(self, other, builder): - return self._o_eq_inv(other, builder, False) - - def o_ne(self, other, builder): - return self._o_eq_inv(other, builder, True) - - def _o_cmp(self, other, icmp, builder): - diff = self.o_sub(other, builder) - if diff is NotImplemented: - return NotImplemented - r = VBool() - if builder is not None: - diff = diff.auto_load(builder) - a = builder.extract_value(diff, 0) - zero = ll.Constant(ll.IntType(64), 0) - ssa_r = builder.icmp_signed(icmp, a, zero) - r.auto_store(builder, ssa_r) - return r - - def o_lt(self, other, builder): - return self._o_cmp(other, "<", builder) - - def o_le(self, other, builder): - return self._o_cmp(other, "<=", builder) - - def o_gt(self, other, builder): - return self._o_cmp(other, ">", builder) - - def o_ge(self, other, builder): - return self._o_cmp(other, ">=", builder) - - def _o_addsub(self, other, builder, sub, invert=False): - if isinstance(other, VFloat): - a = self.o_getattr("numerator", builder) - b = self.o_getattr("denominator", builder) - if sub: - if invert: - return operators.truediv( - operators.sub(operators.mul(other, - b, - builder), - a, - builder), - b, - builder) - else: - return operators.truediv( - operators.sub(a, - operators.mul(other, - b, - builder), - builder), - b, - builder) - else: - return operators.truediv( - operators.add(operators.mul(other, - b, - builder), - a, - builder), - b, - builder) - else: - if not isinstance(other, (VFraction, VInt)): - return NotImplemented - r = VFraction() - if builder is not None: - if isinstance(other, VInt): - i = other.o_int64(builder).auto_load(builder) - x, rd = self._nd(builder) - y = builder.mul(rd, i) - else: - a, b = self._nd(builder) - c, d = other._nd(builder) - rd = builder.mul(b, d) - x = builder.mul(a, d) - y = builder.mul(c, b) - if sub: - if invert: - rn = builder.sub(y, x) - else: - rn = builder.sub(x, y) - else: - rn = builder.add(x, y) - rn, rd = _reduce(builder, rn, rd) # rd is already > 0 - r.auto_store(builder, _make_ssa(builder, rn, rd)) - return r - - def o_add(self, other, builder): - return self._o_addsub(other, builder, False) - - def o_sub(self, other, builder): - return self._o_addsub(other, builder, True) - - def or_add(self, other, builder): - return self._o_addsub(other, builder, False) - - def or_sub(self, other, builder): - return self._o_addsub(other, builder, True, True) - - def _o_muldiv(self, other, builder, div, invert=False): - if isinstance(other, VFloat): - a = self.o_getattr("numerator", builder) - b = self.o_getattr("denominator", builder) - if invert: - a, b = b, a - if div: - return operators.truediv(a, - operators.mul(b, other, builder), - builder) - else: - return operators.truediv(operators.mul(a, other, builder), - b, - builder) - else: - if not isinstance(other, (VFraction, VInt)): - return NotImplemented - r = VFraction() - if builder is not None: - a, b = self._nd(builder) - if invert: - a, b = b, a - if isinstance(other, VInt): - i = other.o_int64(builder).auto_load(builder) - if div: - b = builder.mul(b, i) - else: - a = builder.mul(a, i) - else: - c, d = other._nd(builder) - if div: - a = builder.mul(a, d) - b = builder.mul(b, c) - else: - a = builder.mul(a, c) - b = builder.mul(b, d) - if div or invert: - a, b = _signnum(builder, a, b) - a, b = _reduce(builder, a, b) - r.auto_store(builder, _make_ssa(builder, a, b)) - return r - - def o_mul(self, other, builder): - return self._o_muldiv(other, builder, False) - - def o_truediv(self, other, builder): - return self._o_muldiv(other, builder, True) - - def or_mul(self, other, builder): - return self._o_muldiv(other, builder, False) - - def or_truediv(self, other, builder): - # multiply by the inverse - return self._o_muldiv(other, builder, False, True) - - def o_floordiv(self, other, builder): - r = self.o_truediv(other, builder) - if r is NotImplemented: - return r - else: - return r.o_int(builder) - - def or_floordiv(self, other, builder): - r = self.or_truediv(other, builder) - if r is NotImplemented: - return r - else: - return r.o_int(builder) diff --git a/artiq/py2llvm_old/test/py2llvm.py b/artiq/py2llvm_old/test/py2llvm.py deleted file mode 100644 index c6d9f0135..000000000 --- a/artiq/py2llvm_old/test/py2llvm.py +++ /dev/null @@ -1,169 +0,0 @@ -import unittest -from pythonparser import parse, ast -import inspect -from fractions import Fraction -from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double -import struct - -import llvmlite_or1k.binding as llvm - -from artiq.language.core import int64 -from artiq.py2llvm.infer_types import infer_function_types -from artiq.py2llvm import base_types, lists -from artiq.py2llvm.module import Module - -def simplify_encode(a, b): - f = Fraction(a, b) - return f.numerator*1000 + f.denominator - - -def frac_arith_encode(op, a, b, c, d): - if op == 0: - f = Fraction(a, b) - Fraction(c, d) - elif op == 1: - f = Fraction(a, b) + Fraction(c, d) - elif op == 2: - f = Fraction(a, b) * Fraction(c, d) - else: - f = Fraction(a, b) / Fraction(c, d) - return f.numerator*1000 + f.denominator - - -def frac_arith_encode_int(op, a, b, x): - if op == 0: - f = Fraction(a, b) - x - elif op == 1: - f = Fraction(a, b) + x - elif op == 2: - f = Fraction(a, b) * x - else: - f = Fraction(a, b) / x - return f.numerator*1000 + f.denominator - - -def frac_arith_encode_int_rev(op, a, b, x): - if op == 0: - f = x - Fraction(a, b) - elif op == 1: - f = x + Fraction(a, b) - elif op == 2: - f = x * Fraction(a, b) - else: - f = x / Fraction(a, b) - return f.numerator*1000 + f.denominator - - -def frac_arith_float(op, a, b, x): - if op == 0: - return Fraction(a, b) - x - elif op == 1: - return Fraction(a, b) + x - elif op == 2: - return Fraction(a, b) * x - else: - return Fraction(a, b) / x - - -def frac_arith_float_rev(op, a, b, x): - if op == 0: - return x - Fraction(a, b) - elif op == 1: - return x + Fraction(a, b) - elif op == 2: - return x * Fraction(a, b) - else: - return x / Fraction(a, b) - - -class CodeGenCase(unittest.TestCase): - def test_frac_simplify(self): - simplify_encode_c = CompiledFunction( - simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - self.assertEqual( - simplify_encode_c(a, b), simplify_encode(a, b)) - - def _test_frac_arith(self, op): - frac_arith_encode_c = CompiledFunction( - frac_arith_encode, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "c": base_types.VInt(), "d": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - for c in _test_range(): - for d in _test_range(): - self.assertEqual( - frac_arith_encode_c(op, a, b, c, d), - frac_arith_encode(op, a, b, c, d)) - - def test_frac_add(self): - self._test_frac_arith(0) - - def test_frac_sub(self): - self._test_frac_arith(1) - - def test_frac_mul(self): - self._test_frac_arith(2) - - def test_frac_div(self): - self._test_frac_arith(3) - - def _test_frac_arith_int(self, op, rev): - f = frac_arith_encode_int_rev if rev else frac_arith_encode_int - f_c = CompiledFunction(f, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "x": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - for x in _test_range(): - self.assertEqual( - f_c(op, a, b, x), - f(op, a, b, x)) - - def test_frac_add_int(self): - self._test_frac_arith_int(0, False) - self._test_frac_arith_int(0, True) - - def test_frac_sub_int(self): - self._test_frac_arith_int(1, False) - self._test_frac_arith_int(1, True) - - def test_frac_mul_int(self): - self._test_frac_arith_int(2, False) - self._test_frac_arith_int(2, True) - - def test_frac_div_int(self): - self._test_frac_arith_int(3, False) - self._test_frac_arith_int(3, True) - - def _test_frac_arith_float(self, op, rev): - f = frac_arith_float_rev if rev else frac_arith_float - f_c = CompiledFunction(f, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "x": base_types.VFloat()}) - for a in _test_range(): - for b in _test_range(): - for x in _test_range(): - self.assertAlmostEqual( - f_c(op, a, b, x/2), - f(op, a, b, x/2)) - - def test_frac_add_float(self): - self._test_frac_arith_float(0, False) - self._test_frac_arith_float(0, True) - - def test_frac_sub_float(self): - self._test_frac_arith_float(1, False) - self._test_frac_arith_float(1, True) - - def test_frac_mul_float(self): - self._test_frac_arith_float(2, False) - self._test_frac_arith_float(2, True) - - def test_frac_div_float(self): - self._test_frac_arith_float(3, False) - self._test_frac_arith_float(3, True) diff --git a/artiq/py2llvm_old/transforms/inline.py b/artiq/py2llvm_old/transforms/inline.py deleted file mode 100644 index 4d444bbe8..000000000 --- a/artiq/py2llvm_old/transforms/inline.py +++ /dev/null @@ -1,548 +0,0 @@ -import inspect -import textwrap -import ast -import types -import builtins -from fractions import Fraction -from collections import OrderedDict -from functools import partial -from itertools import zip_longest, chain - -from artiq.language import core as core_language -from artiq.language import units -from artiq.transforms.tools import * - - -def new_mangled_name(in_use_names, name): - mangled_name = name - i = 2 - while mangled_name in in_use_names: - mangled_name = name + str(i) - i += 1 - in_use_names.add(mangled_name) - return mangled_name - - -class MangledName: - def __init__(self, s): - self.s = s - - -class AttributeInfo: - def __init__(self, obj, mangled_name, read_write): - self.obj = obj - self.mangled_name = mangled_name - self.read_write = read_write - - -def is_inlinable(core, func): - if hasattr(func, "k_function_info"): - if func.k_function_info.core_name == "": - return True # portable function - if getattr(func.__self__, func.k_function_info.core_name) is core: - return True # kernel function for the same core device - return False - - -class GlobalNamespace: - def __init__(self, func): - self.func_gd = inspect.getmodule(func).__dict__ - - def __getitem__(self, item): - try: - return self.func_gd[item] - except KeyError: - return getattr(builtins, item) - - -class UndefinedArg: - pass - - -def get_function_args(func_args, func_tr, args, kwargs): - # OrderedDict prevents non-determinism in argument init - r = OrderedDict() - - # Process positional arguments. Any missing positional argument values - # are set to UndefinedArg. - for arg, arg_value in zip_longest(func_args.args, args, - fillvalue=UndefinedArg): - if arg is UndefinedArg: - raise TypeError("Got too many positional arguments") - if arg.arg in r: - raise SyntaxError("Duplicate argument '{}' in function definition" - .format(arg.arg)) - r[arg.arg] = arg_value - - # Process keyword arguments. Any missing keyword-only argument values - # are set to UndefinedArg. - valid_arg_names = {arg.arg for arg in - chain(func_args.args, func_args.kwonlyargs)} - for arg in func_args.kwonlyargs: - if arg.arg in r: - raise SyntaxError("Duplicate argument '{}' in function definition" - .format(arg.arg)) - r[arg.arg] = UndefinedArg - for arg_name, arg_value in kwargs.items(): - if arg_name not in valid_arg_names: - raise TypeError("Got unexpected keyword argument '{}'" - .format(arg_name)) - if r[arg_name] is not UndefinedArg: - raise TypeError("Got multiple values for argument '{}'" - .format(arg_name)) - r[arg_name] = arg_value - - # Replace any UndefinedArg positional arguments with the default value, - # when provided. - for arg, default in zip(func_args.args[-len(func_args.defaults):], - func_args.defaults): - if r[arg.arg] is UndefinedArg: - r[arg.arg] = func_tr.code_visit(default) - # Same with keyword-only arguments. - for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults): - if default is not None and r[arg.arg] is UndefinedArg: - r[arg.arg] = func_tr.code_visit(default) - - # Check that no argument was left undefined. - missing_arguments = ["'"+arg+"'" for arg, value in r.items() - if value is UndefinedArg] - if missing_arguments: - raise TypeError("Missing argument(s): " + " ".join(missing_arguments)) - - return r - - -# args/kwargs can contain values or AST nodes -def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers, - func, args, kwargs): - global_namespace = GlobalNamespace(func) - func_tr = Function(core, - global_namespace, attribute_namespace, in_use_names, - retval_name, mappers) - func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0] - - # Initialize arguments. - # The local namespace is empty so code_visit will always resolve - # using the global namespace. - arg_init = [] - arg_name_map = [] - arg_dict = get_function_args(func_def.args, func_tr, args, kwargs) - for arg_name, arg_value in arg_dict.items(): - if isinstance(arg_value, ast.AST): - value = arg_value - else: - try: - value = ast.copy_location(value_to_ast(arg_value), func_def) - except NotASTRepresentable: - value = None - if value is None: - # static object - func_tr.local_namespace[arg_name] = arg_value - else: - # set parameter value with "name = value" - # assignment at beginning of function - new_name = new_mangled_name(in_use_names, arg_name) - arg_name_map.append((arg_name, new_name)) - target = ast.copy_location(ast.Name(new_name, ast.Store()), - func_def) - assign = ast.copy_location(ast.Assign([target], value), - func_def) - arg_init.append(assign) - # Commit arguments to the local namespace at the end to handle cases - # such as f(x, y=x) (for the default value of y, x must be resolved - # using the global namespace). - for arg_name, mangled_name in arg_name_map: - func_tr.local_namespace[arg_name] = MangledName(mangled_name) - - func_def = func_tr.code_visit(func_def) - func_def.body[0:0] = arg_init - return func_def - - -class Function: - def __init__(self, core, - global_namespace, attribute_namespace, in_use_names, - retval_name, mappers): - # The core device on which this function is executing. - self.core = core - - # Local and global namespaces: - # original name -> MangledName or static object - self.local_namespace = dict() - self.global_namespace = global_namespace - - # (id(static object), attribute) -> AttributeInfo - self.attribute_namespace = attribute_namespace - - # All names currently in use, in the namespace of the combined - # function. - # When creating a name for a new object, check that it is not - # already in this set. - self.in_use_names = in_use_names - - # Name of the variable to store the return value to, or None - # to keep the return statement. - self.retval_name = retval_name - - # Host object mappers, for RPC and exception numbers - self.mappers = mappers - - self._insertion_point = None - - # This is ast.NodeVisitor/NodeTransformer from CPython, modified - # to add code_ prefix. - def code_visit(self, node): - method = "code_visit_" + node.__class__.__name__ - visitor = getattr(self, method, self.code_generic_visit) - return visitor(node) - - # This is ast.NodeTransformer.generic_visit from CPython, modified - # to update self._insertion_point. - def code_generic_visit(self, node, exclude_fields=set()): - for field, old_value in ast.iter_fields(node): - if field in exclude_fields: - continue - old_value = getattr(node, field, None) - if isinstance(old_value, list): - prev_insertion_point = self._insertion_point - new_values = [] - if field in ("body", "orelse", "finalbody"): - self._insertion_point = new_values - for value in old_value: - if isinstance(value, ast.AST): - value = self.code_visit(value) - if value is None: - continue - elif not isinstance(value, ast.AST): - new_values.extend(value) - continue - new_values.append(value) - old_value[:] = new_values - self._insertion_point = prev_insertion_point - elif isinstance(old_value, ast.AST): - new_node = self.code_visit(old_value) - if new_node is None: - delattr(node, field) - else: - setattr(node, field, new_node) - return node - - def code_visit_Name(self, node): - if isinstance(node.ctx, ast.Store): - if (node.id in self.local_namespace - and isinstance(self.local_namespace[node.id], - MangledName)): - new_name = self.local_namespace[node.id].s - else: - new_name = new_mangled_name(self.in_use_names, node.id) - self.local_namespace[node.id] = MangledName(new_name) - node.id = new_name - return node - else: - try: - obj = self.local_namespace[node.id] - except KeyError: - try: - obj = self.global_namespace[node.id] - except KeyError: - raise NameError("name '{}' is not defined".format(node.id)) - if isinstance(obj, MangledName): - node.id = obj.s - return node - else: - try: - return value_to_ast(obj) - except NotASTRepresentable: - raise NotImplementedError( - "Static object cannot be used here") - - def code_visit_Attribute(self, node): - # There are two cases of attributes: - # 1. static object attributes, e.g. self.foo - # 2. dynamic expression attributes, e.g. - # (Fraction(1, 2) + x).numerator - # Static object resolution has no side effects so we try it first. - try: - obj = self.static_visit(node.value) - except: - self.code_generic_visit(node) - return node - else: - key = (id(obj), node.attr) - try: - attr_info = self.attribute_namespace[key] - except KeyError: - new_name = new_mangled_name(self.in_use_names, node.attr) - attr_info = AttributeInfo(obj, new_name, False) - self.attribute_namespace[key] = attr_info - if isinstance(node.ctx, ast.Store): - attr_info.read_write = True - return ast.copy_location( - ast.Name(attr_info.mangled_name, node.ctx), - node) - - def code_visit_Call(self, node): - func = self.static_visit(node.func) - node.args = [self.code_visit(arg) for arg in node.args] - for kw in node.keywords: - kw.value = self.code_visit(kw.value) - - if is_embeddable(func): - node.func = ast.copy_location( - ast.Name(func.__name__, ast.Load()), - node) - return node - elif is_inlinable(self.core, func): - retval_name = func.k_function_info.k_function.__name__ + "_return" - retval_name_m = new_mangled_name(self.in_use_names, retval_name) - args = [func.__self__] + node.args - kwargs = {kw.arg: kw.value for kw in node.keywords} - inlined = get_inline(self.core, - self.attribute_namespace, self.in_use_names, - retval_name_m, self.mappers, - func.k_function_info.k_function, - args, kwargs) - seq = ast.copy_location( - ast.With( - items=[ast.withitem(context_expr=ast.Name(id="sequential", - ctx=ast.Load()), - optional_vars=None)], - body=inlined.body), - node) - self._insertion_point.append(seq) - return ast.copy_location(ast.Name(retval_name_m, ast.Load()), - node) - else: - arg1 = ast.copy_location(ast.Str("rpc"), node) - arg2 = ast.copy_location( - value_to_ast(self.mappers.rpc.encode(func)), node) - node.args[0:0] = [arg1, arg2] - node.func = ast.copy_location( - ast.Name("syscall", ast.Load()), node) - return node - - def code_visit_Return(self, node): - self.code_generic_visit(node) - if self.retval_name is None: - return node - else: - return ast.copy_location( - ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())], - value=node.value), - node) - - def code_visit_Expr(self, node): - if isinstance(node.value, ast.Str): - # Strip docstrings. This also removes strings appearing in the - # middle of the code, but they are nops. - return None - self.code_generic_visit(node) - if isinstance(node.value, ast.Name): - # Remove Expr nodes that contain only a name, likely due to - # function call inlining. Such nodes that were originally in the - # code are also removed, but this does not affect the semantics of - # the code as they are nops. - return None - else: - return node - - def encode_exception(self, e): - exception_class = self.static_visit(e) - if not inspect.isclass(exception_class): - raise NotImplementedError("Exception type must be a class") - if issubclass(exception_class, core_language.RuntimeException): - exception_id = exception_class.eid - else: - exception_id = self.mappers.exception.encode(exception_class) - return ast.copy_location( - ast.Call(func=ast.Name("EncodedException", ast.Load()), - args=[value_to_ast(exception_id)], keywords=[]), - e) - - def code_visit_Raise(self, node): - if node.cause is not None: - raise NotImplementedError("Exception causes are not supported") - if node.exc is not None: - node.exc = self.encode_exception(node.exc) - return node - - def code_visit_ExceptHandler(self, node): - if node.name is not None: - raise NotImplementedError("'as target' is not supported") - if node.type is not None: - if isinstance(node.type, ast.Tuple): - node.type.elts = [self.encode_exception(e) - for e in node.type.elts] - else: - node.type = self.encode_exception(node.type) - self.code_generic_visit(node) - return node - - def get_user_ctxm(self, context_expr): - try: - ctxm = self.static_visit(context_expr) - except: - # this also catches watchdog() - return None - else: - if (ctxm is core_language.sequential - or ctxm is core_language.parallel): - return None - return ctxm - - def code_visit_With(self, node): - if len(node.items) != 1: - raise NotImplementedError - item = node.items[0] - if item.optional_vars is not None: - raise NotImplementedError - ctxm = self.get_user_ctxm(item.context_expr) - if ctxm is None: - self.code_generic_visit(node) - return node - - # user context manager - self.code_generic_visit(node, {"items"}) - if (not hasattr(ctxm, "__enter__") - or not hasattr(ctxm.__enter__, "k_function_info")): - raise NotImplementedError - enter = get_inline(self.core, - self.attribute_namespace, self.in_use_names, - None, self.mappers, - ctxm.__enter__.k_function_info.k_function, - [ctxm], dict()) - if (not hasattr(ctxm, "__exit__") - or not hasattr(ctxm.__exit__, "k_function_info")): - raise NotImplementedError - exit = get_inline(self.core, - self.attribute_namespace, self.in_use_names, - None, self.mappers, - ctxm.__exit__.k_function_info.k_function, - [ctxm, None, None, None], dict()) - try_stmt = ast.copy_location( - ast.Try(body=node.body, - handlers=[], - orelse=[], - finalbody=exit.body), node) - return ast.copy_location( - ast.With( - items=[ast.withitem(context_expr=ast.Name(id="sequential", - ctx=ast.Load()), - optional_vars=None)], - body=enter.body + [try_stmt]), - node) - - def code_visit_FunctionDef(self, node): - node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], - kw_defaults=[], kwarg=None, defaults=[]) - node.decorator_list = [] - self.code_generic_visit(node) - return node - - def static_visit(self, node): - method = "static_visit_" + node.__class__.__name__ - visitor = getattr(self, method) - return visitor(node) - - def static_visit_Name(self, node): - try: - obj = self.local_namespace[node.id] - except KeyError: - try: - obj = self.global_namespace[node.id] - except KeyError: - raise NameError("name '{}' is not defined".format(node.id)) - if isinstance(obj, MangledName): - raise NotImplementedError( - "Only a static object can be used here") - return obj - - def static_visit_Attribute(self, node): - value = self.static_visit(node.value) - return getattr(value, node.attr) - - -class HostObjectMapper: - def __init__(self, first_encoding=0): - self._next_encoding = first_encoding - # id(object) -> (encoding, object) - # this format is required to support non-hashable host objects. - self._d = dict() - - def encode(self, obj): - try: - return self._d[id(obj)][0] - except KeyError: - encoding = self._next_encoding - self._d[id(obj)] = (encoding, obj) - self._next_encoding += 1 - return encoding - - def get_map(self): - return {encoding: obj for i, (encoding, obj) in self._d.items()} - - -def get_attr_init(attribute_namespace, loc_node): - attr_init = [] - for (_, attr), attr_info in attribute_namespace.items(): - if hasattr(attr_info.obj, attr): - value = getattr(attr_info.obj, attr) - if (hasattr(value, "kernel_attr_init") - and not value.kernel_attr_init): - continue - value = ast.copy_location(value_to_ast(value), loc_node) - target = ast.copy_location(ast.Name(attr_info.mangled_name, - ast.Store()), - loc_node) - assign = ast.copy_location(ast.Assign([target], value), - loc_node) - attr_init.append(assign) - return attr_init - - -def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node): - attr_writeback = [] - for (_, attr), attr_info in attribute_namespace.items(): - if attr_info.read_write: - setter = partial(setattr, attr_info.obj, attr) - func = ast.copy_location( - ast.Name("syscall", ast.Load()), loc_node) - arg1 = ast.copy_location(ast.Str("rpc"), loc_node) - arg2 = ast.copy_location( - value_to_ast(rpc_mapper.encode(setter)), loc_node) - arg3 = ast.copy_location( - ast.Name(attr_info.mangled_name, ast.Load()), loc_node) - call = ast.copy_location( - ast.Call(func=func, args=[arg1, arg2, arg3], keywords=[]), - loc_node) - expr = ast.copy_location(ast.Expr(call), loc_node) - attr_writeback.append(expr) - return attr_writeback - - -def inline(core, k_function, k_args, k_kwargs, with_attr_writeback): - # OrderedDict prevents non-determinism in attribute init - attribute_namespace = OrderedDict() - # NOTE: in_use_names will be mutated. Do not mutate embeddable_func_names! - in_use_names = embeddable_func_names | {"sequential", "parallel", - "watchdog"} - mappers = types.SimpleNamespace( - rpc=HostObjectMapper(), - exception=HostObjectMapper(core_language.first_user_eid) - ) - func_def = get_inline( - core=core, - attribute_namespace=attribute_namespace, - in_use_names=in_use_names, - retval_name=None, - mappers=mappers, - func=k_function, - args=k_args, - kwargs=k_kwargs) - - func_def.body[0:0] = get_attr_init(attribute_namespace, func_def) - if with_attr_writeback: - func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc, - func_def) - - return func_def, mappers.rpc.get_map(), mappers.exception.get_map() diff --git a/artiq/py2llvm_old/transforms/interleave.py b/artiq/py2llvm_old/transforms/interleave.py deleted file mode 100644 index 7d1f733ff..000000000 --- a/artiq/py2llvm_old/transforms/interleave.py +++ /dev/null @@ -1,130 +0,0 @@ -import ast -import types - -from artiq.transforms.tools import * - - -# -1 statement duration could not be pre-determined -# 0 statement has no effect on timeline -# >0 statement is a static delay that advances the timeline -# by the given amount -def _get_duration(stmt): - if isinstance(stmt, (ast.Expr, ast.Assign)): - return _get_duration(stmt.value) - elif isinstance(stmt, ast.If): - if (all(_get_duration(s) == 0 for s in stmt.body) - and all(_get_duration(s) == 0 for s in stmt.orelse)): - return 0 - else: - return -1 - elif isinstance(stmt, ast.Try): - if (all(_get_duration(s) == 0 for s in stmt.body) - and all(_get_duration(s) == 0 for s in stmt.orelse) - and all(_get_duration(s) == 0 for s in stmt.finalbody) - and all(_get_duration(s) == 0 for s in handler.body - for handler in stmt.handlers)): - return 0 - else: - return -1 - elif isinstance(stmt, ast.Call): - name = stmt.func.id - assert(name != "delay") - if name == "delay_mu": - try: - da = eval_constant(stmt.args[0]) - except NotConstant: - da = -1 - return da - elif name == "at_mu": - return -1 - else: - return 0 - else: - return 0 - - -def _interleave_timelines(timelines): - r = [] - - current_stmts = [] - for stmts in timelines: - it = iter(stmts) - try: - stmt = next(it) - except StopIteration: - pass - else: - current_stmts.append(types.SimpleNamespace( - delay=_get_duration(stmt), stmt=stmt, it=it)) - - while current_stmts: - dt = min(stmt.delay for stmt in current_stmts) - if dt < 0: - # contains statement(s) with indeterminate duration - return None - if dt > 0: - # advance timeline by dt - for stmt in current_stmts: - stmt.delay -= dt - if stmt.delay == 0: - ref_stmt = stmt.stmt - delay_stmt = ast.copy_location( - ast.Expr(ast.Call( - func=ast.Name("delay_mu", ast.Load()), - args=[value_to_ast(dt)], keywords=[])), - ref_stmt) - r.append(delay_stmt) - else: - for stmt in current_stmts: - if stmt.delay == 0: - r.append(stmt.stmt) - # discard executed statements - exhausted_list = [] - for stmt_i, stmt in enumerate(current_stmts): - if stmt.delay == 0: - try: - stmt.stmt = next(stmt.it) - except StopIteration: - exhausted_list.append(stmt_i) - else: - stmt.delay = _get_duration(stmt.stmt) - for offset, i in enumerate(exhausted_list): - current_stmts.pop(i-offset) - - return r - - -def _interleave_stmts(stmts): - replacements = [] - for stmt_i, stmt in enumerate(stmts): - if isinstance(stmt, (ast.For, ast.While, ast.If)): - _interleave_stmts(stmt.body) - _interleave_stmts(stmt.orelse) - elif isinstance(stmt, ast.Try): - _interleave_stmts(stmt.body) - _interleave_stmts(stmt.orelse) - _interleave_stmts(stmt.finalbody) - for handler in stmt.handlers: - _interleave_stmts(handler.body) - elif isinstance(stmt, ast.With): - btype = stmt.items[0].context_expr.id - if btype == "sequential": - _interleave_stmts(stmt.body) - replacements.append((stmt_i, stmt.body)) - elif btype == "parallel": - timelines = [[s] for s in stmt.body] - for timeline in timelines: - _interleave_stmts(timeline) - merged = _interleave_timelines(timelines) - if merged is not None: - replacements.append((stmt_i, merged)) - else: - raise ValueError("Unknown block type: " + btype) - offset = 0 - for location, new_stmts in replacements: - stmts[offset+location:offset+location+1] = new_stmts - offset += len(new_stmts) - 1 - - -def interleave(func_def): - _interleave_stmts(func_def.body) diff --git a/artiq/py2llvm_old/transforms/quantize_time.py b/artiq/py2llvm_old/transforms/quantize_time.py deleted file mode 100644 index 42e04f564..000000000 --- a/artiq/py2llvm_old/transforms/quantize_time.py +++ /dev/null @@ -1,43 +0,0 @@ - def visit_With(self, node): - self.generic_visit(node) - if (isinstance(node.items[0].context_expr, ast.Call) - and node.items[0].context_expr.func.id == "watchdog"): - - idname = "__watchdog_id_" + str(self.watchdog_id_counter) - self.watchdog_id_counter += 1 - - time = ast.BinOp(left=node.items[0].context_expr.args[0], - op=ast.Mult(), - right=ast.Num(1000)) - time_int = ast.Call( - func=ast.Name("round", ast.Load()), - args=[time], - keywords=[], starargs=None, kwargs=None) - syscall_set = ast.Call( - func=ast.Name("syscall", ast.Load()), - args=[ast.Str("watchdog_set"), time_int], - keywords=[], starargs=None, kwargs=None) - stmt_set = ast.copy_location( - ast.Assign(targets=[ast.Name(idname, ast.Store())], - value=syscall_set), - node) - - syscall_clear = ast.Call( - func=ast.Name("syscall", ast.Load()), - args=[ast.Str("watchdog_clear"), - ast.Name(idname, ast.Load())], - keywords=[], starargs=None, kwargs=None) - stmt_clear = ast.copy_location(ast.Expr(syscall_clear), node) - - node.items[0] = ast.withitem( - context_expr=ast.Name(id="sequential", - ctx=ast.Load()), - optional_vars=None) - node.body = [ - stmt_set, - ast.Try(body=node.body, - handlers=[], - orelse=[], - finalbody=[stmt_clear]) - ] - return node diff --git a/artiq/py2llvm_old/transforms/unroll_loops.py b/artiq/py2llvm_old/transforms/unroll_loops.py deleted file mode 100644 index 1840e7248..000000000 --- a/artiq/py2llvm_old/transforms/unroll_loops.py +++ /dev/null @@ -1,82 +0,0 @@ -import ast -from copy import deepcopy - -from artiq.transforms.tools import eval_ast, value_to_ast - - -def _count_stmts(node): - if isinstance(node, list): - return sum(map(_count_stmts, node)) - elif isinstance(node, ast.With): - return 1 + _count_stmts(node.body) - elif isinstance(node, (ast.For, ast.While, ast.If)): - return 1 + _count_stmts(node.body) + _count_stmts(node.orelse) - elif isinstance(node, ast.Try): - r = 1 + _count_stmts(node.body) \ - + _count_stmts(node.orelse) \ - + _count_stmts(node.finalbody) - for handler in node.handlers: - r += 1 + _count_stmts(handler.body) - return r - else: - return 1 - - -def _loop_breakable(node): - if isinstance(node, list): - return any(map(_loop_breakable, node)) - elif isinstance(node, (ast.Break, ast.Continue)): - return True - elif isinstance(node, ast.With): - return _loop_breakable(node.body) - elif isinstance(node, ast.If): - return _loop_breakable(node.body) or _loop_breakable(node.orelse) - elif isinstance(node, ast.Try): - if (_loop_breakable(node.body) - or _loop_breakable(node.orelse) - or _loop_breakable(node.finalbody)): - return True - for handler in node.handlers: - if _loop_breakable(handler.body): - return True - return False - else: - return False - - -class _LoopUnroller(ast.NodeTransformer): - def __init__(self, limit): - self.limit = limit - - def visit_For(self, node): - self.generic_visit(node) - try: - it = eval_ast(node.iter) - except: - return node - l_it = len(it) - if l_it: - if (not _loop_breakable(node.body) - and l_it*_count_stmts(node.body) < self.limit): - replacement = [] - for i in it: - if not isinstance(i, int): - replacement = None - break - replacement.append(ast.copy_location( - ast.Assign(targets=[node.target], - value=value_to_ast(i)), - node)) - replacement += deepcopy(node.body) - if replacement is not None: - return replacement - else: - return node - else: - return node - else: - return node.orelse - - -def unroll_loops(node, limit): - _LoopUnroller(limit).visit(node)