forked from M-Labs/artiq
compiler: do not associate SSA values with iodelay even when inlining.
Fixes #201.
This commit is contained in:
parent
33c3b3377e
commit
082e9e20dd
@ -51,20 +51,19 @@ def inline(call_insn):
|
||||
elif isinstance(source_insn, ir.Phi):
|
||||
target_insn = ir.Phi()
|
||||
elif isinstance(source_insn, ir.Delay):
|
||||
substs = source_insn.substs()
|
||||
mapped_substs = {var: value_map[substs[var]] for var in substs}
|
||||
const_substs = {var: iodelay.Const(mapped_substs[var].value)
|
||||
for var in mapped_substs
|
||||
if isinstance(mapped_substs[var], ir.Constant)}
|
||||
other_substs = {var: mapped_substs[var]
|
||||
for var in mapped_substs
|
||||
if not isinstance(mapped_substs[var], ir.Constant)}
|
||||
target_insn = ir.Delay(source_insn.interval.fold(const_substs), other_substs,
|
||||
value_map[source_insn.decomposition()],
|
||||
value_map[source_insn.target()])
|
||||
target_insn = source_insn.copy(mapper)
|
||||
target_insn.interval = source_insn.interval.fold(call_insn.arg_exprs)
|
||||
elif isinstance(source_insn, ir.Loop):
|
||||
target_insn = source_insn.copy(mapper)
|
||||
target_insn.trip_count = source_insn.trip_count.fold(call_insn.arg_exprs)
|
||||
elif isinstance(source_insn, ir.Call):
|
||||
target_insn = source_insn.copy(mapper)
|
||||
target_insn.arg_exprs = \
|
||||
{ arg: source_insn.arg_exprs[arg].fold(call_insn.arg_exprs)
|
||||
for arg in source_insn.arg_exprs }
|
||||
else:
|
||||
target_insn = source_insn.copy(mapper)
|
||||
target_insn.name = "i." + source_insn.name
|
||||
target_insn.name = "i." + source_insn.name
|
||||
value_map[source_insn] = target_insn
|
||||
target_block.append(target_insn)
|
||||
|
||||
|
@ -54,6 +54,7 @@ class BoolOpT(ast.BoolOp, commontyped):
|
||||
class CallT(ast.Call, commontyped):
|
||||
"""
|
||||
:ivar iodelay: (:class:`iodelay.Expr`)
|
||||
:ivar arg_exprs: (dict of str to :class:`iodelay.Expr`)
|
||||
"""
|
||||
class CompareT(ast.Compare, commontyped):
|
||||
pass
|
||||
|
@ -5,7 +5,7 @@ of the ARTIQ compiler.
|
||||
|
||||
from collections import OrderedDict
|
||||
from pythonparser import ast
|
||||
from . import types, builtins
|
||||
from . import types, builtins, iodelay
|
||||
|
||||
# Generic SSA IR classes
|
||||
|
||||
@ -939,6 +939,8 @@ class Call(Instruction):
|
||||
"""
|
||||
A function call operation.
|
||||
|
||||
:ivar arg_exprs: (dict of str to `iodelay.Expr`)
|
||||
iodelay expressions for values of arguments
|
||||
:ivar static_target_function: (:class:`Function` or None)
|
||||
statically resolved callee
|
||||
"""
|
||||
@ -946,15 +948,21 @@ class Call(Instruction):
|
||||
"""
|
||||
:param func: (:class:`Value`) function to call
|
||||
:param args: (list of :class:`Value`) function arguments
|
||||
:param arg_exprs: (dict of str to `iodelay.Expr`)
|
||||
"""
|
||||
def __init__(self, func, args, name=""):
|
||||
def __init__(self, func, args, arg_exprs, name=""):
|
||||
assert isinstance(func, Value)
|
||||
for arg in args: assert isinstance(arg, Value)
|
||||
for arg in arg_exprs:
|
||||
assert isinstance(arg, str)
|
||||
assert isinstance(arg_exprs[arg], iodelay.Expr)
|
||||
super().__init__([func] + args, func.type.ret, name)
|
||||
self.arg_exprs = arg_exprs
|
||||
self.static_target_function = None
|
||||
|
||||
def copy(self, mapper):
|
||||
self_copy = super().copy(mapper)
|
||||
self_copy.arg_exprs = self.arg_exprs
|
||||
self_copy.static_target_function = self.static_target_function
|
||||
return self_copy
|
||||
|
||||
@ -1195,6 +1203,8 @@ class Invoke(Terminator):
|
||||
"""
|
||||
A function call operation that supports exception handling.
|
||||
|
||||
:ivar arg_exprs: (dict of str to `iodelay.Expr`)
|
||||
iodelay expressions for values of arguments
|
||||
:ivar static_target_function: (:class:`Function` or None)
|
||||
statically resolved callee
|
||||
"""
|
||||
@ -1204,17 +1214,23 @@ class Invoke(Terminator):
|
||||
:param args: (list of :class:`Value`) function arguments
|
||||
:param normal: (:class:`BasicBlock`) normal target
|
||||
:param exn: (:class:`BasicBlock`) exceptional target
|
||||
:param arg_exprs: (dict of str to `iodelay.Expr`)
|
||||
"""
|
||||
def __init__(self, func, args, normal, exn, name=""):
|
||||
def __init__(self, func, args, arg_exprs, normal, exn, name=""):
|
||||
assert isinstance(func, Value)
|
||||
for arg in args: assert isinstance(arg, Value)
|
||||
assert isinstance(normal, BasicBlock)
|
||||
assert isinstance(exn, BasicBlock)
|
||||
for arg in arg_exprs:
|
||||
assert isinstance(arg, str)
|
||||
assert isinstance(arg_exprs[arg], iodelay.Expr)
|
||||
super().__init__([func] + args + [normal, exn], func.type.ret, name)
|
||||
self.arg_exprs = arg_exprs
|
||||
self.static_target_function = None
|
||||
|
||||
def copy(self, mapper):
|
||||
self_copy = super().copy(mapper)
|
||||
self_copy.arg_exprs = self.arg_exprs
|
||||
self_copy.static_target_function = self.static_target_function
|
||||
return self_copy
|
||||
|
||||
@ -1299,31 +1315,24 @@ class Delay(Terminator):
|
||||
inlining could lead to the expression folding to a constant.
|
||||
|
||||
:ivar interval: (:class:`iodelay.Expr`) expression
|
||||
:ivar var_names: (list of string)
|
||||
iodelay variable names corresponding to SSA values
|
||||
"""
|
||||
|
||||
"""
|
||||
:param interval: (:class:`iodelay.Expr`) expression
|
||||
:param substs: (dict of str to :class:`Value`)
|
||||
SSA values corresponding to iodelay variable names
|
||||
:param call: (:class:`Call` or ``Constant(None, builtins.TNone())``)
|
||||
the call instruction that caused this delay, if any
|
||||
:param target: (:class:`BasicBlock`) branch target
|
||||
"""
|
||||
def __init__(self, interval, substs, decomposition, target, name=""):
|
||||
for var_name in substs: assert isinstance(var_name, str)
|
||||
def __init__(self, interval, decomposition, target, name=""):
|
||||
assert isinstance(decomposition, Call) or \
|
||||
isinstance(decomposition, Builtin) and decomposition.op in ("delay", "delay_mu")
|
||||
assert isinstance(target, BasicBlock)
|
||||
super().__init__([decomposition, target, *substs.values()], builtins.TNone(), name)
|
||||
super().__init__([decomposition, target], builtins.TNone(), name)
|
||||
self.interval = interval
|
||||
self.var_names = list(substs.keys())
|
||||
|
||||
def copy(self, mapper):
|
||||
self_copy = super().copy(mapper)
|
||||
self_copy.interval = self.interval
|
||||
self_copy.var_names = list(self.var_names)
|
||||
return self_copy
|
||||
|
||||
def decomposition(self):
|
||||
@ -1342,17 +1351,9 @@ class Delay(Terminator):
|
||||
self.operands[1] = new_target
|
||||
self.operands[1].uses.add(self)
|
||||
|
||||
def substs(self):
|
||||
return {key: value for key, value in zip(self.var_names, self.operands[2:])}
|
||||
|
||||
def _operands_as_string(self, type_printer):
|
||||
substs = self.substs()
|
||||
substs_as_strings = []
|
||||
for var_name in substs:
|
||||
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
||||
result = "[{}]".format(", ".join(substs_as_strings))
|
||||
result += ", decomp {}, to {}".format(self.decomposition().as_operand(type_printer),
|
||||
self.target().as_operand(type_printer))
|
||||
result = "decomp {}, to {}".format(self.decomposition().as_operand(type_printer),
|
||||
self.target().as_operand(type_printer))
|
||||
return result
|
||||
|
||||
def opcode(self):
|
||||
@ -1367,14 +1368,10 @@ class Loop(Terminator):
|
||||
|
||||
:ivar trip_count: (:class:`iodelay.Expr`)
|
||||
expression for trip count
|
||||
:ivar var_names: (list of string)
|
||||
iodelay variable names corresponding to ``trip_count`` operands
|
||||
"""
|
||||
|
||||
"""
|
||||
:param trip_count: (:class:`iodelay.Expr`) expression
|
||||
:param substs: (dict of str to :class:`Value`)
|
||||
SSA values corresponding to iodelay variable names
|
||||
:param indvar: (:class:`Phi`)
|
||||
phi node corresponding to the induction SSA value,
|
||||
which advances from ``0`` to ``trip_count - 1``
|
||||
@ -1382,21 +1379,18 @@ class Loop(Terminator):
|
||||
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
|
||||
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
|
||||
"""
|
||||
def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""):
|
||||
for var_name in substs: assert isinstance(var_name, str)
|
||||
def __init__(self, trip_count, indvar, cond, if_true, if_false, name=""):
|
||||
assert isinstance(indvar, Phi)
|
||||
assert isinstance(cond, Value)
|
||||
assert builtins.is_bool(cond.type)
|
||||
assert isinstance(if_true, BasicBlock)
|
||||
assert isinstance(if_false, BasicBlock)
|
||||
super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
|
||||
super().__init__([indvar, cond, if_true, if_false], builtins.TNone(), name)
|
||||
self.trip_count = trip_count
|
||||
self.var_names = list(substs.keys())
|
||||
|
||||
def copy(self, mapper):
|
||||
self_copy = super().copy(mapper)
|
||||
self_copy.trip_count = self.trip_count
|
||||
self_copy.var_names = list(self.var_names)
|
||||
return self_copy
|
||||
|
||||
def induction_variable(self):
|
||||
@ -1411,17 +1405,9 @@ class Loop(Terminator):
|
||||
def if_false(self):
|
||||
return self.operands[3]
|
||||
|
||||
def substs(self):
|
||||
return {key: value for key, value in zip(self.var_names, self.operands[4:])}
|
||||
|
||||
def _operands_as_string(self, type_printer):
|
||||
substs = self.substs()
|
||||
substs_as_strings = []
|
||||
for var_name in substs:
|
||||
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
||||
result = "[{}]".format(", ".join(substs_as_strings))
|
||||
result += ", indvar {}, if {}, {}, {}".format(
|
||||
*list(map(lambda value: value.as_operand(type_printer), self.operands[0:4])))
|
||||
result = "indvar {}, if {}, {}, {}".format(
|
||||
*list(map(lambda value: value.as_operand(type_printer), self.operands)))
|
||||
return result
|
||||
|
||||
def opcode(self):
|
||||
|
@ -526,9 +526,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
else_tail = tail
|
||||
|
||||
if node.trip_count is not None:
|
||||
substs = {var_name: self.current_args[var_name]
|
||||
for var_name in node.trip_count.free_vars()}
|
||||
head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail))
|
||||
head.append(ir.Loop(node.trip_count, phi, cond, body, else_tail))
|
||||
else:
|
||||
head.append(ir.BranchIf(cond, body, else_tail))
|
||||
if not post_body.is_terminated():
|
||||
@ -1659,10 +1657,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
assert None not in args
|
||||
|
||||
if self.unwind_target is None:
|
||||
insn = self.append(ir.Call(func, args))
|
||||
insn = self.append(ir.Call(func, args, node.arg_exprs))
|
||||
else:
|
||||
after_invoke = self.add_block()
|
||||
insn = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target))
|
||||
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
|
||||
after_invoke, self.unwind_target))
|
||||
self.current_block = after_invoke
|
||||
|
||||
method_key = None
|
||||
@ -1672,9 +1671,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
|
||||
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
||||
after_delay = self.add_block()
|
||||
substs = {var_name: self.current_args[var_name]
|
||||
for var_name in node.iodelay.free_vars()}
|
||||
self.append(ir.Delay(node.iodelay, substs, insn, after_delay))
|
||||
self.append(ir.Delay(node.iodelay, insn, after_delay))
|
||||
self.current_block = after_delay
|
||||
|
||||
return insn
|
||||
|
@ -426,7 +426,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
|
||||
def visit_Call(self, node):
|
||||
node = self.generic_visit(node)
|
||||
node = asttyped.CallT(type=types.TVar(), iodelay=None,
|
||||
node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={},
|
||||
func=node.func, args=node.args, keywords=node.keywords,
|
||||
starargs=node.starargs, kwargs=node.kwargs,
|
||||
star_loc=node.star_loc, dstar_loc=node.dstar_loc,
|
||||
|
@ -297,8 +297,8 @@ class IODelayEstimator(algorithm.Visitor):
|
||||
args[arg_name] = arg_node
|
||||
|
||||
free_vars = delay.duration.free_vars()
|
||||
call_delay = delay.duration.fold(
|
||||
{ arg: self.evaluate(args[arg], abort=abort) for arg in free_vars })
|
||||
node.arg_exprs = { arg: self.evaluate(args[arg], abort=abort) for arg in free_vars }
|
||||
call_delay = delay.duration.fold(node.arg_exprs)
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user