compiler: do not associate SSA values with iodelay even when inlining.

Fixes #201.
This commit is contained in:
whitequark 2015-12-25 15:02:33 +08:00
parent 33c3b3377e
commit 082e9e20dd
6 changed files with 47 additions and 64 deletions

View File

@ -51,20 +51,19 @@ def inline(call_insn):
elif isinstance(source_insn, ir.Phi):
target_insn = ir.Phi()
elif isinstance(source_insn, ir.Delay):
substs = source_insn.substs()
mapped_substs = {var: value_map[substs[var]] for var in substs}
const_substs = {var: iodelay.Const(mapped_substs[var].value)
for var in mapped_substs
if isinstance(mapped_substs[var], ir.Constant)}
other_substs = {var: mapped_substs[var]
for var in mapped_substs
if not isinstance(mapped_substs[var], ir.Constant)}
target_insn = ir.Delay(source_insn.interval.fold(const_substs), other_substs,
value_map[source_insn.decomposition()],
value_map[source_insn.target()])
target_insn = source_insn.copy(mapper)
target_insn.interval = source_insn.interval.fold(call_insn.arg_exprs)
elif isinstance(source_insn, ir.Loop):
target_insn = source_insn.copy(mapper)
target_insn.trip_count = source_insn.trip_count.fold(call_insn.arg_exprs)
elif isinstance(source_insn, ir.Call):
target_insn = source_insn.copy(mapper)
target_insn.arg_exprs = \
{ arg: source_insn.arg_exprs[arg].fold(call_insn.arg_exprs)
for arg in source_insn.arg_exprs }
else:
target_insn = source_insn.copy(mapper)
target_insn.name = "i." + source_insn.name
target_insn.name = "i." + source_insn.name
value_map[source_insn] = target_insn
target_block.append(target_insn)

View File

@ -54,6 +54,7 @@ class BoolOpT(ast.BoolOp, commontyped):
class CallT(ast.Call, commontyped):
"""
:ivar iodelay: (:class:`iodelay.Expr`)
:ivar arg_exprs: (dict of str to :class:`iodelay.Expr`)
"""
class CompareT(ast.Compare, commontyped):
pass

View File

