From 082e9e20dd799feff422e354bdc71e8baed44109 Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 25 Dec 2015 15:02:33 +0800 Subject: [PATCH] compiler: do not associate SSA values with iodelay even when inlining. Fixes #201. --- artiq/compiler/algorithms/inline.py | 23 +++---- artiq/compiler/asttyped.py | 1 + artiq/compiler/ir.py | 68 ++++++++----------- .../compiler/transforms/artiq_ir_generator.py | 13 ++-- .../compiler/transforms/asttyped_rewriter.py | 2 +- .../compiler/transforms/iodelay_estimator.py | 4 +- 6 files changed, 47 insertions(+), 64 deletions(-) diff --git a/artiq/compiler/algorithms/inline.py b/artiq/compiler/algorithms/inline.py index fc7cb6a95..ce3e3315f 100644 --- a/artiq/compiler/algorithms/inline.py +++ b/artiq/compiler/algorithms/inline.py @@ -51,20 +51,19 @@ def inline(call_insn): elif isinstance(source_insn, ir.Phi): target_insn = ir.Phi() elif isinstance(source_insn, ir.Delay): - substs = source_insn.substs() - mapped_substs = {var: value_map[substs[var]] for var in substs} - const_substs = {var: iodelay.Const(mapped_substs[var].value) - for var in mapped_substs - if isinstance(mapped_substs[var], ir.Constant)} - other_substs = {var: mapped_substs[var] - for var in mapped_substs - if not isinstance(mapped_substs[var], ir.Constant)} - target_insn = ir.Delay(source_insn.interval.fold(const_substs), other_substs, - value_map[source_insn.decomposition()], - value_map[source_insn.target()]) + target_insn = source_insn.copy(mapper) + target_insn.interval = source_insn.interval.fold(call_insn.arg_exprs) + elif isinstance(source_insn, ir.Loop): + target_insn = source_insn.copy(mapper) + target_insn.trip_count = source_insn.trip_count.fold(call_insn.arg_exprs) + elif isinstance(source_insn, ir.Call): + target_insn = source_insn.copy(mapper) + target_insn.arg_exprs = \ + { arg: source_insn.arg_exprs[arg].fold(call_insn.arg_exprs) + for arg in source_insn.arg_exprs } else: target_insn = source_insn.copy(mapper) - target_insn.name = "i." + source_insn.name + target_insn.name = "i." + source_insn.name value_map[source_insn] = target_insn target_block.append(target_insn) diff --git a/artiq/compiler/asttyped.py b/artiq/compiler/asttyped.py index 26df00ad4..9d4274470 100644 --- a/artiq/compiler/asttyped.py +++ b/artiq/compiler/asttyped.py @@ -54,6 +54,7 @@ class BoolOpT(ast.BoolOp, commontyped): class CallT(ast.Call, commontyped): """ :ivar iodelay: (:class:`iodelay.Expr`) + :ivar arg_exprs: (dict of str to :class:`iodelay.Expr`) """ class CompareT(ast.Compare, commontyped): pass diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index c517787d3..645a7f3c4 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -5,7 +5,7 @@ of the ARTIQ compiler. from collections import OrderedDict from pythonparser import ast -from . import types, builtins +from . import types, builtins, iodelay # Generic SSA IR classes @@ -939,6 +939,8 @@ class Call(Instruction): """ A function call operation. + :ivar arg_exprs: (dict of str to `iodelay.Expr`) + iodelay expressions for values of arguments :ivar static_target_function: (:class:`Function` or None) statically resolved callee """ @@ -946,15 +948,21 @@ class Call(Instruction): """ :param func: (:class:`Value`) function to call :param args: (list of :class:`Value`) function arguments + :param arg_exprs: (dict of str to `iodelay.Expr`) """ - def __init__(self, func, args, name=""): + def __init__(self, func, args, arg_exprs, name=""): assert isinstance(func, Value) for arg in args: assert isinstance(arg, Value) + for arg in arg_exprs: + assert isinstance(arg, str) + assert isinstance(arg_exprs[arg], iodelay.Expr) super().__init__([func] + args, func.type.ret, name) + self.arg_exprs = arg_exprs self.static_target_function = None def copy(self, mapper): self_copy = super().copy(mapper) + self_copy.arg_exprs = self.arg_exprs self_copy.static_target_function = self.static_target_function return self_copy @@ -1195,6 +1203,8 @@ class Invoke(Terminator): """ A function call operation that supports exception handling. + :ivar arg_exprs: (dict of str to `iodelay.Expr`) + iodelay expressions for values of arguments :ivar static_target_function: (:class:`Function` or None) statically resolved callee """ @@ -1204,17 +1214,23 @@ class Invoke(Terminator): :param args: (list of :class:`Value`) function arguments :param normal: (:class:`BasicBlock`) normal target :param exn: (:class:`BasicBlock`) exceptional target + :param arg_exprs: (dict of str to `iodelay.Expr`) """ - def __init__(self, func, args, normal, exn, name=""): + def __init__(self, func, args, arg_exprs, normal, exn, name=""): assert isinstance(func, Value) for arg in args: assert isinstance(arg, Value) assert isinstance(normal, BasicBlock) assert isinstance(exn, BasicBlock) + for arg in arg_exprs: + assert isinstance(arg, str) + assert isinstance(arg_exprs[arg], iodelay.Expr) super().__init__([func] + args + [normal, exn], func.type.ret, name) + self.arg_exprs = arg_exprs self.static_target_function = None def copy(self, mapper): self_copy = super().copy(mapper) + self_copy.arg_exprs = self.arg_exprs self_copy.static_target_function = self.static_target_function return self_copy @@ -1299,31 +1315,24 @@ class Delay(Terminator): inlining could lead to the expression folding to a constant. :ivar interval: (:class:`iodelay.Expr`) expression - :ivar var_names: (list of string) - iodelay variable names corresponding to SSA values """ """ :param interval: (:class:`iodelay.Expr`) expression - :param substs: (dict of str to :class:`Value`) - SSA values corresponding to iodelay variable names :param call: (:class:`Call` or ``Constant(None, builtins.TNone())``) the call instruction that caused this delay, if any :param target: (:class:`BasicBlock`) branch target """ - def __init__(self, interval, substs, decomposition, target, name=""): - for var_name in substs: assert isinstance(var_name, str) + def __init__(self, interval, decomposition, target, name=""): assert isinstance(decomposition, Call) or \ isinstance(decomposition, Builtin) and decomposition.op in ("delay", "delay_mu") assert isinstance(target, BasicBlock) - super().__init__([decomposition, target, *substs.values()], builtins.TNone(), name) + super().__init__([decomposition, target], builtins.TNone(), name) self.interval = interval - self.var_names = list(substs.keys()) def copy(self, mapper): self_copy = super().copy(mapper) self_copy.interval = self.interval - self_copy.var_names = list(self.var_names) return self_copy def decomposition(self): @@ -1342,17 +1351,9 @@ class Delay(Terminator): self.operands[1] = new_target self.operands[1].uses.add(self) - def substs(self): - return {key: value for key, value in zip(self.var_names, self.operands[2:])} - def _operands_as_string(self, type_printer): - substs = self.substs() - substs_as_strings = [] - for var_name in substs: - substs_as_strings.append("{} = {}".format(var_name, substs[var_name])) - result = "[{}]".format(", ".join(substs_as_strings)) - result += ", decomp {}, to {}".format(self.decomposition().as_operand(type_printer), - self.target().as_operand(type_printer)) + result = "decomp {}, to {}".format(self.decomposition().as_operand(type_printer), + self.target().as_operand(type_printer)) return result def opcode(self): @@ -1367,14 +1368,10 @@ class Loop(Terminator): :ivar trip_count: (:class:`iodelay.Expr`) expression for trip count - :ivar var_names: (list of string) - iodelay variable names corresponding to ``trip_count`` operands """ """ :param trip_count: (:class:`iodelay.Expr`) expression - :param substs: (dict of str to :class:`Value`) - SSA values corresponding to iodelay variable names :param indvar: (:class:`Phi`) phi node corresponding to the induction SSA value, which advances from ``0`` to ``trip_count - 1`` @@ -1382,21 +1379,18 @@ class Loop(Terminator): :param if_true: (:class:`BasicBlock`) branch target if condition is truthful :param if_false: (:class:`BasicBlock`) branch target if condition is falseful """ - def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""): - for var_name in substs: assert isinstance(var_name, str) + def __init__(self, trip_count, indvar, cond, if_true, if_false, name=""): assert isinstance(indvar, Phi) assert isinstance(cond, Value) assert builtins.is_bool(cond.type) assert isinstance(if_true, BasicBlock) assert isinstance(if_false, BasicBlock) - super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name) + super().__init__([indvar, cond, if_true, if_false], builtins.TNone(), name) self.trip_count = trip_count - self.var_names = list(substs.keys()) def copy(self, mapper): self_copy = super().copy(mapper) self_copy.trip_count = self.trip_count - self_copy.var_names = list(self.var_names) return self_copy def induction_variable(self): @@ -1411,17 +1405,9 @@ class Loop(Terminator): def if_false(self): return self.operands[3] - def substs(self): - return {key: value for key, value in zip(self.var_names, self.operands[4:])} - def _operands_as_string(self, type_printer): - substs = self.substs() - substs_as_strings = [] - for var_name in substs: - substs_as_strings.append("{} = {}".format(var_name, substs[var_name])) - result = "[{}]".format(", ".join(substs_as_strings)) - result += ", indvar {}, if {}, {}, {}".format( - *list(map(lambda value: value.as_operand(type_printer), self.operands[0:4]))) + result = "indvar {}, if {}, {}, {}".format( + *list(map(lambda value: value.as_operand(type_printer), self.operands))) return result def opcode(self): diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index de15e20f4..a76fd76a9 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -526,9 +526,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else_tail = tail if node.trip_count is not None: - substs = {var_name: self.current_args[var_name] - for var_name in node.trip_count.free_vars()} - head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail)) + head.append(ir.Loop(node.trip_count, phi, cond, body, else_tail)) else: head.append(ir.BranchIf(cond, body, else_tail)) if not post_body.is_terminated(): @@ -1659,10 +1657,11 @@ class ARTIQIRGenerator(algorithm.Visitor): assert None not in args if self.unwind_target is None: - insn = self.append(ir.Call(func, args)) + insn = self.append(ir.Call(func, args, node.arg_exprs)) else: after_invoke = self.add_block() - insn = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target)) + insn = self.append(ir.Invoke(func, args, node.arg_exprs, + after_invoke, self.unwind_target)) self.current_block = after_invoke method_key = None @@ -1672,9 +1671,7 @@ class ARTIQIRGenerator(algorithm.Visitor): if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0): after_delay = self.add_block() - substs = {var_name: self.current_args[var_name] - for var_name in node.iodelay.free_vars()} - self.append(ir.Delay(node.iodelay, substs, insn, after_delay)) + self.append(ir.Delay(node.iodelay, insn, after_delay)) self.current_block = after_delay return insn diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index b7fa6f7a1..a14d55c5d 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -426,7 +426,7 @@ class ASTTypedRewriter(algorithm.Transformer): def visit_Call(self, node): node = self.generic_visit(node) - node = asttyped.CallT(type=types.TVar(), iodelay=None, + node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={}, func=node.func, args=node.args, keywords=node.keywords, starargs=node.starargs, kwargs=node.kwargs, star_loc=node.star_loc, dstar_loc=node.dstar_loc, diff --git a/artiq/compiler/transforms/iodelay_estimator.py b/artiq/compiler/transforms/iodelay_estimator.py index df76c74cd..163093bb9 100644 --- a/artiq/compiler/transforms/iodelay_estimator.py +++ b/artiq/compiler/transforms/iodelay_estimator.py @@ -297,8 +297,8 @@ class IODelayEstimator(algorithm.Visitor): args[arg_name] = arg_node free_vars = delay.duration.free_vars() - call_delay = delay.duration.fold( - { arg: self.evaluate(args[arg], abort=abort) for arg in free_vars }) + node.arg_exprs = { arg: self.evaluate(args[arg], abort=abort) for arg in free_vars } + call_delay = delay.duration.fold(node.arg_exprs) else: assert False else: