compiler: do not associate SSA values with iodelay even when inlining.

Fixes #201.
This commit is contained in:
whitequark 2015-12-25 15:02:33 +08:00
parent 33c3b3377e
commit 082e9e20dd
6 changed files with 47 additions and 64 deletions

View File

@ -51,20 +51,19 @@ def inline(call_insn):
elif isinstance(source_insn, ir.Phi): elif isinstance(source_insn, ir.Phi):
target_insn = ir.Phi() target_insn = ir.Phi()
elif isinstance(source_insn, ir.Delay): elif isinstance(source_insn, ir.Delay):
substs = source_insn.substs() target_insn = source_insn.copy(mapper)
mapped_substs = {var: value_map[substs[var]] for var in substs} target_insn.interval = source_insn.interval.fold(call_insn.arg_exprs)
const_substs = {var: iodelay.Const(mapped_substs[var].value) elif isinstance(source_insn, ir.Loop):
for var in mapped_substs target_insn = source_insn.copy(mapper)
if isinstance(mapped_substs[var], ir.Constant)} target_insn.trip_count = source_insn.trip_count.fold(call_insn.arg_exprs)
other_substs = {var: mapped_substs[var] elif isinstance(source_insn, ir.Call):
for var in mapped_substs target_insn = source_insn.copy(mapper)
if not isinstance(mapped_substs[var], ir.Constant)} target_insn.arg_exprs = \
target_insn = ir.Delay(source_insn.interval.fold(const_substs), other_substs, { arg: source_insn.arg_exprs[arg].fold(call_insn.arg_exprs)
value_map[source_insn.decomposition()], for arg in source_insn.arg_exprs }
value_map[source_insn.target()])
else: else:
target_insn = source_insn.copy(mapper) target_insn = source_insn.copy(mapper)
target_insn.name = "i." + source_insn.name target_insn.name = "i." + source_insn.name
value_map[source_insn] = target_insn value_map[source_insn] = target_insn
target_block.append(target_insn) target_block.append(target_insn)

View File

@ -54,6 +54,7 @@ class BoolOpT(ast.BoolOp, commontyped):
class CallT(ast.Call, commontyped): class CallT(ast.Call, commontyped):
""" """
:ivar iodelay: (:class:`iodelay.Expr`) :ivar iodelay: (:class:`iodelay.Expr`)
:ivar arg_exprs: (dict of str to :class:`iodelay.Expr`)
""" """
class CompareT(ast.Compare, commontyped): class CompareT(ast.Compare, commontyped):
pass pass

View File

@ -5,7 +5,7 @@ of the ARTIQ compiler.
from collections import OrderedDict from collections import OrderedDict
from pythonparser import ast from pythonparser import ast
from . import types, builtins from . import types, builtins, iodelay
# Generic SSA IR classes # Generic SSA IR classes
@ -939,6 +939,8 @@ class Call(Instruction):
""" """
A function call operation. A function call operation.
:ivar arg_exprs: (dict of str to `iodelay.Expr`)
iodelay expressions for values of arguments
:ivar static_target_function: (:class:`Function` or None) :ivar static_target_function: (:class:`Function` or None)
statically resolved callee statically resolved callee
""" """
@ -946,15 +948,21 @@ class Call(Instruction):
""" """
:param func: (:class:`Value`) function to call :param func: (:class:`Value`) function to call
:param args: (list of :class:`Value`) function arguments :param args: (list of :class:`Value`) function arguments
:param arg_exprs: (dict of str to `iodelay.Expr`)
""" """
def __init__(self, func, args, name=""): def __init__(self, func, args, arg_exprs, name=""):
assert isinstance(func, Value) assert isinstance(func, Value)
for arg in args: assert isinstance(arg, Value) for arg in args: assert isinstance(arg, Value)
for arg in arg_exprs:
assert isinstance(arg, str)
assert isinstance(arg_exprs[arg], iodelay.Expr)
super().__init__([func] + args, func.type.ret, name) super().__init__([func] + args, func.type.ret, name)
self.arg_exprs = arg_exprs
self.static_target_function = None self.static_target_function = None
def copy(self, mapper): def copy(self, mapper):
self_copy = super().copy(mapper) self_copy = super().copy(mapper)
self_copy.arg_exprs = self.arg_exprs
self_copy.static_target_function = self.static_target_function self_copy.static_target_function = self.static_target_function
return self_copy return self_copy
@ -1195,6 +1203,8 @@ class Invoke(Terminator):
""" """
A function call operation that supports exception handling. A function call operation that supports exception handling.
:ivar arg_exprs: (dict of str to `iodelay.Expr`)
iodelay expressions for values of arguments
:ivar static_target_function: (:class:`Function` or None) :ivar static_target_function: (:class:`Function` or None)
statically resolved callee statically resolved callee
""" """
@ -1204,17 +1214,23 @@ class Invoke(Terminator):
:param args: (list of :class:`Value`) function arguments :param args: (list of :class:`Value`) function arguments
:param normal: (:class:`BasicBlock`) normal target :param normal: (:class:`BasicBlock`) normal target
:param exn: (:class:`BasicBlock`) exceptional target :param exn: (:class:`BasicBlock`) exceptional target
:param arg_exprs: (dict of str to `iodelay.Expr`)
""" """
def __init__(self, func, args, normal, exn, name=""): def __init__(self, func, args, arg_exprs, normal, exn, name=""):
assert isinstance(func, Value) assert isinstance(func, Value)
for arg in args: assert isinstance(arg, Value) for arg in args: assert isinstance(arg, Value)
assert isinstance(normal, BasicBlock) assert isinstance(normal, BasicBlock)
assert isinstance(exn, BasicBlock) assert isinstance(exn, BasicBlock)
for arg in arg_exprs:
assert isinstance(arg, str)
assert isinstance(arg_exprs[arg], iodelay.Expr)
super().__init__([func] + args + [normal, exn], func.type.ret, name) super().__init__([func] + args + [normal, exn], func.type.ret, name)
self.arg_exprs = arg_exprs
self.static_target_function = None self.static_target_function = None
def copy(self, mapper): def copy(self, mapper):
self_copy = super().copy(mapper) self_copy = super().copy(mapper)
self_copy.arg_exprs = self.arg_exprs
self_copy.static_target_function = self.static_target_function self_copy.static_target_function = self.static_target_function
return self_copy return self_copy
@ -1299,31 +1315,24 @@ class Delay(Terminator):
inlining could lead to the expression folding to a constant. inlining could lead to the expression folding to a constant.
:ivar interval: (:class:`iodelay.Expr`) expression :ivar interval: (:class:`iodelay.Expr`) expression
:ivar var_names: (list of string)
iodelay variable names corresponding to SSA values
""" """
""" """
:param interval: (:class:`iodelay.Expr`) expression :param interval: (:class:`iodelay.Expr`) expression
:param substs: (dict of str to :class:`Value`)
SSA values corresponding to iodelay variable names
:param call: (:class:`Call` or ``Constant(None, builtins.TNone())``) :param call: (:class:`Call` or ``Constant(None, builtins.TNone())``)
the call instruction that caused this delay, if any the call instruction that caused this delay, if any
:param target: (:class:`BasicBlock`) branch target :param target: (:class:`BasicBlock`) branch target
""" """
def __init__(self, interval, substs, decomposition, target, name=""): def __init__(self, interval, decomposition, target, name=""):
for var_name in substs: assert isinstance(var_name, str)
assert isinstance(decomposition, Call) or \ assert isinstance(decomposition, Call) or \
isinstance(decomposition, Builtin) and decomposition.op in ("delay", "delay_mu") isinstance(decomposition, Builtin) and decomposition.op in ("delay", "delay_mu")
assert isinstance(target, BasicBlock) assert isinstance(target, BasicBlock)
super().__init__([decomposition, target, *substs.values()], builtins.TNone(), name) super().__init__([decomposition, target], builtins.TNone(), name)
self.interval = interval self.interval = interval
self.var_names = list(substs.keys())
def copy(self, mapper): def copy(self, mapper):
self_copy = super().copy(mapper) self_copy = super().copy(mapper)
self_copy.interval = self.interval self_copy.interval = self.interval
self_copy.var_names = list(self.var_names)
return self_copy return self_copy
def decomposition(self): def decomposition(self):
@ -1342,17 +1351,9 @@ class Delay(Terminator):
self.operands[1] = new_target self.operands[1] = new_target
self.operands[1].uses.add(self) self.operands[1].uses.add(self)
def substs(self):
return {key: value for key, value in zip(self.var_names, self.operands[2:])}
def _operands_as_string(self, type_printer): def _operands_as_string(self, type_printer):
substs = self.substs() result = "decomp {}, to {}".format(self.decomposition().as_operand(type_printer),
substs_as_strings = [] self.target().as_operand(type_printer))
for var_name in substs:
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
result = "[{}]".format(", ".join(substs_as_strings))
result += ", decomp {}, to {}".format(self.decomposition().as_operand(type_printer),
self.target().as_operand(type_printer))
return result return result
def opcode(self): def opcode(self):
@ -1367,14 +1368,10 @@ class Loop(Terminator):
:ivar trip_count: (:class:`iodelay.Expr`) :ivar trip_count: (:class:`iodelay.Expr`)
expression for trip count expression for trip count
:ivar var_names: (list of string)
iodelay variable names corresponding to ``trip_count`` operands
""" """
""" """
:param trip_count: (:class:`iodelay.Expr`) expression :param trip_count: (:class:`iodelay.Expr`) expression
:param substs: (dict of str to :class:`Value`)
SSA values corresponding to iodelay variable names
:param indvar: (:class:`Phi`) :param indvar: (:class:`Phi`)
phi node corresponding to the induction SSA value, phi node corresponding to the induction SSA value,
which advances from ``0`` to ``trip_count - 1`` which advances from ``0`` to ``trip_count - 1``
@ -1382,21 +1379,18 @@ class Loop(Terminator):
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful :param if_true: (:class:`BasicBlock`) branch target if condition is truthful
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful :param if_false: (:class:`BasicBlock`) branch target if condition is falseful
""" """
def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""): def __init__(self, trip_count, indvar, cond, if_true, if_false, name=""):
for var_name in substs: assert isinstance(var_name, str)
assert isinstance(indvar, Phi) assert isinstance(indvar, Phi)
assert isinstance(cond, Value) assert isinstance(cond, Value)
assert builtins.is_bool(cond.type) assert builtins.is_bool(cond.type)
assert isinstance(if_true, BasicBlock) assert isinstance(if_true, BasicBlock)
assert isinstance(if_false, BasicBlock) assert isinstance(if_false, BasicBlock)
super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name) super().__init__([indvar, cond, if_true, if_false], builtins.TNone(), name)
self.trip_count = trip_count self.trip_count = trip_count
self.var_names = list(substs.keys())
def copy(self, mapper): def copy(self, mapper):
self_copy = super().copy(mapper) self_copy = super().copy(mapper)
self_copy.trip_count = self.trip_count self_copy.trip_count = self.trip_count
self_copy.var_names = list(self.var_names)
return self_copy return self_copy
def induction_variable(self): def induction_variable(self):
@ -1411,17 +1405,9 @@ class Loop(Terminator):
def if_false(self): def if_false(self):
return self.operands[3] return self.operands[3]
def substs(self):
return {key: value for key, value in zip(self.var_names, self.operands[4:])}
def _operands_as_string(self, type_printer): def _operands_as_string(self, type_printer):
substs = self.substs() result = "indvar {}, if {}, {}, {}".format(
substs_as_strings = [] *list(map(lambda value: value.as_operand(type_printer), self.operands)))
for var_name in substs:
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
result = "[{}]".format(", ".join(substs_as_strings))
result += ", indvar {}, if {}, {}, {}".format(
*list(map(lambda value: value.as_operand(type_printer), self.operands[0:4])))
return result return result
def opcode(self): def opcode(self):