@ -5,7 +5,7 @@ of the ARTIQ compiler.
from collections import OrderedDict
from pythonparser import ast
from . import types, builtins
from . import types, builtins, iodelay
# Generic SSA IR classes
@ -939,6 +939,8 @@ class Call(Instruction):
"""
A function call operation.
:ivar arg_exprs: (dict of str to `iodelay.Expr`)
iodelay expressions for values of arguments
:ivar static_target_function: (:class:`Function` or None)
statically resolved callee
"""
@ -946,15 +948,21 @@ class Call(Instruction):
"""
:param func: (:class:`Value`) function to call
:param args: (list of :class:`Value`) function arguments
:param arg_exprs: (dict of str to `iodelay.Expr`)
"""
def __init__(self, func, args, name=""):
def __init__(self, func, args, arg_exprs, name=""):
assert isinstance(func, Value)
for arg in args: assert isinstance(arg, Value)
for arg in arg_exprs:
assert isinstance(arg, str)
assert isinstance(arg_exprs[arg], iodelay.Expr)
super().__init__([func] + args, func.type.ret, name)
self.arg_exprs = arg_exprs
self.static_target_function = None
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.arg_exprs = self.arg_exprs
self_copy.static_target_function = self.static_target_function
return self_copy
@ -1195,6 +1203,8 @@ class Invoke(Terminator):
"""
A function call operation that supports exception handling.
:ivar arg_exprs: (dict of str to `iodelay.Expr`)
iodelay expressions for values of arguments
:ivar static_target_function: (:class:`Function` or None)
statically resolved callee
"""
@ -1204,17 +1214,23 @@ class Invoke(Terminator):
:param args: (list of :class:`Value`) function arguments
:param normal: (:class:`BasicBlock`) normal target
:param exn: (:class:`BasicBlock`) exceptional target
:param arg_exprs: (dict of str to `iodelay.Expr`)
"""
def __init__(self, func, args, normal, exn, name=""):
def __init__(self, func, args, arg_exprs, normal, exn, name=""):
assert isinstance(func, Value)
for arg in args: assert isinstance(arg, Value)
assert isinstance(normal, BasicBlock)
assert isinstance(exn, BasicBlock)
for arg in arg_exprs:
assert isinstance(arg, str)
assert isinstance(arg_exprs[arg], iodelay.Expr)
super().__init__([func] + args + [normal, exn], func.type.ret, name)
self.arg_exprs = arg_exprs
self.static_target_function = None
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.arg_exprs = self.arg_exprs
self_copy.static_target_function = self.static_target_function
return self_copy
@ -1299,31 +1315,24 @@ class Delay(Terminator):
inlining could lead to the expression folding to a constant.
:ivar interval: (:class:`iodelay.Expr`) expression
:ivar var_names: (list of string)
iodelay variable names corresponding to SSA values
"""
"""
:param interval: (:class:`iodelay.Expr`) expression
:param substs: (dict of str to :class:`Value`)
SSA values corresponding to iodelay variable names
:param call: (:class:`Call` or ``Constant(None, builtins.TNone())``)
the call instruction that caused this delay, if any
:param target: (:class:`BasicBlock`) branch target
"""
def __init__(self, interval, substs, decomposition, target, name=""):
for var_name in substs: assert isinstance(var_name, str)
def __init__(self, interval, decomposition, target, name=""):
assert isinstance(decomposition, Call) or \
isinstance(decomposition, Builtin) and decomposition.op in ("delay", "delay_mu")
assert isinstance(target, BasicBlock)
super().__init__([decomposition, target, *substs.values()], builtins.TNone(), name)
super().__init__([decomposition, target], builtins.TNone(), name)
self.interval = interval
self.var_names = list(substs.keys())
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.interval = self.interval
self_copy.var_names = list(self.var_names)
return self_copy
def decomposition(self):
@ -1342,17 +1351,9 @@ class Delay(Terminator):
self.operands[1] = new_target
self.operands[1].uses.add(self)
def substs(self):
return {key: value for key, value in zip(self.var_names, self.operands[2:])}
def _operands_as_string(self, type_printer):
substs = self.substs()
substs_as_strings = []
for var_name in substs:
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
result = "[{}]".format(", ".join(substs_as_strings))
result += ", decomp {}, to {}".format(self.decomposition().as_operand(type_printer),
self.target().as_operand(type_printer))
result = "decomp {}, to {}".format(self.decomposition().as_operand(type_printer),
self.target().as_operand(type_printer))
return result
def opcode(self):
@ -1367,14 +1368,10 @@ class Loop(Terminator):
:ivar trip_count: (:class:`iodelay.Expr`)
expression for trip count
:ivar var_names: (list of string)
iodelay variable names corresponding to ``trip_count`` operands
"""
"""
:param trip_count: (:class:`iodelay.Expr`) expression
:param substs: (dict of str to :class:`Value`)
SSA values corresponding to iodelay variable names
:param indvar: (:class:`Phi`)
phi node corresponding to the induction SSA value,
which advances from ``0`` to ``trip_count - 1``
@ -1382,21 +1379,18 @@ class Loop(Terminator):
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
"""
def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""):
for var_name in substs: assert isinstance(var_name, str)
def __init__(self, trip_count, indvar, cond, if_true, if_false, name=""):
assert isinstance(indvar, Phi)
assert isinstance(cond, Value)
assert builtins.is_bool(cond.type)
assert isinstance(if_true, BasicBlock)
assert isinstance(if_false, BasicBlock)
super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
super().__init__([indvar, cond, if_true, if_false], builtins.TNone(), name)
self.trip_count = trip_count
self.var_names = list(substs.keys())
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.trip_count = self.trip_count
self_copy.var_names = list(self.var_names)
return self_copy
def induction_variable(self):
@ -1411,17 +1405,9 @@ class Loop(Terminator):
def if_false(self):
return self.operands[3]
def substs(self):
return {key: value for key, value in zip(self.var_names, self.operands[4:])}
def _operands_as_string(self, type_printer):
substs = self.substs()
substs_as_strings = []
for var_name in substs:
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
result = "[{}]".format(", ".join(substs_as_strings))
result += ", indvar {}, if {}, {}, {}".format(
*list(map(lambda value: value.as_operand(type_printer), self.operands[0:4])))
result = "indvar {}, if {}, {}, {}".format(
*list(map(lambda value: value.as_operand(type_printer), self.operands)))
return result
def opcode(self):

View File

@ -526,9 +526,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else_tail = tail
if node.trip_count is not None:
substs = {var_name: self.current_args[var_name]
for var_name in node.trip_count.free_vars()}
head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail))
head.append(ir.Loop(node.trip_count, phi, cond, body, else_tail))
else:
head.append(ir.BranchIf(cond, body, else_tail))
if not post_body.is_terminated():
@ -1659,10 +1657,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
assert None not in args
if self.unwind_target is None:
insn = self.append(ir.Call(func, args))
insn = self.append(ir.Call(func, args, node.arg_exprs))
else:
after_invoke = self.add_block()
insn = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target))
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
after_invoke, self.unwind_target))
self.current_block = after_invoke
method_key = None
@ -1672,9 +1671,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
after_delay = self.add_block()
substs = {var_name: self.current_args[var_name]
for var_name in node.iodelay.free_vars()}
self.append(ir.Delay(node.iodelay, substs, insn, after_delay))
self.append(ir.Delay(node.iodelay, insn, after_delay))
self.current_block = after_delay
return insn

View File

@ -426,7 +426,7 @@ class ASTTypedRewriter(algorithm.Transformer):
def visit_Call(self, node):
node = self.generic_visit(node)
node = asttyped.CallT(type=types.TVar(), iodelay=None,
node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={},
func=node.func, args=node.args, keywords=node.keywords,
starargs=node.starargs, kwargs=node.kwargs,
star_loc=node.star_loc, dstar_loc=node.dstar_loc,

View File

@ -297,8 +297,8 @@ class IODelayEstimator(algorithm.Visitor):
args[arg_name] = arg_node
free_vars = delay.duration.free_vars()
call_delay = delay.duration.fold(
{ arg: self.evaluate(args[arg], abort=abort) for arg in free_vars })
node.arg_exprs = { arg: self.evaluate(args[arg], abort=abort) for arg in free_vars }
call_delay = delay.duration.fold(node.arg_exprs)
else:
assert False
else: