From 200330a808186f731ab4cc0166b5bf2c4248ef9e Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 10 Aug 2015 20:36:39 +0300 Subject: [PATCH] Remove parts of py2llvm that are implemented in the new compiler. --- artiq/py2llvm_old/__init__.py | 6 - artiq/py2llvm_old/ast_body.py | 539 ---------------- artiq/py2llvm_old/base_types.py | 321 ---------- artiq/py2llvm_old/infer_types.py | 75 --- artiq/py2llvm_old/iterators.py | 51 -- artiq/py2llvm_old/lists.py | 72 --- artiq/py2llvm_old/module.py | 62 -- artiq/py2llvm_old/test/py2llvm.py | 169 +++++ artiq/py2llvm_old/tools.py | 5 - artiq/{ => py2llvm_old}/transforms/inline.py | 0 .../transforms/interleave.py | 0 .../transforms/lower_time.py | 0 .../transforms/quantize_time.py | 0 .../transforms/unroll_loops.py | 0 artiq/py2llvm_old/values.py | 94 --- artiq/test/py2llvm.py | 372 ----------- artiq/test/transforms.py | 44 -- artiq/transforms/__init__.py | 0 artiq/transforms/fold_constants.py | 156 ----- artiq/transforms/remove_dead_code.py | 59 -- artiq/transforms/remove_inter_assigns.py | 149 ----- artiq/transforms/tools.py | 141 ---- artiq/transforms/unparse.py | 600 ------------------ 23 files changed, 169 insertions(+), 2746 deletions(-) delete mode 100644 artiq/py2llvm_old/__init__.py delete mode 100644 artiq/py2llvm_old/ast_body.py delete mode 100644 artiq/py2llvm_old/base_types.py delete mode 100644 artiq/py2llvm_old/infer_types.py delete mode 100644 artiq/py2llvm_old/iterators.py delete mode 100644 artiq/py2llvm_old/lists.py delete mode 100644 artiq/py2llvm_old/module.py create mode 100644 artiq/py2llvm_old/test/py2llvm.py delete mode 100644 artiq/py2llvm_old/tools.py rename artiq/{ => py2llvm_old}/transforms/inline.py (100%) rename artiq/{ => py2llvm_old}/transforms/interleave.py (100%) rename artiq/{ => py2llvm_old}/transforms/lower_time.py (100%) rename artiq/{ => py2llvm_old}/transforms/quantize_time.py (100%) rename artiq/{ => py2llvm_old}/transforms/unroll_loops.py (100%) delete mode 100644 artiq/py2llvm_old/values.py delete mode 100644 artiq/test/py2llvm.py delete mode 100644 artiq/test/transforms.py delete mode 100644 artiq/transforms/__init__.py delete mode 100644 artiq/transforms/fold_constants.py delete mode 100644 artiq/transforms/remove_dead_code.py delete mode 100644 artiq/transforms/remove_inter_assigns.py delete mode 100644 artiq/transforms/tools.py delete mode 100644 artiq/transforms/unparse.py diff --git a/artiq/py2llvm_old/__init__.py b/artiq/py2llvm_old/__init__.py deleted file mode 100644 index ebb8a93af..000000000 --- a/artiq/py2llvm_old/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from artiq.py2llvm.module import Module - -def get_runtime_binary(runtime, func_def): - module = Module(runtime) - module.compile_function(func_def, dict()) - return module.emit_object() diff --git a/artiq/py2llvm_old/ast_body.py b/artiq/py2llvm_old/ast_body.py deleted file mode 100644 index b310ae78c..000000000 --- a/artiq/py2llvm_old/ast_body.py +++ /dev/null @@ -1,539 +0,0 @@ -from pythonparser import ast - -import llvmlite_or1k.ir as ll - -from artiq.py2llvm import values, base_types, fractions, lists, iterators -from artiq.py2llvm.tools import is_terminated - - -_ast_unops = { - ast.Invert: "o_inv", - ast.Not: "o_not", - ast.UAdd: "o_pos", - ast.USub: "o_neg" -} - -_ast_binops = { - ast.Add: values.operators.add, - ast.Sub: values.operators.sub, - ast.Mult: values.operators.mul, - ast.Div: values.operators.truediv, - ast.FloorDiv: values.operators.floordiv, - ast.Mod: values.operators.mod, - ast.Pow: values.operators.pow, - ast.LShift: values.operators.lshift, - ast.RShift: values.operators.rshift, - ast.BitOr: values.operators.or_, - ast.BitXor: values.operators.xor, - ast.BitAnd: values.operators.and_ -} - -_ast_cmps = { - ast.Eq: values.operators.eq, - ast.NotEq: values.operators.ne, - ast.Lt: values.operators.lt, - ast.LtE: values.operators.le, - ast.Gt: values.operators.gt, - ast.GtE: values.operators.ge -} - - -class Visitor: - def __init__(self, runtime, ns, builder=None): - self.runtime = runtime - self.ns = ns - self.builder = builder - self._break_stack = [] - self._continue_stack = [] - self._active_exception_stack = [] - self._exception_level_stack = [0] - - # builder can be None for visit_expression - def visit_expression(self, node): - method = "_visit_expr_" + node.__class__.__name__ - try: - visitor = getattr(self, method) - except AttributeError: - raise NotImplementedError("Unsupported node '{}' in expression" - .format(node.__class__.__name__)) - return visitor(node) - - def _visit_expr_Name(self, node): - try: - r = self.ns[node.id] - except KeyError: - raise NameError("Name '{}' is not defined".format(node.id)) - return r - - def _visit_expr_NameConstant(self, node): - v = node.value - if v is None: - r = base_types.VNone() - elif isinstance(v, bool): - r = base_types.VBool() - else: - raise NotImplementedError - if self.builder is not None: - r.set_const_value(self.builder, v) - return r - - def _visit_expr_Num(self, node): - n = node.n - if isinstance(n, int): - if abs(n) < 2**31: - r = base_types.VInt() - else: - r = base_types.VInt(64) - elif isinstance(n, float): - r = base_types.VFloat() - else: - raise NotImplementedError - if self.builder is not None: - r.set_const_value(self.builder, n) - return r - - def _visit_expr_UnaryOp(self, node): - value = self.visit_expression(node.operand) - return getattr(value, _ast_unops[type(node.op)])(self.builder) - - def _visit_expr_BinOp(self, node): - return _ast_binops[type(node.op)](self.visit_expression(node.left), - self.visit_expression(node.right), - self.builder) - - def _visit_expr_BoolOp(self, node): - if self.builder is not None: - initial_block = self.builder.basic_block - function = initial_block.function - merge_block = function.append_basic_block("b_merge") - - test_blocks = [] - test_values = [] - for i, value in enumerate(node.values): - if self.builder is not None: - test_block = function.append_basic_block("b_{}_test".format(i)) - test_blocks.append(test_block) - self.builder.position_at_end(test_block) - test_values.append(self.visit_expression(value)) - - result = test_values[0].new() - for value in test_values[1:]: - result.merge(value) - - if self.builder is not None: - self.builder.position_at_end(initial_block) - result.alloca(self.builder, "b_result") - self.builder.branch(test_blocks[0]) - - next_test_blocks = test_blocks[1:] - next_test_blocks.append(None) - for block, next_block, value in zip(test_blocks, - next_test_blocks, - test_values): - self.builder.position_at_end(block) - bval = value.o_bool(self.builder) - result.auto_store(self.builder, - value.auto_load(self.builder)) - if next_block is None: - self.builder.branch(merge_block) - else: - if isinstance(node.op, ast.Or): - self.builder.cbranch(bval.auto_load(self.builder), - merge_block, - next_block) - elif isinstance(node.op, ast.And): - self.builder.cbranch(bval.auto_load(self.builder), - next_block, - merge_block) - else: - raise NotImplementedError - self.builder.position_at_end(merge_block) - - return result - - def _visit_expr_Compare(self, node): - comparisons = [] - old_comparator = self.visit_expression(node.left) - for op, comparator_a in zip(node.ops, node.comparators): - comparator = self.visit_expression(comparator_a) - comparison = _ast_cmps[type(op)](old_comparator, comparator, - self.builder) - comparisons.append(comparison) - old_comparator = comparator - r = comparisons[0] - for comparison in comparisons[1:]: - r = values.operators.and_(r, comparison) - return r - - def _visit_expr_Call(self, node): - fn = node.func.id - if fn in {"bool", "int", "int64", "round", "round64", "float", "len"}: - value = self.visit_expression(node.args[0]) - return getattr(value, "o_" + fn)(self.builder) - elif fn == "Fraction": - r = fractions.VFraction() - if self.builder is not None: - numerator = self.visit_expression(node.args[0]) - denominator = self.visit_expression(node.args[1]) - r.set_value_nd(self.builder, numerator, denominator) - return r - elif fn == "range": - return iterators.IRange( - self.builder, - [self.visit_expression(arg) for arg in node.args]) - elif fn == "syscall": - return self.runtime.build_syscall( - node.args[0].s, - [self.visit_expression(expr) for expr in node.args[1:]], - self.builder) - else: - raise NameError("Function '{}' is not defined".format(fn)) - - def _visit_expr_Attribute(self, node): - value = self.visit_expression(node.value) - return value.o_getattr(node.attr, self.builder) - - def _visit_expr_List(self, node): - elts = [self.visit_expression(elt) for elt in node.elts] - if elts: - el_type = elts[0].new() - for elt in elts[1:]: - el_type.merge(elt) - else: - el_type = base_types.VNone() - count = len(elts) - r = lists.VList(el_type, count) - r.elts = elts - return r - - def _visit_expr_ListComp(self, node): - if len(node.generators) != 1: - raise NotImplementedError - generator = node.generators[0] - if not isinstance(generator, ast.comprehension): - raise NotImplementedError - if not isinstance(generator.target, ast.Name): - raise NotImplementedError - target = generator.target.id - if not isinstance(generator.iter, ast.Call): - raise NotImplementedError - if not isinstance(generator.iter.func, ast.Name): - raise NotImplementedError - if generator.iter.func.id != "range": - raise NotImplementedError - if len(generator.iter.args) != 1: - raise NotImplementedError - if not isinstance(generator.iter.args[0], ast.Num): - raise NotImplementedError - count = generator.iter.args[0].n - - # Prevent incorrect use of the generator target, if it is defined in - # the local function namespace. - if target in self.ns: - old_target_val = self.ns[target] - del self.ns[target] - else: - old_target_val = None - elt = self.visit_expression(node.elt) - if old_target_val is not None: - self.ns[target] = old_target_val - - el_type = elt.new() - r = lists.VList(el_type, count) - r.elt = elt - return r - - def _visit_expr_Subscript(self, node): - value = self.visit_expression(node.value) - if isinstance(node.slice, ast.Index): - index = self.visit_expression(node.slice.value) - else: - raise NotImplementedError - return value.o_subscript(index, self.builder) - - def visit_statements(self, stmts): - for node in stmts: - node_type = node.__class__.__name__ - method = "_visit_stmt_" + node_type - try: - visitor = getattr(self, method) - except AttributeError: - raise NotImplementedError("Unsupported node '{}' in statement" - .format(node_type)) - visitor(node) - if node_type in ("Return", "Break", "Continue"): - break - - def _bb_terminated(self): - return is_terminated(self.builder.basic_block) - - def _visit_stmt_Assign(self, node): - val = self.visit_expression(node.value) - if isinstance(node.value, ast.List): - if len(node.targets) > 1: - raise NotImplementedError - target = self.visit_expression(node.targets[0]) - target.set_count(self.builder, val.alloc_count) - for i, elt in enumerate(val.elts): - idx = base_types.VInt() - idx.set_const_value(self.builder, i) - target.o_subscript(idx, self.builder).set_value(self.builder, - elt) - elif isinstance(node.value, ast.ListComp): - if len(node.targets) > 1: - raise NotImplementedError - target = self.visit_expression(node.targets[0]) - target.set_count(self.builder, val.alloc_count) - - i = base_types.VInt() - i.alloca(self.builder) - i.auto_store(self.builder, ll.Constant(ll.IntType(32), 0)) - - function = self.builder.basic_block.function - copy_block = function.append_basic_block("ai_copy") - end_block = function.append_basic_block("ai_end") - self.builder.branch(copy_block) - - self.builder.position_at_end(copy_block) - target.o_subscript(i, self.builder).set_value(self.builder, - val.elt) - i.auto_store(self.builder, self.builder.add( - i.auto_load(self.builder), - ll.Constant(ll.IntType(32), 1))) - cont = self.builder.icmp_signed( - "<", i.auto_load(self.builder), - ll.Constant(ll.IntType(32), val.alloc_count)) - self.builder.cbranch(cont, copy_block, end_block) - - self.builder.position_at_end(end_block) - else: - for target in node.targets: - target = self.visit_expression(target) - target.set_value(self.builder, val) - - def _visit_stmt_AugAssign(self, node): - target = self.visit_expression(node.target) - right = self.visit_expression(node.value) - val = _ast_binops[type(node.op)](target, right, self.builder) - target.set_value(self.builder, val) - - def _visit_stmt_Expr(self, node): - self.visit_expression(node.value) - - def _visit_stmt_If(self, node): - function = self.builder.basic_block.function - then_block = function.append_basic_block("i_then") - else_block = function.append_basic_block("i_else") - merge_block = function.append_basic_block("i_merge") - - condition = self.visit_expression(node.test).o_bool(self.builder) - self.builder.cbranch(condition.auto_load(self.builder), - then_block, else_block) - - self.builder.position_at_end(then_block) - self.visit_statements(node.body) - if not self._bb_terminated(): - self.builder.branch(merge_block) - - self.builder.position_at_end(else_block) - self.visit_statements(node.orelse) - if not self._bb_terminated(): - self.builder.branch(merge_block) - - self.builder.position_at_end(merge_block) - - def _enter_loop_body(self, break_block, continue_block): - self._break_stack.append(break_block) - self._continue_stack.append(continue_block) - self._exception_level_stack.append(0) - - def _leave_loop_body(self): - self._exception_level_stack.pop() - self._continue_stack.pop() - self._break_stack.pop() - - def _visit_stmt_While(self, node): - function = self.builder.basic_block.function - - body_block = function.append_basic_block("w_body") - else_block = function.append_basic_block("w_else") - condition = self.visit_expression(node.test).o_bool(self.builder) - self.builder.cbranch( - condition.auto_load(self.builder), body_block, else_block) - - continue_block = function.append_basic_block("w_continue") - merge_block = function.append_basic_block("w_merge") - self.builder.position_at_end(body_block) - self._enter_loop_body(merge_block, continue_block) - self.visit_statements(node.body) - self._leave_loop_body() - if not self._bb_terminated(): - self.builder.branch(continue_block) - - self.builder.position_at_end(continue_block) - condition = self.visit_expression(node.test).o_bool(self.builder) - self.builder.cbranch( - condition.auto_load(self.builder), body_block, merge_block) - - self.builder.position_at_end(else_block) - self.visit_statements(node.orelse) - if not self._bb_terminated(): - self.builder.branch(merge_block) - - self.builder.position_at_end(merge_block) - - def _visit_stmt_For(self, node): - function = self.builder.basic_block.function - - it = self.visit_expression(node.iter) - target = self.visit_expression(node.target) - itval = it.get_value_ptr() - - body_block = function.append_basic_block("f_body") - else_block = function.append_basic_block("f_else") - cont = it.o_next(self.builder) - self.builder.cbranch( - cont.auto_load(self.builder), body_block, else_block) - - continue_block = function.append_basic_block("f_continue") - merge_block = function.append_basic_block("f_merge") - self.builder.position_at_end(body_block) - target.set_value(self.builder, itval) - self._enter_loop_body(merge_block, continue_block) - self.visit_statements(node.body) - self._leave_loop_body() - if not self._bb_terminated(): - self.builder.branch(continue_block) - - self.builder.position_at_end(continue_block) - cont = it.o_next(self.builder) - self.builder.cbranch( - cont.auto_load(self.builder), body_block, merge_block) - - self.builder.position_at_end(else_block) - self.visit_statements(node.orelse) - if not self._bb_terminated(): - self.builder.branch(merge_block) - - self.builder.position_at_end(merge_block) - - def _break_loop_body(self, target_block): - exception_levels = self._exception_level_stack[-1] - if exception_levels: - self.runtime.build_pop(self.builder, exception_levels) - self.builder.branch(target_block) - - def _visit_stmt_Break(self, node): - self._break_loop_body(self._break_stack[-1]) - - def _visit_stmt_Continue(self, node): - self._break_loop_body(self._continue_stack[-1]) - - def _visit_stmt_Return(self, node): - if node.value is None: - val = base_types.VNone() - else: - val = self.visit_expression(node.value) - exception_levels = sum(self._exception_level_stack) - if exception_levels: - self.runtime.build_pop(self.builder, exception_levels) - if isinstance(val, base_types.VNone): - self.builder.ret_void() - else: - self.builder.ret(val.auto_load(self.builder)) - - def _visit_stmt_Pass(self, node): - pass - - def _visit_stmt_Raise(self, node): - if self._active_exception_stack: - finally_block, propagate, propagate_eid = ( - self._active_exception_stack[-1]) - self.builder.store(ll.Constant(ll.IntType(1), 1), propagate) - if node.exc is not None: - eid = ll.Constant(ll.IntType(32), node.exc.args[0].n) - self.builder.store(eid, propagate_eid) - self.builder.branch(finally_block) - else: - eid = ll.Constant(ll.IntType(32), node.exc.args[0].n) - self.runtime.build_raise(self.builder, eid) - - def _handle_exception(self, function, finally_block, - propagate, propagate_eid, handlers): - eid = self.runtime.build_getid(self.builder) - self._active_exception_stack.append( - (finally_block, propagate, propagate_eid)) - self.builder.store(ll.Constant(ll.IntType(1), 1), propagate) - self.builder.store(eid, propagate_eid) - - for handler in handlers: - handled_exc_block = function.append_basic_block("try_exc_h") - cont_exc_block = function.append_basic_block("try_exc_c") - if handler.type is None: - self.builder.branch(handled_exc_block) - else: - if isinstance(handler.type, ast.Tuple): - match = self.builder.icmp_signed( - "==", eid, - ll.Constant(ll.IntType(32), - handler.type.elts[0].args[0].n)) - for elt in handler.type.elts[1:]: - match = self.builder.or_( - match, - self.builder.icmp_signed( - "==", eid, - ll.Constant(ll.IntType(32), elt.args[0].n))) - else: - match = self.builder.icmp_signed( - "==", eid, - ll.Constant(ll.IntType(32), handler.type.args[0].n)) - self.builder.cbranch(match, handled_exc_block, cont_exc_block) - self.builder.position_at_end(handled_exc_block) - self.builder.store(ll.Constant(ll.IntType(1), 0), propagate) - self.visit_statements(handler.body) - if not self._bb_terminated(): - self.builder.branch(finally_block) - self.builder.position_at_end(cont_exc_block) - self.builder.branch(finally_block) - - self._active_exception_stack.pop() - - def _visit_stmt_Try(self, node): - function = self.builder.basic_block.function - noexc_block = function.append_basic_block("try_noexc") - exc_block = function.append_basic_block("try_exc") - finally_block = function.append_basic_block("try_finally") - - propagate = self.builder.alloca(ll.IntType(1), - name="propagate") - self.builder.store(ll.Constant(ll.IntType(1), 0), propagate) - propagate_eid = self.builder.alloca(ll.IntType(32), - name="propagate_eid") - exception_occured = self.runtime.build_catch(self.builder) - self.builder.cbranch(exception_occured, exc_block, noexc_block) - - self.builder.position_at_end(noexc_block) - self._exception_level_stack[-1] += 1 - self.visit_statements(node.body) - self._exception_level_stack[-1] -= 1 - if not self._bb_terminated(): - self.runtime.build_pop(self.builder, 1) - self.visit_statements(node.orelse) - if not self._bb_terminated(): - self.builder.branch(finally_block) - self.builder.position_at_end(exc_block) - self._handle_exception(function, finally_block, - propagate, propagate_eid, node.handlers) - - propagate_block = function.append_basic_block("try_propagate") - merge_block = function.append_basic_block("try_merge") - self.builder.position_at_end(finally_block) - self.visit_statements(node.finalbody) - if not self._bb_terminated(): - self.builder.cbranch( - self.builder.load(propagate), - propagate_block, merge_block) - self.builder.position_at_end(propagate_block) - self.runtime.build_raise(self.builder, self.builder.load(propagate_eid)) - self.builder.branch(merge_block) - self.builder.position_at_end(merge_block) diff --git a/artiq/py2llvm_old/base_types.py b/artiq/py2llvm_old/base_types.py deleted file mode 100644 index 3ef472984..000000000 --- a/artiq/py2llvm_old/base_types.py +++ /dev/null @@ -1,321 +0,0 @@ -import llvmlite_or1k.ir as ll - -from artiq.py2llvm.values import VGeneric - - -class VNone(VGeneric): - def get_llvm_type(self): - return ll.VoidType() - - def alloca(self, builder, name): - pass - - def set_const_value(self, builder, v): - assert v is None - - def set_value(self, builder, other): - if not isinstance(other, VNone): - raise TypeError - - def o_bool(self, builder): - r = VBool() - if builder is not None: - r.set_const_value(builder, False) - return r - - def o_not(self, builder): - r = VBool() - if builder is not None: - r.set_const_value(builder, True) - return r - - -class VInt(VGeneric): - def __init__(self, nbits=32): - VGeneric.__init__(self) - self.nbits = nbits - - def get_llvm_type(self): - return ll.IntType(self.nbits) - - def __repr__(self): - return "".format(self.nbits) - - def same_type(self, other): - return isinstance(other, VInt) and other.nbits == self.nbits - - def merge(self, other): - if isinstance(other, VInt) and not isinstance(other, VBool): - if other.nbits > self.nbits: - self.nbits = other.nbits - else: - raise TypeError("Incompatible types: {} and {}" - .format(repr(self), repr(other))) - - def set_value(self, builder, n): - self.auto_store( - builder, n.o_intx(self.nbits, builder).auto_load(builder)) - - def set_const_value(self, builder, n): - self.auto_store(builder, ll.Constant(self.get_llvm_type(), n)) - - def o_bool(self, builder, inv=False): - r = VBool() - if builder is not None: - r.auto_store( - builder, builder.icmp_signed( - "==" if inv else "!=", - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), 0))) - return r - - def o_float(self, builder): - r = VFloat() - if builder is not None: - if isinstance(self, VBool): - cf = builder.uitofp - else: - cf = builder.sitofp - r.auto_store(builder, cf(self.auto_load(builder), - r.get_llvm_type())) - return r - - def o_not(self, builder): - return self.o_bool(builder, inv=True) - - def o_neg(self, builder): - r = VInt(self.nbits) - if builder is not None: - r.auto_store( - builder, builder.mul( - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), -1))) - return r - - def o_intx(self, target_bits, builder): - r = VInt(target_bits) - if builder is not None: - if self.nbits == target_bits: - r.auto_store( - builder, self.auto_load(builder)) - if self.nbits > target_bits: - r.auto_store( - builder, builder.trunc(self.auto_load(builder), - r.get_llvm_type())) - if self.nbits < target_bits: - if isinstance(self, VBool): - ef = builder.zext - else: - ef = builder.sext - r.auto_store( - builder, ef(self.auto_load(builder), - r.get_llvm_type())) - return r - o_roundx = o_intx - - def o_truediv(self, other, builder): - if isinstance(other, VInt): - left = self.o_float(builder) - right = other.o_float(builder) - return left.o_truediv(right, builder) - else: - return NotImplemented - -def _make_vint_binop_method(builder_name, bool_op): - def binop_method(self, other, builder): - if isinstance(other, VInt): - target_bits = max(self.nbits, other.nbits) - if not bool_op and target_bits == 1: - target_bits = 32 - if bool_op and target_bits == 1: - r = VBool() - else: - r = VInt(target_bits) - if builder is not None: - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - bf = getattr(builder, builder_name) - r.auto_store( - builder, bf(left.auto_load(builder), - right.auto_load(builder))) - return r - else: - return NotImplemented - return binop_method - -for _method_name, _builder_name, _bool_op in (("o_add", "add", False), - ("o_sub", "sub", False), - ("o_mul", "mul", False), - ("o_floordiv", "sdiv", False), - ("o_mod", "srem", False), - ("o_and", "and_", True), - ("o_xor", "xor", True), - ("o_or", "or_", True)): - setattr(VInt, _method_name, _make_vint_binop_method(_builder_name, _bool_op)) - - -def _make_vint_cmp_method(icmp_val): - def cmp_method(self, other, builder): - if isinstance(other, VInt): - r = VBool() - if builder is not None: - target_bits = max(self.nbits, other.nbits) - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - r.auto_store( - builder, - builder.icmp_signed( - icmp_val, left.auto_load(builder), - right.auto_load(builder))) - return r - else: - return NotImplemented - return cmp_method - -for _method_name, _icmp_val in (("o_eq", "=="), - ("o_ne", "!="), - ("o_lt", "<"), - ("o_le", "<="), - ("o_gt", ">"), - ("o_ge", ">=")): - setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val)) - - -class VBool(VInt): - def __init__(self): - VInt.__init__(self, 1) - - __repr__ = VGeneric.__repr__ - same_type = VGeneric.same_type - merge = VGeneric.merge - - def set_const_value(self, builder, b): - VInt.set_const_value(self, builder, int(b)) - - -class VFloat(VGeneric): - def get_llvm_type(self): - return ll.DoubleType() - - def set_value(self, builder, v): - if not isinstance(v, VFloat): - raise TypeError - self.auto_store(builder, v.auto_load(builder)) - - def set_const_value(self, builder, n): - self.auto_store(builder, ll.Constant(self.get_llvm_type(), n)) - - def o_float(self, builder): - r = VFloat() - if builder is not None: - r.auto_store(builder, self.auto_load(builder)) - return r - - def o_bool(self, builder, inv=False): - r = VBool() - if builder is not None: - r.auto_store( - builder, builder.fcmp_ordered( - "==" if inv else "!=", - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), 0.0))) - return r - - def o_not(self, builder): - return self.o_bool(builder, True) - - def o_neg(self, builder): - r = VFloat() - if builder is not None: - r.auto_store( - builder, builder.fmul( - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), -1.0))) - return r - - def o_intx(self, target_bits, builder): - r = VInt(target_bits) - if builder is not None: - r.auto_store(builder, builder.fptosi(self.auto_load(builder), - r.get_llvm_type())) - return r - - def o_roundx(self, target_bits, builder): - r = VInt(target_bits) - if builder is not None: - function = builder.basic_block.function - neg_block = function.append_basic_block("fr_neg") - merge_block = function.append_basic_block("fr_merge") - - half = VFloat() - half.alloca(builder, "half") - half.set_const_value(builder, 0.5) - - condition = builder.fcmp_ordered( - "<", - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), 0.0)) - builder.cbranch(condition, neg_block, merge_block) - - builder.position_at_end(neg_block) - half.set_const_value(builder, -0.5) - builder.branch(merge_block) - - builder.position_at_end(merge_block) - s = builder.fadd(self.auto_load(builder), half.auto_load(builder)) - r.auto_store(builder, builder.fptosi(s, r.get_llvm_type())) - return r - - def o_floordiv(self, other, builder): - return self.o_truediv(other, builder).o_int64(builder).o_float(builder) - -def _make_vfloat_binop_method(builder_name, reverse): - def binop_method(self, other, builder): - if not hasattr(other, "o_float"): - return NotImplemented - r = VFloat() - if builder is not None: - left = self.o_float(builder) - right = other.o_float(builder) - if reverse: - left, right = right, left - bf = getattr(builder, builder_name) - r.auto_store( - builder, bf(left.auto_load(builder), - right.auto_load(builder))) - return r - return binop_method - -for _method_name, _builder_name in (("add", "fadd"), - ("sub", "fsub"), - ("mul", "fmul"), - ("truediv", "fdiv")): - setattr(VFloat, "o_" + _method_name, - _make_vfloat_binop_method(_builder_name, False)) - setattr(VFloat, "or_" + _method_name, - _make_vfloat_binop_method(_builder_name, True)) - - -def _make_vfloat_cmp_method(fcmp_val): - def cmp_method(self, other, builder): - if not hasattr(other, "o_float"): - return NotImplemented - r = VBool() - if builder is not None: - left = self.o_float(builder) - right = other.o_float(builder) - r.auto_store( - builder, - builder.fcmp_ordered( - fcmp_val, left.auto_load(builder), - right.auto_load(builder))) - return r - return cmp_method - -for _method_name, _fcmp_val in (("o_eq", "=="), - ("o_ne", "!="), - ("o_lt", "<"), - ("o_le", "<="), - ("o_gt", ">"), - ("o_ge", ">=")): - setattr(VFloat, _method_name, _make_vfloat_cmp_method(_fcmp_val)) diff --git a/artiq/py2llvm_old/infer_types.py b/artiq/py2llvm_old/infer_types.py deleted file mode 100644 index 7de53bab8..000000000 --- a/artiq/py2llvm_old/infer_types.py +++ /dev/null @@ -1,75 +0,0 @@ -import pythonparser.algorithm -from pythonparser import ast -from copy import deepcopy - -from artiq.py2llvm.ast_body import Visitor -from artiq.py2llvm import base_types - - -class _TypeScanner(pythonparser.algorithm.Visitor): - def __init__(self, env, ns): - self.exprv = Visitor(env, ns) - - def _update_target(self, target, val): - ns = self.exprv.ns - if isinstance(target, ast.Name): - if target.id in ns: - ns[target.id].merge(val) - else: - ns[target.id] = deepcopy(val) - elif isinstance(target, ast.Subscript): - target = target.value - levels = 0 - while isinstance(target, ast.Subscript): - target = target.value - levels += 1 - if isinstance(target, ast.Name): - target_value = ns[target.id] - for i in range(levels): - target_value = target_value.o_subscript(None, None) - target_value.merge_subscript(val) - else: - raise NotImplementedError - else: - raise NotImplementedError - - def visit_Assign(self, node): - val = self.exprv.visit_expression(node.value) - for target in node.targets: - self._update_target(target, val) - - def visit_AugAssign(self, node): - val = self.exprv.visit_expression(ast.BinOp( - op=node.op, left=node.target, right=node.value)) - self._update_target(node.target, val) - - def visit_For(self, node): - it = self.exprv.visit_expression(node.iter) - self._update_target(node.target, it.get_value_ptr()) - self.generic_visit(node) - - def visit_Return(self, node): - if node.value is None: - val = base_types.VNone() - else: - val = self.exprv.visit_expression(node.value) - ns = self.exprv.ns - if "return" in ns: - ns["return"].merge(val) - else: - ns["return"] = deepcopy(val) - - -def infer_function_types(env, node, param_types): - ns = deepcopy(param_types) - ts = _TypeScanner(env, ns) - ts.visit(node) - while True: - prev_ns = deepcopy(ns) - ts = _TypeScanner(env, ns) - ts.visit(node) - if all(v.same_type(prev_ns[k]) for k, v in ns.items()): - # no more promotions - completed - if "return" not in ns: - ns["return"] = base_types.VNone() - return ns diff --git a/artiq/py2llvm_old/iterators.py b/artiq/py2llvm_old/iterators.py deleted file mode 100644 index 0e1526319..000000000 --- a/artiq/py2llvm_old/iterators.py +++ /dev/null @@ -1,51 +0,0 @@ -from artiq.py2llvm.values import operators -from artiq.py2llvm.base_types import VInt - -class IRange: - def __init__(self, builder, args): - minimum, step = None, None - if len(args) == 1: - maximum = args[0] - elif len(args) == 2: - minimum, maximum = args - else: - minimum, maximum, step = args - if minimum is None: - minimum = VInt() - if builder is not None: - minimum.set_const_value(builder, 0) - if step is None: - step = VInt() - if builder is not None: - step.set_const_value(builder, 1) - - self._counter = minimum.new() - self._counter.merge(maximum) - self._counter.merge(step) - self._minimum = self._counter.new() - self._maximum = self._counter.new() - self._step = self._counter.new() - - if builder is not None: - self._minimum.alloca(builder, "irange_min") - self._maximum.alloca(builder, "irange_max") - self._step.alloca(builder, "irange_step") - self._counter.alloca(builder, "irange_count") - - self._minimum.set_value(builder, minimum) - self._maximum.set_value(builder, maximum) - self._step.set_value(builder, step) - - counter_init = operators.sub(self._minimum, self._step, builder) - self._counter.set_value(builder, counter_init) - - # must be a pointer value that can be dereferenced anytime - # to get the current value of the iterator - def get_value_ptr(self): - return self._counter - - def o_next(self, builder): - self._counter.set_value( - builder, - operators.add(self._counter, self._step, builder)) - return operators.lt(self._counter, self._maximum, builder) diff --git a/artiq/py2llvm_old/lists.py b/artiq/py2llvm_old/lists.py deleted file mode 100644 index e17ab5348..000000000 --- a/artiq/py2llvm_old/lists.py +++ /dev/null @@ -1,72 +0,0 @@ -import llvmlite_or1k.ir as ll - -from artiq.py2llvm.values import VGeneric -from artiq.py2llvm.base_types import VInt, VNone - - -class VList(VGeneric): - def __init__(self, el_type, alloc_count): - VGeneric.__init__(self) - self.el_type = el_type - self.alloc_count = alloc_count - - def get_llvm_type(self): - count = 0 if self.alloc_count is None else self.alloc_count - if isinstance(self.el_type, VNone): - return ll.LiteralStructType([ll.IntType(32)]) - else: - return ll.LiteralStructType([ - ll.IntType(32), ll.ArrayType(self.el_type.get_llvm_type(), - count)]) - - def __repr__(self): - return "".format( - repr(self.el_type), - "?" if self.alloc_count is None else self.alloc_count) - - def same_type(self, other): - return (isinstance(other, VList) - and self.el_type.same_type(other.el_type)) - - def merge(self, other): - if isinstance(other, VList): - if self.alloc_count: - if other.alloc_count: - self.el_type.merge(other.el_type) - if self.alloc_count < other.alloc_count: - self.alloc_count = other.alloc_count - else: - self.el_type = other.el_type.new() - self.alloc_count = other.alloc_count - else: - raise TypeError("Incompatible types: {} and {}" - .format(repr(self), repr(other))) - - def merge_subscript(self, other): - self.el_type.merge(other) - - def set_count(self, builder, count): - count_ptr = builder.gep(self.llvm_value, [ - ll.Constant(ll.IntType(32), 0), - ll.Constant(ll.IntType(32), 0)]) - builder.store(ll.Constant(ll.IntType(32), count), count_ptr) - - def o_len(self, builder): - r = VInt() - if builder is not None: - count_ptr = builder.gep(self.llvm_value, [ - ll.Constant(ll.IntType(32), 0), - ll.Constant(ll.IntType(32), 0)]) - r.auto_store(builder, builder.load(count_ptr)) - return r - - def o_subscript(self, index, builder): - r = self.el_type.new() - if builder is not None and not isinstance(r, VNone): - index = index.o_int(builder).auto_load(builder) - ssa_r = builder.gep(self.llvm_value, [ - ll.Constant(ll.IntType(32), 0), - ll.Constant(ll.IntType(32), 1), - index]) - r.auto_store(builder, ssa_r) - return r diff --git a/artiq/py2llvm_old/module.py b/artiq/py2llvm_old/module.py deleted file mode 100644 index f4df806e6..000000000 --- a/artiq/py2llvm_old/module.py +++ /dev/null @@ -1,62 +0,0 @@ -import llvmlite_or1k.ir as ll -import llvmlite_or1k.binding as llvm - -from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools - - -class Module: - def __init__(self, runtime=None): - self.llvm_module = ll.Module("main") - self.runtime = runtime - - if self.runtime is not None: - self.runtime.init_module(self) - fractions.init_module(self) - - def finalize(self): - self.llvm_module_ref = llvm.parse_assembly(str(self.llvm_module)) - pmb = llvm.create_pass_manager_builder() - pmb.opt_level = 2 - pm = llvm.create_module_pass_manager() - pmb.populate(pm) - pm.run(self.llvm_module_ref) - - def get_ee(self): - self.finalize() - tm = llvm.Target.from_default_triple().create_target_machine() - ee = llvm.create_mcjit_compiler(self.llvm_module_ref, tm) - ee.finalize_object() - return ee - - def emit_object(self): - self.finalize() - return self.runtime.emit_object() - - def compile_function(self, func_def, param_types): - ns = infer_types.infer_function_types(self.runtime, func_def, param_types) - retval = ns["return"] - - function_type = ll.FunctionType(retval.get_llvm_type(), - [ns[arg.arg].get_llvm_type() for arg in func_def.args.args]) - function = ll.Function(self.llvm_module, function_type, func_def.name) - bb = function.append_basic_block("entry") - builder = ll.IRBuilder() - builder.position_at_end(bb) - - for arg_ast, arg_llvm in zip(func_def.args.args, function.args): - arg_llvm.name = arg_ast.arg - for k, v in ns.items(): - v.alloca(builder, k) - for arg_ast, arg_llvm in zip(func_def.args.args, function.args): - ns[arg_ast.arg].auto_store(builder, arg_llvm) - - visitor = ast_body.Visitor(self.runtime, ns, builder) - visitor.visit_statements(func_def.body) - - if not tools.is_terminated(builder.basic_block): - if isinstance(retval, base_types.VNone): - builder.ret_void() - else: - builder.ret(retval.auto_load(builder)) - - return function, retval diff --git a/artiq/py2llvm_old/test/py2llvm.py b/artiq/py2llvm_old/test/py2llvm.py new file mode 100644 index 000000000..c6d9f0135 --- /dev/null +++ b/artiq/py2llvm_old/test/py2llvm.py @@ -0,0 +1,169 @@ +import unittest +from pythonparser import parse, ast +import inspect +from fractions import Fraction +from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double +import struct + +import llvmlite_or1k.binding as llvm + +from artiq.language.core import int64 +from artiq.py2llvm.infer_types import infer_function_types +from artiq.py2llvm import base_types, lists +from artiq.py2llvm.module import Module + +def simplify_encode(a, b): + f = Fraction(a, b) + return f.numerator*1000 + f.denominator + + +def frac_arith_encode(op, a, b, c, d): + if op == 0: + f = Fraction(a, b) - Fraction(c, d) + elif op == 1: + f = Fraction(a, b) + Fraction(c, d) + elif op == 2: + f = Fraction(a, b) * Fraction(c, d) + else: + f = Fraction(a, b) / Fraction(c, d) + return f.numerator*1000 + f.denominator + + +def frac_arith_encode_int(op, a, b, x): + if op == 0: + f = Fraction(a, b) - x + elif op == 1: + f = Fraction(a, b) + x + elif op == 2: + f = Fraction(a, b) * x + else: + f = Fraction(a, b) / x + return f.numerator*1000 + f.denominator + + +def frac_arith_encode_int_rev(op, a, b, x): + if op == 0: + f = x - Fraction(a, b) + elif op == 1: + f = x + Fraction(a, b) + elif op == 2: + f = x * Fraction(a, b) + else: + f = x / Fraction(a, b) + return f.numerator*1000 + f.denominator + + +def frac_arith_float(op, a, b, x): + if op == 0: + return Fraction(a, b) - x + elif op == 1: + return Fraction(a, b) + x + elif op == 2: + return Fraction(a, b) * x + else: + return Fraction(a, b) / x + + +def frac_arith_float_rev(op, a, b, x): + if op == 0: + return x - Fraction(a, b) + elif op == 1: + return x + Fraction(a, b) + elif op == 2: + return x * Fraction(a, b) + else: + return x / Fraction(a, b) + + +class CodeGenCase(unittest.TestCase): + def test_frac_simplify(self): + simplify_encode_c = CompiledFunction( + simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) + for a in _test_range(): + for b in _test_range(): + self.assertEqual( + simplify_encode_c(a, b), simplify_encode(a, b)) + + def _test_frac_arith(self, op): + frac_arith_encode_c = CompiledFunction( + frac_arith_encode, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "c": base_types.VInt(), "d": base_types.VInt()}) + for a in _test_range(): + for b in _test_range(): + for c in _test_range(): + for d in _test_range(): + self.assertEqual( + frac_arith_encode_c(op, a, b, c, d), + frac_arith_encode(op, a, b, c, d)) + + def test_frac_add(self): + self._test_frac_arith(0) + + def test_frac_sub(self): + self._test_frac_arith(1) + + def test_frac_mul(self): + self._test_frac_arith(2) + + def test_frac_div(self): + self._test_frac_arith(3) + + def _test_frac_arith_int(self, op, rev): + f = frac_arith_encode_int_rev if rev else frac_arith_encode_int + f_c = CompiledFunction(f, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "x": base_types.VInt()}) + for a in _test_range(): + for b in _test_range(): + for x in _test_range(): + self.assertEqual( + f_c(op, a, b, x), + f(op, a, b, x)) + + def test_frac_add_int(self): + self._test_frac_arith_int(0, False) + self._test_frac_arith_int(0, True) + + def test_frac_sub_int(self): + self._test_frac_arith_int(1, False) + self._test_frac_arith_int(1, True) + + def test_frac_mul_int(self): + self._test_frac_arith_int(2, False) + self._test_frac_arith_int(2, True) + + def test_frac_div_int(self): + self._test_frac_arith_int(3, False) + self._test_frac_arith_int(3, True) + + def _test_frac_arith_float(self, op, rev): + f = frac_arith_float_rev if rev else frac_arith_float + f_c = CompiledFunction(f, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "x": base_types.VFloat()}) + for a in _test_range(): + for b in _test_range(): + for x in _test_range(): + self.assertAlmostEqual( + f_c(op, a, b, x/2), + f(op, a, b, x/2)) + + def test_frac_add_float(self): + self._test_frac_arith_float(0, False) + self._test_frac_arith_float(0, True) + + def test_frac_sub_float(self): + self._test_frac_arith_float(1, False) + self._test_frac_arith_float(1, True) + + def test_frac_mul_float(self): + self._test_frac_arith_float(2, False) + self._test_frac_arith_float(2, True) + + def test_frac_div_float(self): + self._test_frac_arith_float(3, False) + self._test_frac_arith_float(3, True) diff --git a/artiq/py2llvm_old/tools.py b/artiq/py2llvm_old/tools.py deleted file mode 100644 index ba9e76949..000000000 --- a/artiq/py2llvm_old/tools.py +++ /dev/null @@ -1,5 +0,0 @@ -import llvmlite_or1k.ir as ll - -def is_terminated(basic_block): - return (basic_block.instructions - and isinstance(basic_block.instructions[-1], ll.Terminator)) diff --git a/artiq/transforms/inline.py b/artiq/py2llvm_old/transforms/inline.py similarity index 100% rename from artiq/transforms/inline.py rename to artiq/py2llvm_old/transforms/inline.py diff --git a/artiq/transforms/interleave.py b/artiq/py2llvm_old/transforms/interleave.py similarity index 100% rename from artiq/transforms/interleave.py rename to artiq/py2llvm_old/transforms/interleave.py diff --git a/artiq/transforms/lower_time.py b/artiq/py2llvm_old/transforms/lower_time.py similarity index 100% rename from artiq/transforms/lower_time.py rename to artiq/py2llvm_old/transforms/lower_time.py diff --git a/artiq/transforms/quantize_time.py b/artiq/py2llvm_old/transforms/quantize_time.py similarity index 100% rename from artiq/transforms/quantize_time.py rename to artiq/py2llvm_old/transforms/quantize_time.py diff --git a/artiq/transforms/unroll_loops.py b/artiq/py2llvm_old/transforms/unroll_loops.py similarity index 100% rename from artiq/transforms/unroll_loops.py rename to artiq/py2llvm_old/transforms/unroll_loops.py diff --git a/artiq/py2llvm_old/values.py b/artiq/py2llvm_old/values.py deleted file mode 100644 index 6f0b90e2c..000000000 --- a/artiq/py2llvm_old/values.py +++ /dev/null @@ -1,94 +0,0 @@ -from types import SimpleNamespace -from copy import copy - -import llvmlite_or1k.ir as ll - - -class VGeneric: - def __init__(self): - self.llvm_value = None - - def new(self): - r = copy(self) - r.llvm_value = None - return r - - def __repr__(self): - return "<" + self.__class__.__name__ + ">" - - def same_type(self, other): - return isinstance(other, self.__class__) - - def merge(self, other): - if not self.same_type(other): - raise TypeError("Incompatible types: {} and {}" - .format(repr(self), repr(other))) - - def auto_load(self, builder): - if isinstance(self.llvm_value.type, ll.PointerType): - return builder.load(self.llvm_value) - else: - return self.llvm_value - - def auto_store(self, builder, llvm_value): - if self.llvm_value is None: - self.llvm_value = llvm_value - elif isinstance(self.llvm_value.type, ll.PointerType): - builder.store(llvm_value, self.llvm_value) - else: - raise RuntimeError( - "Attempted to set LLVM SSA value multiple times") - - def alloca(self, builder, name=""): - if self.llvm_value is not None: - raise RuntimeError("Attempted to alloca existing LLVM value "+name) - self.llvm_value = builder.alloca(self.get_llvm_type(), name=name) - - def o_int(self, builder): - return self.o_intx(32, builder) - - def o_int64(self, builder): - return self.o_intx(64, builder) - - def o_round(self, builder): - return self.o_roundx(32, builder) - - def o_round64(self, builder): - return self.o_roundx(64, builder) - - -def _make_binary_operator(op_name): - def op(l, r, builder): - try: - opf = getattr(l, "o_" + op_name) - except AttributeError: - result = NotImplemented - else: - result = opf(r, builder) - if result is NotImplemented: - try: - ropf = getattr(r, "or_" + op_name) - except AttributeError: - result = NotImplemented - else: - result = ropf(l, builder) - if result is NotImplemented: - raise TypeError( - "Unsupported operand types for {}: {} and {}" - .format(op_name, type(l).__name__, type(r).__name__)) - return result - return op - - -def _make_operators(): - d = dict() - for op_name in ("add", "sub", "mul", - "truediv", "floordiv", "mod", - "pow", "lshift", "rshift", "xor", - "eq", "ne", "lt", "le", "gt", "ge"): - d[op_name] = _make_binary_operator(op_name) - d["and_"] = _make_binary_operator("and") - d["or_"] = _make_binary_operator("or") - return SimpleNamespace(**d) - -operators = _make_operators() diff --git a/artiq/test/py2llvm.py b/artiq/test/py2llvm.py deleted file mode 100644 index 07250b7d1..000000000 --- a/artiq/test/py2llvm.py +++ /dev/null @@ -1,372 +0,0 @@ -import unittest -from pythonparser import parse, ast -import inspect -from fractions import Fraction -from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double -import struct - -import llvmlite_or1k.binding as llvm - -from artiq.language.core import int64 -from artiq.py2llvm.infer_types import infer_function_types -from artiq.py2llvm import base_types, lists -from artiq.py2llvm.module import Module - - -llvm.initialize() -llvm.initialize_native_target() -llvm.initialize_native_asmprinter() -if struct.calcsize("P") < 8: - from ctypes import _dlopen, RTLD_GLOBAL - _dlopen("libgcc_s.so", RTLD_GLOBAL) - - -def _base_types(choice): - a = 2 # promoted later to int64 - b = a + 1 # initially int32, becomes int64 after a is promoted - c = b//2 # initially int32, becomes int64 after b is promoted - d = 4 and 5 # stays int32 - x = int64(7) - a += x # promotes a to int64 - foo = True | True or False - bar = None - myf = 4.5 - myf2 = myf + x - - if choice and foo and not bar: - return d - elif myf2: - return x + c - else: - return int64(8) - - -def _build_function_types(f): - return infer_function_types( - None, parse(inspect.getsource(f)), - dict()) - - -class FunctionBaseTypesCase(unittest.TestCase): - def setUp(self): - self.ns = _build_function_types(_base_types) - - def test_simple_types(self): - self.assertIsInstance(self.ns["foo"], base_types.VBool) - self.assertIsInstance(self.ns["bar"], base_types.VNone) - self.assertIsInstance(self.ns["d"], base_types.VInt) - self.assertEqual(self.ns["d"].nbits, 32) - self.assertIsInstance(self.ns["x"], base_types.VInt) - self.assertEqual(self.ns["x"].nbits, 64) - self.assertIsInstance(self.ns["myf"], base_types.VFloat) - self.assertIsInstance(self.ns["myf2"], base_types.VFloat) - - def test_promotion(self): - for v in "abc": - self.assertIsInstance(self.ns[v], base_types.VInt) - self.assertEqual(self.ns[v].nbits, 64) - - def test_return(self): - self.assertIsInstance(self.ns["return"], base_types.VInt) - self.assertEqual(self.ns["return"].nbits, 64) - - -def test_list_types(): - a = [0, 0, 0, 0, 0] - for i in range(2): - a[i] = int64(8) - return a - - -class FunctionListTypesCase(unittest.TestCase): - def setUp(self): - self.ns = _build_function_types(test_list_types) - - def test_list_types(self): - self.assertIsInstance(self.ns["a"], lists.VList) - self.assertIsInstance(self.ns["a"].el_type, base_types.VInt) - self.assertEqual(self.ns["a"].el_type.nbits, 64) - self.assertEqual(self.ns["a"].alloc_count, 5) - self.assertIsInstance(self.ns["i"], base_types.VInt) - self.assertEqual(self.ns["i"].nbits, 32) - - -def _value_to_ctype(v): - if isinstance(v, base_types.VBool): - return c_int - elif isinstance(v, base_types.VInt): - if v.nbits == 32: - return c_int32 - elif v.nbits == 64: - return c_int64 - else: - raise NotImplementedError(str(v)) - elif isinstance(v, base_types.VFloat): - return c_double - else: - raise NotImplementedError(str(v)) - - -class CompiledFunction: - def __init__(self, function, param_types): - module = Module() - - func_def = parse(inspect.getsource(function)).body[0] - function, retval = module.compile_function(func_def, param_types) - argvals = [param_types[arg.arg] for arg in func_def.args.args] - - ee = module.get_ee() - cfptr = ee.get_pointer_to_global( - module.llvm_module_ref.get_function(function.name)) - retval_ctype = _value_to_ctype(retval) - argval_ctypes = [_value_to_ctype(argval) for argval in argvals] - self.cfunc = CFUNCTYPE(retval_ctype, *argval_ctypes)(cfptr) - - # HACK: prevent garbage collection of self.cfunc internals - self.ee = ee - - def __call__(self, *args): - return self.cfunc(*args) - - -def arith(op, a, b): - if op == 0: - return a + b - elif op == 1: - return a - b - elif op == 2: - return a * b - else: - return a / b - - -def is_prime(x): - d = 2 - while d*d <= x: - if not x % d: - return False - d += 1 - return True - - -def simplify_encode(a, b): - f = Fraction(a, b) - return f.numerator*1000 + f.denominator - - -def frac_arith_encode(op, a, b, c, d): - if op == 0: - f = Fraction(a, b) - Fraction(c, d) - elif op == 1: - f = Fraction(a, b) + Fraction(c, d) - elif op == 2: - f = Fraction(a, b) * Fraction(c, d) - else: - f = Fraction(a, b) / Fraction(c, d) - return f.numerator*1000 + f.denominator - - -def frac_arith_encode_int(op, a, b, x): - if op == 0: - f = Fraction(a, b) - x - elif op == 1: - f = Fraction(a, b) + x - elif op == 2: - f = Fraction(a, b) * x - else: - f = Fraction(a, b) / x - return f.numerator*1000 + f.denominator - - -def frac_arith_encode_int_rev(op, a, b, x): - if op == 0: - f = x - Fraction(a, b) - elif op == 1: - f = x + Fraction(a, b) - elif op == 2: - f = x * Fraction(a, b) - else: - f = x / Fraction(a, b) - return f.numerator*1000 + f.denominator - - -def frac_arith_float(op, a, b, x): - if op == 0: - return Fraction(a, b) - x - elif op == 1: - return Fraction(a, b) + x - elif op == 2: - return Fraction(a, b) * x - else: - return Fraction(a, b) / x - - -def frac_arith_float_rev(op, a, b, x): - if op == 0: - return x - Fraction(a, b) - elif op == 1: - return x + Fraction(a, b) - elif op == 2: - return x * Fraction(a, b) - else: - return x / Fraction(a, b) - - -def list_test(): - x = 80 - a = [3 for x in range(7)] - b = [1, 2, 4, 5, 4, 0, 5] - a[3] = x - a[0] += 6 - a[1] = b[1] + b[2] - - acc = 0 - for i in range(7): - if i and a[i]: - acc += 1 - acc += a[i] - return acc - - -def corner_cases(): - two = True + True - (not True) - three = two + True//True - False*True - two_float = three - True/True - one_float = two_float - (1.0 == bool(0.1)) - zero = int(one_float) + round(-0.6) - eleven_float = zero + 5.5//0.5 - ten_float = eleven_float + round(Fraction(2, -3)) - return ten_float - - -def _test_range(): - for i in range(5, 10): - yield i - yield -i - - -class CodeGenCase(unittest.TestCase): - def _test_float_arith(self, op): - arith_c = CompiledFunction(arith, { - "op": base_types.VInt(), - "a": base_types.VFloat(), "b": base_types.VFloat()}) - for a in _test_range(): - for b in _test_range(): - self.assertEqual(arith_c(op, a/2, b/2), arith(op, a/2, b/2)) - - def test_float_add(self): - self._test_float_arith(0) - - def test_float_sub(self): - self._test_float_arith(1) - - def test_float_mul(self): - self._test_float_arith(2) - - def test_float_div(self): - self._test_float_arith(3) - - def test_is_prime(self): - is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()}) - for i in range(200): - self.assertEqual(is_prime_c(i), is_prime(i)) - - def test_frac_simplify(self): - simplify_encode_c = CompiledFunction( - simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - self.assertEqual( - simplify_encode_c(a, b), simplify_encode(a, b)) - - def _test_frac_arith(self, op): - frac_arith_encode_c = CompiledFunction( - frac_arith_encode, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "c": base_types.VInt(), "d": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - for c in _test_range(): - for d in _test_range(): - self.assertEqual( - frac_arith_encode_c(op, a, b, c, d), - frac_arith_encode(op, a, b, c, d)) - - def test_frac_add(self): - self._test_frac_arith(0) - - def test_frac_sub(self): - self._test_frac_arith(1) - - def test_frac_mul(self): - self._test_frac_arith(2) - - def test_frac_div(self): - self._test_frac_arith(3) - - def _test_frac_arith_int(self, op, rev): - f = frac_arith_encode_int_rev if rev else frac_arith_encode_int - f_c = CompiledFunction(f, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "x": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - for x in _test_range(): - self.assertEqual( - f_c(op, a, b, x), - f(op, a, b, x)) - - def test_frac_add_int(self): - self._test_frac_arith_int(0, False) - self._test_frac_arith_int(0, True) - - def test_frac_sub_int(self): - self._test_frac_arith_int(1, False) - self._test_frac_arith_int(1, True) - - def test_frac_mul_int(self): - self._test_frac_arith_int(2, False) - self._test_frac_arith_int(2, True) - - def test_frac_div_int(self): - self._test_frac_arith_int(3, False) - self._test_frac_arith_int(3, True) - - def _test_frac_arith_float(self, op, rev): - f = frac_arith_float_rev if rev else frac_arith_float - f_c = CompiledFunction(f, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "x": base_types.VFloat()}) - for a in _test_range(): - for b in _test_range(): - for x in _test_range(): - self.assertAlmostEqual( - f_c(op, a, b, x/2), - f(op, a, b, x/2)) - - def test_frac_add_float(self): - self._test_frac_arith_float(0, False) - self._test_frac_arith_float(0, True) - - def test_frac_sub_float(self): - self._test_frac_arith_float(1, False) - self._test_frac_arith_float(1, True) - - def test_frac_mul_float(self): - self._test_frac_arith_float(2, False) - self._test_frac_arith_float(2, True) - - def test_frac_div_float(self): - self._test_frac_arith_float(3, False) - self._test_frac_arith_float(3, True) - - def test_list(self): - list_test_c = CompiledFunction(list_test, dict()) - self.assertEqual(list_test_c(), list_test()) - - def test_corner_cases(self): - corner_cases_c = CompiledFunction(corner_cases, dict()) - self.assertEqual(corner_cases_c(), corner_cases()) diff --git a/artiq/test/transforms.py b/artiq/test/transforms.py deleted file mode 100644 index dffee41a2..000000000 --- a/artiq/test/transforms.py +++ /dev/null @@ -1,44 +0,0 @@ -import unittest -import ast - -from artiq import ns -from artiq.coredevice import comm_dummy, core -from artiq.transforms.unparse import unparse - - -optimize_in = """ - -def run(): - dds_sysclk = Fraction(1000000000, 1) - n = seconds_to_mu((1.2345 * Fraction(1, 1000000000))) - with sequential: - frequency = 345 * Fraction(1000000, 1) - frequency_to_ftw_return = int((((2 ** 32) * frequency) / dds_sysclk)) - ftw = frequency_to_ftw_return - with sequential: - ftw2 = ftw - ftw_to_frequency_return = ((ftw2 * dds_sysclk) / (2 ** 32)) - f = ftw_to_frequency_return - phi = ((1000 * mu_to_seconds(n)) * f) - do_something(int(phi)) -""" - -optimize_out = """ - -def run(): - now = syscall('now_init') - try: - do_something(344) - finally: - syscall('now_save', now) -""" - - -class OptimizeCase(unittest.TestCase): - def test_optimize(self): - dmgr = dict() - dmgr["comm"] = comm_dummy.Comm(dmgr) - coredev = core.Core(dmgr, ref_period=1*ns) - func_def = ast.parse(optimize_in).body[0] - coredev.transform_stack(func_def, dict(), dict()) - self.assertEqual(unparse(func_def), optimize_out) diff --git a/artiq/transforms/__init__.py b/artiq/transforms/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/artiq/transforms/fold_constants.py b/artiq/transforms/fold_constants.py deleted file mode 100644 index 402fc243b..000000000 --- a/artiq/transforms/fold_constants.py +++ /dev/null @@ -1,156 +0,0 @@ -import ast -import operator -from fractions import Fraction - -from artiq.transforms.tools import * -from artiq.language.core import int64, round64 - - -_ast_unops = { - ast.Invert: operator.inv, - ast.Not: operator.not_, - ast.UAdd: operator.pos, - ast.USub: operator.neg -} - - -_ast_binops = { - ast.Add: operator.add, - ast.Sub: operator.sub, - ast.Mult: operator.mul, - ast.Div: operator.truediv, - ast.FloorDiv: operator.floordiv, - ast.Mod: operator.mod, - ast.Pow: operator.pow, - ast.LShift: operator.lshift, - ast.RShift: operator.rshift, - ast.BitOr: operator.or_, - ast.BitXor: operator.xor, - ast.BitAnd: operator.and_ -} - -_ast_cmpops = { - ast.Eq: operator.eq, - ast.NotEq: operator.ne, - ast.Lt: operator.lt, - ast.LtE: operator.le, - ast.Gt: operator.gt, - ast.GtE: operator.ge -} - -_ast_boolops = { - ast.Or: lambda x, y: x or y, - ast.And: lambda x, y: x and y -} - - -class _ConstantFolder(ast.NodeTransformer): - def visit_UnaryOp(self, node): - self.generic_visit(node) - try: - operand = eval_constant(node.operand) - except NotConstant: - return node - try: - op = _ast_unops[type(node.op)] - except KeyError: - return node - try: - result = value_to_ast(op(operand)) - except: - return node - return ast.copy_location(result, node) - - def visit_BinOp(self, node): - self.generic_visit(node) - try: - left, right = eval_constant(node.left), eval_constant(node.right) - except NotConstant: - return node - try: - op = _ast_binops[type(node.op)] - except KeyError: - return node - try: - result = value_to_ast(op(left, right)) - except: - return node - return ast.copy_location(result, node) - - def visit_Compare(self, node): - self.generic_visit(node) - try: - operands = [eval_constant(node.left)] - except NotConstant: - operands = [node.left] - ops = [] - for op, right_ast in zip(node.ops, node.comparators): - try: - right = eval_constant(right_ast) - except NotConstant: - right = right_ast - if (not isinstance(operands[-1], ast.AST) - and not isinstance(right, ast.AST)): - left = operands.pop() - operands.append(_ast_cmpops[type(op)](left, right)) - else: - ops.append(op) - operands.append(right_ast) - operands = [operand if isinstance(operand, ast.AST) - else ast.copy_location(value_to_ast(operand), node) - for operand in operands] - if len(operands) == 1: - return operands[0] - else: - node.left = operands[0] - node.right = operands[1:] - node.ops = ops - return node - - def visit_BoolOp(self, node): - self.generic_visit(node) - new_values = [] - for value in node.values: - try: - value_c = eval_constant(value) - except NotConstant: - new_values.append(value) - else: - if new_values and not isinstance(new_values[-1], ast.AST): - op = _ast_boolops[type(node.op)] - new_values[-1] = op(new_values[-1], value_c) - else: - new_values.append(value_c) - new_values = [v if isinstance(v, ast.AST) else value_to_ast(v) - for v in new_values] - if len(new_values) > 1: - node.values = new_values - return node - else: - return new_values[0] - - def visit_Call(self, node): - self.generic_visit(node) - fn = node.func.id - constant_ops = { - "int": int, - "int64": int64, - "round": round, - "round64": round64, - "Fraction": Fraction - } - if fn in constant_ops: - args = [] - for arg in node.args: - try: - args.append(eval_constant(arg)) - except NotConstant: - return node - result = value_to_ast(constant_ops[fn](*args)) - return ast.copy_location(result, node) - else: - return node - - -def fold_constants(node): - _ConstantFolder().visit(node) diff --git a/artiq/transforms/remove_dead_code.py b/artiq/transforms/remove_dead_code.py deleted file mode 100644 index 9a58c851d..000000000 --- a/artiq/transforms/remove_dead_code.py +++ /dev/null @@ -1,59 +0,0 @@ -import ast - -from artiq.transforms.tools import is_ref_transparent - - -class _SourceLister(ast.NodeVisitor): - def __init__(self): - self.sources = set() - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Load): - self.sources.add(node.id) - - -class _DeadCodeRemover(ast.NodeTransformer): - def __init__(self, kept_targets): - self.kept_targets = kept_targets - - def visit_Assign(self, node): - new_targets = [] - for target in node.targets: - if (not isinstance(target, ast.Name) - or target.id in self.kept_targets): - new_targets.append(target) - if not new_targets and is_ref_transparent(node.value)[0]: - return None - else: - return node - - def visit_AugAssign(self, node): - if (isinstance(node.target, ast.Name) - and node.target.id not in self.kept_targets - and is_ref_transparent(node.value)[0]): - return None - else: - return node - - def visit_If(self, node): - self.generic_visit(node) - if isinstance(node.test, ast.NameConstant): - if node.test.value: - return node.body - else: - return node.orelse - else: - return node - - def visit_While(self, node): - self.generic_visit(node) - if isinstance(node.test, ast.NameConstant) and not node.test.value: - return node.orelse - else: - return node - - -def remove_dead_code(func_def): - sl = _SourceLister() - sl.visit(func_def) - _DeadCodeRemover(sl.sources).visit(func_def) diff --git a/artiq/transforms/remove_inter_assigns.py b/artiq/transforms/remove_inter_assigns.py deleted file mode 100644 index 56d877215..000000000 --- a/artiq/transforms/remove_inter_assigns.py +++ /dev/null @@ -1,149 +0,0 @@ -import ast -from copy import copy, deepcopy -from collections import defaultdict - -from artiq.transforms.tools import is_ref_transparent, count_all_nodes - - -class _TargetLister(ast.NodeVisitor): - def __init__(self): - self.targets = set() - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Store): - self.targets.add(node.id) - - -class _InterAssignRemover(ast.NodeTransformer): - def __init__(self): - self.replacements = dict() - self.modified_names = set() - # name -> set of names that depend on it - # i.e. when x is modified, dependencies[x] is the set of names that - # cannot be replaced anymore - self.dependencies = defaultdict(set) - - def invalidate(self, name): - try: - del self.replacements[name] - except KeyError: - pass - for d in self.dependencies[name]: - self.invalidate(d) - del self.dependencies[name] - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Load): - try: - return deepcopy(self.replacements[node.id]) - except KeyError: - return node - else: - self.modified_names.add(node.id) - self.invalidate(node.id) - return node - - def visit_Assign(self, node): - node.value = self.visit(node.value) - node.targets = [self.visit(target) for target in node.targets] - rt, depends_on = is_ref_transparent(node.value) - if rt and count_all_nodes(node.value) < 100: - for target in node.targets: - if isinstance(target, ast.Name): - if target.id not in depends_on: - self.replacements[target.id] = node.value - for d in depends_on: - self.dependencies[d].add(target.id) - return node - - def visit_AugAssign(self, node): - left = deepcopy(node.target) - left.ctx = ast.Load() - newnode = ast.copy_location( - ast.Assign( - targets=[node.target], - value=ast.BinOp(left=left, op=node.op, right=node.value) - ), - node - ) - return self.visit_Assign(newnode) - - def modified_names_push(self): - prev_modified_names = self.modified_names - self.modified_names = set() - return prev_modified_names - - def modified_names_pop(self, prev_modified_names): - for name in self.modified_names: - self.invalidate(name) - self.modified_names |= prev_modified_names - - def visit_Try(self, node): - prev_modified_names = self.modified_names_push() - node.body = [self.visit(stmt) for stmt in node.body] - self.modified_names_pop(prev_modified_names) - - prev_modified_names = self.modified_names_push() - prev_replacements = self.replacements - for handler in node.handlers: - self.replacements = copy(prev_replacements) - handler.body = [self.visit(stmt) for stmt in handler.body] - self.replacements = copy(prev_replacements) - node.orelse = [self.visit(stmt) for stmt in node.orelse] - self.modified_names_pop(prev_modified_names) - - prev_modified_names = self.modified_names_push() - node.finalbody = [self.visit(stmt) for stmt in node.finalbody] - self.modified_names_pop(prev_modified_names) - return node - - def visit_If(self, node): - node.test = self.visit(node.test) - - prev_modified_names = self.modified_names_push() - - prev_replacements = self.replacements - self.replacements = copy(prev_replacements) - node.body = [self.visit(n) for n in node.body] - self.replacements = copy(prev_replacements) - node.orelse = [self.visit(n) for n in node.orelse] - self.replacements = prev_replacements - - self.modified_names_pop(prev_modified_names) - - return node - - def visit_loop(self, node): - prev_modified_names = self.modified_names_push() - prev_replacements = self.replacements - - self.replacements = copy(prev_replacements) - tl = _TargetLister() - for n in node.body: - tl.visit(n) - for name in tl.targets: - self.invalidate(name) - node.body = [self.visit(n) for n in node.body] - - self.replacements = copy(prev_replacements) - node.orelse = [self.visit(n) for n in node.orelse] - - self.replacements = prev_replacements - self.modified_names_pop(prev_modified_names) - - def visit_For(self, node): - prev_modified_names = self.modified_names_push() - node.target = self.visit(node.target) - self.modified_names_pop(prev_modified_names) - node.iter = self.visit(node.iter) - self.visit_loop(node) - return node - - def visit_While(self, node): - self.visit_loop(node) - node.test = self.visit(node.test) - return node - - -def remove_inter_assigns(func_def): - _InterAssignRemover().visit(func_def) diff --git a/artiq/transforms/tools.py b/artiq/transforms/tools.py deleted file mode 100644 index 97d596d2b..000000000 --- a/artiq/transforms/tools.py +++ /dev/null @@ -1,141 +0,0 @@ -import ast -from fractions import Fraction - -from artiq.language import core as core_language -from artiq.language import units - - -embeddable_funcs = ( - core_language.delay_mu, core_language.at_mu, core_language.now_mu, - core_language.delay, - core_language.seconds_to_mu, core_language.mu_to_seconds, - core_language.syscall, core_language.watchdog, - range, bool, int, float, round, len, - core_language.int64, core_language.round64, - Fraction, core_language.EncodedException -) -embeddable_func_names = {func.__name__ for func in embeddable_funcs} - - -def is_embeddable(func): - for ef in embeddable_funcs: - if func is ef: - return True - return False - - -def eval_ast(expr, symdict=dict()): - if not isinstance(expr, ast.Expression): - expr = ast.copy_location(ast.Expression(expr), expr) - ast.fix_missing_locations(expr) - code = compile(expr, "", "eval") - return eval(code, symdict) - - -class NotASTRepresentable(Exception): - pass - - -def value_to_ast(value): - if isinstance(value, core_language.int64): # must be before int - return ast.Call( - func=ast.Name("int64", ast.Load()), - args=[ast.Num(int(value))], - keywords=[], starargs=None, kwargs=None) - elif isinstance(value, bool) or value is None: - # must also be before int - # isinstance(True/False, int) == True - return ast.NameConstant(value) - elif isinstance(value, (int, float)): - return ast.Num(value) - elif isinstance(value, Fraction): - return ast.Call( - func=ast.Name("Fraction", ast.Load()), - args=[ast.Num(value.numerator), ast.Num(value.denominator)], - keywords=[], starargs=None, kwargs=None) - elif isinstance(value, str): - return ast.Str(value) - elif isinstance(value, list): - elts = [value_to_ast(elt) for elt in value] - return ast.List(elts, ast.Load()) - else: - for kg in core_language.kernel_globals: - if value is getattr(core_language, kg): - return ast.Name(kg, ast.Load()) - raise NotASTRepresentable(str(value)) - - -class NotConstant(Exception): - pass - - -def eval_constant(node): - if isinstance(node, ast.Num): - return node.n - elif isinstance(node, ast.Str): - return node.s - elif isinstance(node, ast.NameConstant): - return node.value - elif isinstance(node, ast.Call): - funcname = node.func.id - if funcname == "int64": - return core_language.int64(eval_constant(node.args[0])) - elif funcname == "Fraction": - numerator = eval_constant(node.args[0]) - denominator = eval_constant(node.args[1]) - return Fraction(numerator, denominator) - else: - raise NotConstant - else: - raise NotConstant - - -_replaceable_funcs = { - "bool", "int", "float", "round", - "int64", "round64", "Fraction", - "seconds_to_mu", "mu_to_seconds" -} - - -def _is_ref_transparent(dependencies, expr): - if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)): - return True - elif isinstance(expr, ast.Name): - dependencies.add(expr.id) - return True - elif isinstance(expr, ast.UnaryOp): - return _is_ref_transparent(dependencies, expr.operand) - elif isinstance(expr, ast.BinOp): - return (_is_ref_transparent(dependencies, expr.left) - and _is_ref_transparent(dependencies, expr.right)) - elif isinstance(expr, ast.BoolOp): - return all(_is_ref_transparent(dependencies, v) for v in expr.values) - elif isinstance(expr, ast.Call): - return (expr.func.id in _replaceable_funcs and - all(_is_ref_transparent(dependencies, arg) - for arg in expr.args)) - else: - return False - - -def is_ref_transparent(expr): - dependencies = set() - if _is_ref_transparent(dependencies, expr): - return True, dependencies - else: - return False, None - - -class _NodeCounter(ast.NodeVisitor): - def __init__(self): - self.count = 0 - - def generic_visit(self, node): - self.count += 1 - ast.NodeVisitor.generic_visit(self, node) - - -def count_all_nodes(node): - nc = _NodeCounter() - nc.visit(node) - return nc.count diff --git a/artiq/transforms/unparse.py b/artiq/transforms/unparse.py deleted file mode 100644 index cacd6e73b..000000000 --- a/artiq/transforms/unparse.py +++ /dev/null @@ -1,600 +0,0 @@ -import sys -import ast - - -# Large float and imaginary literals get turned into infinities in the AST. -# We unparse those infinities to INFSTR. -INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) - - -def _interleave(inter, f, seq): - """Call f on each item in seq, calling inter() in between. - """ - seq = iter(seq) - try: - f(next(seq)) - except StopIteration: - pass - else: - for x in seq: - inter() - f(x) - - -class _Unparser: - """Methods in this class recursively traverse an AST and - output source code for the abstract syntax; original formatting - is disregarded. """ - - def __init__(self, tree): - """Print the source for tree to the "result" string.""" - self.result = "" - self._indent = 0 - self.dispatch(tree) - self.result += "\n" - - def fill(self, text=""): - "Indent a piece of text, according to the current indentation level" - self.result += "\n"+" "*self._indent + text - - def write(self, text): - "Append a piece of text to the current line." - self.result += text - - def enter(self): - "Print ':', and increase the indentation." - self.write(":") - self._indent += 1 - - def leave(self): - "Decrease the indentation level." - self._indent -= 1 - - def dispatch(self, tree): - "Dispatcher function, dispatching tree type T to method _T." - if isinstance(tree, list): - for t in tree: - self.dispatch(t) - return - meth = getattr(self, "_"+tree.__class__.__name__) - meth(tree) - - # Unparsing methods - # - # There should be one method per concrete grammar type - # Constructors should be grouped by sum type. Ideally, - # this would follow the order in the grammar, but - # currently doesn't. - - def _Module(self, tree): - for stmt in tree.body: - self.dispatch(stmt) - - # stmt - def _Expr(self, tree): - self.fill() - self.dispatch(tree.value) - - def _Import(self, t): - self.fill("import ") - _interleave(lambda: self.write(", "), self.dispatch, t.names) - - def _ImportFrom(self, t): - self.fill("from ") - self.write("." * t.level) - if t.module: - self.write(t.module) - self.write(" import ") - _interleave(lambda: self.write(", "), self.dispatch, t.names) - - def _Assign(self, t): - self.fill() - for target in t.targets: - self.dispatch(target) - self.write(" = ") - self.dispatch(t.value) - - def _AugAssign(self, t): - self.fill() - self.dispatch(t.target) - self.write(" "+self.binop[t.op.__class__.__name__]+"= ") - self.dispatch(t.value) - - def _Return(self, t): - self.fill("return") - if t.value: - self.write(" ") - self.dispatch(t.value) - - def _Pass(self, t): - self.fill("pass") - - def _Break(self, t): - self.fill("break") - - def _Continue(self, t): - self.fill("continue") - - def _Delete(self, t): - self.fill("del ") - _interleave(lambda: self.write(", "), self.dispatch, t.targets) - - def _Assert(self, t): - self.fill("assert ") - self.dispatch(t.test) - if t.msg: - self.write(", ") - self.dispatch(t.msg) - - def _Global(self, t): - self.fill("global ") - _interleave(lambda: self.write(", "), self.write, t.names) - - def _Nonlocal(self, t): - self.fill("nonlocal ") - _interleave(lambda: self.write(", "), self.write, t.names) - - def _Yield(self, t): - self.write("(") - self.write("yield") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") - - def _YieldFrom(self, t): - self.write("(") - self.write("yield from") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") - - def _Raise(self, t): - self.fill("raise") - if not t.exc: - assert not t.cause - return - self.write(" ") - self.dispatch(t.exc) - if t.cause: - self.write(" from ") - self.dispatch(t.cause) - - def _Try(self, t): - self.fill("try") - self.enter() - self.dispatch(t.body) - self.leave() - for ex in t.handlers: - self.dispatch(ex) - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - if t.finalbody: - self.fill("finally") - self.enter() - self.dispatch(t.finalbody) - self.leave() - - def _ExceptHandler(self, t): - self.fill("except") - if t.type: - self.write(" ") - self.dispatch(t.type) - if t.name: - self.write(" as ") - self.write(t.name) - self.enter() - self.dispatch(t.body) - self.leave() - - def _ClassDef(self, t): - self.write("\n") - for deco in t.decorator_list: - self.fill("@") - self.dispatch(deco) - self.fill("class "+t.name) - self.write("(") - comma = False - for e in t.bases: - if comma: - self.write(", ") - else: - comma = True - self.dispatch(e) - for e in t.keywords: - if comma: - self.write(", ") - else: - comma = True - self.dispatch(e) - if t.starargs: - if comma: - self.write(", ") - else: - comma = True - self.write("*") - self.dispatch(t.starargs) - if t.kwargs: - if comma: - self.write(", ") - else: - comma = True - self.write("**") - self.dispatch(t.kwargs) - self.write(")") - - self.enter() - self.dispatch(t.body) - self.leave() - - def _FunctionDef(self, t): - self.write("\n") - for deco in t.decorator_list: - self.fill("@") - self.dispatch(deco) - self.fill("def "+t.name + "(") - self.dispatch(t.args) - self.write(")") - if t.returns: - self.write(" -> ") - self.dispatch(t.returns) - self.enter() - self.dispatch(t.body) - self.leave() - - def _For(self, t): - self.fill("for ") - self.dispatch(t.target) - self.write(" in ") - self.dispatch(t.iter) - self.enter() - self.dispatch(t.body) - self.leave() - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _If(self, t): - self.fill("if ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - # collapse nested ifs into equivalent elifs. - while (t.orelse and len(t.orelse) == 1 and - isinstance(t.orelse[0], ast.If)): - t = t.orelse[0] - self.fill("elif ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - # final else - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _While(self, t): - self.fill("while ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _With(self, t): - self.fill("with ") - _interleave(lambda: self.write(", "), self.dispatch, t.items) - self.enter() - self.dispatch(t.body) - self.leave() - - # expr - def _Bytes(self, t): - self.write(repr(t.s)) - - def _Str(self, tree): - self.write(repr(tree.s)) - - def _Name(self, t): - self.write(t.id) - - def _NameConstant(self, t): - self.write(repr(t.value)) - - def _Num(self, t): - # Substitute overflowing decimal literal for AST infinities. - self.write(repr(t.n).replace("inf", INFSTR)) - - def _List(self, t): - self.write("[") - _interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write("]") - - def _ListComp(self, t): - self.write("[") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write("]") - - def _GeneratorExp(self, t): - self.write("(") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write(")") - - def _SetComp(self, t): - self.write("{") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write("}") - - def _DictComp(self, t): - self.write("{") - self.dispatch(t.key) - self.write(": ") - self.dispatch(t.value) - for gen in t.generators: - self.dispatch(gen) - self.write("}") - - def _comprehension(self, t): - self.write(" for ") - self.dispatch(t.target) - self.write(" in ") - self.dispatch(t.iter) - for if_clause in t.ifs: - self.write(" if ") - self.dispatch(if_clause) - - def _IfExp(self, t): - self.write("(") - self.dispatch(t.body) - self.write(" if ") - self.dispatch(t.test) - self.write(" else ") - self.dispatch(t.orelse) - self.write(")") - - def _Set(self, t): - assert(t.elts) # should be at least one element - self.write("{") - _interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write("}") - - def _Dict(self, t): - self.write("{") - - def write_pair(pair): - (k, v) = pair - self.dispatch(k) - self.write(": ") - self.dispatch(v) - _interleave(lambda: self.write(", "), write_pair, - zip(t.keys, t.values)) - self.write("}") - - def _Tuple(self, t): - self.write("(") - if len(t.elts) == 1: - (elt,) = t.elts - self.dispatch(elt) - self.write(",") - else: - _interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write(")") - - unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} - - def _UnaryOp(self, t): - self.write("(") - self.write(self.unop[t.op.__class__.__name__]) - self.write(" ") - self.dispatch(t.operand) - self.write(")") - - binop = {"Add": "+", "Sub": "-", "Mult": "*", "Div": "/", "Mod": "%", - "LShift": "<<", "RShift": ">>", - "BitOr": "|", "BitXor": "^", "BitAnd": "&", - "FloorDiv": "//", "Pow": "**"} - - def _BinOp(self, t): - self.write("(") - self.dispatch(t.left) - self.write(" " + self.binop[t.op.__class__.__name__] + " ") - self.dispatch(t.right) - self.write(")") - - cmpops = {"Eq": "==", "NotEq": "!=", - "Lt": "<", "LtE": "<=", "Gt": ">", "GtE": ">=", - "Is": "is", "IsNot": "is not", "In": "in", "NotIn": "not in"} - - def _Compare(self, t): - self.write("(") - self.dispatch(t.left) - for o, e in zip(t.ops, t.comparators): - self.write(" " + self.cmpops[o.__class__.__name__] + " ") - self.dispatch(e) - self.write(")") - - boolops = {ast.And: "and", ast.Or: "or"} - - def _BoolOp(self, t): - self.write("(") - s = " %s " % self.boolops[t.op.__class__] - _interleave(lambda: self.write(s), self.dispatch, t.values) - self.write(")") - - def _Attribute(self, t): - self.dispatch(t.value) - # Special case: 3.__abs__() is a syntax error, so if t.value - # is an integer literal then we need to either parenthesize - # it or add an extra space to get 3 .__abs__(). - if isinstance(t.value, ast.Num) and isinstance(t.value.n, int): - self.write(" ") - self.write(".") - self.write(t.attr) - - def _Call(self, t): - self.dispatch(t.func) - self.write("(") - comma = False - for e in t.args: - if comma: - self.write(", ") - else: - comma = True - self.dispatch(e) - for e in t.keywords: - if comma: - self.write(", ") - else: - comma = True - self.dispatch(e) - if t.starargs: - if comma: - self.write(", ") - else: - comma = True - self.write("*") - self.dispatch(t.starargs) - if t.kwargs: - if comma: - self.write(", ") - else: - comma = True - self.write("**") - self.dispatch(t.kwargs) - self.write(")") - - def _Subscript(self, t): - self.dispatch(t.value) - self.write("[") - self.dispatch(t.slice) - self.write("]") - - def _Starred(self, t): - self.write("*") - self.dispatch(t.value) - - # slice - def _Ellipsis(self, t): - self.write("...") - - def _Index(self, t): - self.dispatch(t.value) - - def _Slice(self, t): - if t.lower: - self.dispatch(t.lower) - self.write(":") - if t.upper: - self.dispatch(t.upper) - if t.step: - self.write(":") - self.dispatch(t.step) - - def _ExtSlice(self, t): - _interleave(lambda: self.write(', '), self.dispatch, t.dims) - - # argument - def _arg(self, t): - self.write(t.arg) - if t.annotation: - self.write(": ") - self.dispatch(t.annotation) - - # others - def _arguments(self, t): - first = True - # normal arguments - defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults - for a, d in zip(t.args, defaults): - if first: - first = False - else: - self.write(", ") - self.dispatch(a) - if d: - self.write("=") - self.dispatch(d) - - # varargs, or bare '*' if no varargs but keyword-only arguments present - if t.vararg or t.kwonlyargs: - if first: - first = False - else: - self.write(", ") - self.write("*") - if t.vararg: - self.write(t.vararg.arg) - if t.vararg.annotation: - self.write(": ") - self.dispatch(t.vararg.annotation) - - # keyword-only arguments - if t.kwonlyargs: - for a, d in zip(t.kwonlyargs, t.kw_defaults): - if first: - first = False - else: - self.write(", ") - self.dispatch(a), - if d: - self.write("=") - self.dispatch(d) - - # kwargs - if t.kwarg: - if first: - first = False - else: - self.write(", ") - self.write("**"+t.kwarg.arg) - if t.kwarg.annotation: - self.write(": ") - self.dispatch(t.kwarg.annotation) - - def _keyword(self, t): - self.write(t.arg) - self.write("=") - self.dispatch(t.value) - - def _Lambda(self, t): - self.write("(") - self.write("lambda ") - self.dispatch(t.args) - self.write(": ") - self.dispatch(t.body) - self.write(")") - - def _alias(self, t): - self.write(t.name) - if t.asname: - self.write(" as "+t.asname) - - def _withitem(self, t): - self.dispatch(t.context_expr) - if t.optional_vars: - self.write(" as ") - self.dispatch(t.optional_vars) - - -def unparse(tree): - unparser = _Unparser(tree) - return unparser.result