View File

@ -526,9 +526,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else_tail = tail else_tail = tail
if node.trip_count is not None: if node.trip_count is not None:
substs = {var_name: self.current_args[var_name] head.append(ir.Loop(node.trip_count, phi, cond, body, else_tail))
for var_name in node.trip_count.free_vars()}
head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail))
else: else:
head.append(ir.BranchIf(cond, body, else_tail)) head.append(ir.BranchIf(cond, body, else_tail))
if not post_body.is_terminated(): if not post_body.is_terminated():
@ -1659,10 +1657,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
assert None not in args assert None not in args
if self.unwind_target is None: if self.unwind_target is None:
insn = self.append(ir.Call(func, args)) insn = self.append(ir.Call(func, args, node.arg_exprs))
else: else:
after_invoke = self.add_block() after_invoke = self.add_block()
insn = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target)) insn = self.append(ir.Invoke(func, args, node.arg_exprs,
after_invoke, self.unwind_target))
self.current_block = after_invoke self.current_block = after_invoke
method_key = None method_key = None
@ -1672,9 +1671,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0): if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
after_delay = self.add_block() after_delay = self.add_block()
substs = {var_name: self.current_args[var_name] self.append(ir.Delay(node.iodelay, insn, after_delay))
for var_name in node.iodelay.free_vars()}
self.append(ir.Delay(node.iodelay, substs, insn, after_delay))
self.current_block = after_delay self.current_block = after_delay
return insn return insn

View File

@ -426,7 +426,7 @@ class ASTTypedRewriter(algorithm.Transformer):
def visit_Call(self, node): def visit_Call(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
node = asttyped.CallT(type=types.TVar(), iodelay=None, node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={},
func=node.func, args=node.args, keywords=node.keywords, func=node.func, args=node.args, keywords=node.keywords,
starargs=node.starargs, kwargs=node.kwargs, starargs=node.starargs, kwargs=node.kwargs,
star_loc=node.star_loc, dstar_loc=node.dstar_loc, star_loc=node.star_loc, dstar_loc=node.dstar_loc,

View File

@ -297,8 +297,8 @@ class IODelayEstimator(algorithm.Visitor):
args[arg_name] = arg_node args[arg_name] = arg_node
free_vars = delay.duration.free_vars() free_vars = delay.duration.free_vars()
call_delay = delay.duration.fold( node.arg_exprs = { arg: self.evaluate(args[arg], abort=abort) for arg in free_vars }
{ arg: self.evaluate(args[arg], abort=abort) for arg in free_vars }) call_delay = delay.duration.fold(node.arg_exprs)
else: else:
assert False assert False
else: else: