forked from M-Labs/artiq
compiler: do not associate SSA values with iodelay even when inlining.
Fixes #201.
This commit is contained in:
parent
33c3b3377e
commit
082e9e20dd
|
@ -51,20 +51,19 @@ def inline(call_insn):
|
||||||
elif isinstance(source_insn, ir.Phi):
|
elif isinstance(source_insn, ir.Phi):
|
||||||
target_insn = ir.Phi()
|
target_insn = ir.Phi()
|
||||||
elif isinstance(source_insn, ir.Delay):
|
elif isinstance(source_insn, ir.Delay):
|
||||||
substs = source_insn.substs()
|
target_insn = source_insn.copy(mapper)
|
||||||
mapped_substs = {var: value_map[substs[var]] for var in substs}
|
target_insn.interval = source_insn.interval.fold(call_insn.arg_exprs)
|
||||||
const_substs = {var: iodelay.Const(mapped_substs[var].value)
|
elif isinstance(source_insn, ir.Loop):
|
||||||
for var in mapped_substs
|
target_insn = source_insn.copy(mapper)
|
||||||
if isinstance(mapped_substs[var], ir.Constant)}
|
target_insn.trip_count = source_insn.trip_count.fold(call_insn.arg_exprs)
|
||||||
other_substs = {var: mapped_substs[var]
|
elif isinstance(source_insn, ir.Call):
|
||||||
for var in mapped_substs
|
target_insn = source_insn.copy(mapper)
|
||||||
if not isinstance(mapped_substs[var], ir.Constant)}
|
target_insn.arg_exprs = \
|
||||||
target_insn = ir.Delay(source_insn.interval.fold(const_substs), other_substs,
|
{ arg: source_insn.arg_exprs[arg].fold(call_insn.arg_exprs)
|
||||||
value_map[source_insn.decomposition()],
|
for arg in source_insn.arg_exprs }
|
||||||
value_map[source_insn.target()])
|
|
||||||
else:
|
else:
|
||||||
target_insn = source_insn.copy(mapper)
|
target_insn = source_insn.copy(mapper)
|
||||||
target_insn.name = "i." + source_insn.name
|
target_insn.name = "i." + source_insn.name
|
||||||
value_map[source_insn] = target_insn
|
value_map[source_insn] = target_insn
|
||||||
target_block.append(target_insn)
|
target_block.append(target_insn)
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ class BoolOpT(ast.BoolOp, commontyped):
|
||||||
class CallT(ast.Call, commontyped):
|
class CallT(ast.Call, commontyped):
|
||||||
"""
|
"""
|
||||||
:ivar iodelay: (:class:`iodelay.Expr`)
|
:ivar iodelay: (:class:`iodelay.Expr`)
|
||||||
|
:ivar arg_exprs: (dict of str to :class:`iodelay.Expr`)
|
||||||
"""
|
"""
|
||||||
class CompareT(ast.Compare, commontyped):
|
class CompareT(ast.Compare, commontyped):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -5,7 +5,7 @@ of the ARTIQ compiler.
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pythonparser import ast
|
from pythonparser import ast
|
||||||
from . import types, builtins
|
from . import types, builtins, iodelay
|
||||||
|
|
||||||
# Generic SSA IR classes
|
# Generic SSA IR classes
|
||||||
|
|
||||||
|
@ -939,6 +939,8 @@ class Call(Instruction):
|
||||||
"""
|
"""
|
||||||
A function call operation.
|
A function call operation.
|
||||||
|
|
||||||
|
:ivar arg_exprs: (dict of str to `iodelay.Expr`)
|
||||||
|
iodelay expressions for values of arguments
|
||||||
:ivar static_target_function: (:class:`Function` or None)
|
:ivar static_target_function: (:class:`Function` or None)
|
||||||
statically resolved callee
|
statically resolved callee
|
||||||
"""
|
"""
|
||||||
|
@ -946,15 +948,21 @@ class Call(Instruction):
|
||||||
"""
|
"""
|
||||||
:param func: (:class:`Value`) function to call
|
:param func: (:class:`Value`) function to call
|
||||||
:param args: (list of :class:`Value`) function arguments
|
:param args: (list of :class:`Value`) function arguments
|
||||||
|
:param arg_exprs: (dict of str to `iodelay.Expr`)
|
||||||
"""
|
"""
|
||||||
def __init__(self, func, args, name=""):
|
def __init__(self, func, args, arg_exprs, name=""):
|
||||||
assert isinstance(func, Value)
|
assert isinstance(func, Value)
|
||||||
for arg in args: assert isinstance(arg, Value)
|
for arg in args: assert isinstance(arg, Value)
|
||||||
|
for arg in arg_exprs:
|
||||||
|
assert isinstance(arg, str)
|
||||||
|
assert isinstance(arg_exprs[arg], iodelay.Expr)
|
||||||
super().__init__([func] + args, func.type.ret, name)
|
super().__init__([func] + args, func.type.ret, name)
|
||||||
|
self.arg_exprs = arg_exprs
|
||||||
self.static_target_function = None
|
self.static_target_function = None
|
||||||
|
|
||||||
def copy(self, mapper):
|
def copy(self, mapper):
|
||||||
self_copy = super().copy(mapper)
|
self_copy = super().copy(mapper)
|
||||||
|
self_copy.arg_exprs = self.arg_exprs
|
||||||
self_copy.static_target_function = self.static_target_function
|
self_copy.static_target_function = self.static_target_function
|
||||||
return self_copy
|
return self_copy
|
||||||
|
|
||||||
|
@ -1195,6 +1203,8 @@ class Invoke(Terminator):
|
||||||
"""
|
"""
|
||||||
A function call operation that supports exception handling.
|
A function call operation that supports exception handling.
|
||||||
|
|
||||||
|
:ivar arg_exprs: (dict of str to `iodelay.Expr`)
|
||||||
|
iodelay expressions for values of arguments
|
||||||
:ivar static_target_function: (:class:`Function` or None)
|
:ivar static_target_function: (:class:`Function` or None)
|
||||||
statically resolved callee
|
statically resolved callee
|
||||||
"""
|
"""
|
||||||
|
@ -1204,17 +1214,23 @@ class Invoke(Terminator):
|
||||||
:param args: (list of :class:`Value`) function arguments
|
:param args: (list of :class:`Value`) function arguments
|
||||||
:param normal: (:class:`BasicBlock`) normal target
|
:param normal: (:class:`BasicBlock`) normal target
|
||||||
:param exn: (:class:`BasicBlock`) exceptional target
|
:param exn: (:class:`BasicBlock`) exceptional target
|
||||||
|
:param arg_exprs: (dict of str to `iodelay.Expr`)
|
||||||
"""
|
"""
|
||||||
def __init__(self, func, args, normal, exn, name=""):
|
def __init__(self, func, args, arg_exprs, normal, exn, name=""):
|
||||||
assert isinstance(func, Value)
|
assert isinstance(func, Value)
|
||||||
for arg in args: assert isinstance(arg, Value)
|
for arg in args: assert isinstance(arg, Value)
|
||||||
assert isinstance(normal, BasicBlock)
|
assert isinstance(normal, BasicBlock)
|
||||||
assert isinstance(exn, BasicBlock)
|
assert isinstance(exn, BasicBlock)
|
||||||
|
for arg in arg_exprs:
|
||||||
|
assert isinstance(arg, str)
|
||||||
|
assert isinstance(arg_exprs[arg], iodelay.Expr)
|
||||||
super().__init__([func] + args + [normal, exn], func.type.ret, name)
|
super().__init__([func] + args + [normal, exn], func.type.ret, name)
|
||||||
|
self.arg_exprs = arg_exprs
|
||||||
self.static_target_function = None
|
self.static_target_function = None
|
||||||
|
|
||||||
def copy(self, mapper):
|
def copy(self, mapper):
|
||||||
self_copy = super().copy(mapper)
|
self_copy = super().copy(mapper)
|
||||||
|
self_copy.arg_exprs = self.arg_exprs
|
||||||
self_copy.static_target_function = self.static_target_function
|
self_copy.static_target_function = self.static_target_function
|
||||||
return self_copy
|
return self_copy
|
||||||
|
|
||||||
|
@ -1299,31 +1315,24 @@ class Delay(Terminator):
|
||||||
inlining could lead to the expression folding to a constant.
|
inlining could lead to the expression folding to a constant.
|
||||||
|
|
||||||
:ivar interval: (:class:`iodelay.Expr`) expression
|
:ivar interval: (:class:`iodelay.Expr`) expression
|
||||||
:ivar var_names: (list of string)
|
|
||||||
iodelay variable names corresponding to SSA values
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
:param interval: (:class:`iodelay.Expr`) expression
|
:param interval: (:class:`iodelay.Expr`) expression
|
||||||
:param substs: (dict of str to :class:`Value`)
|
|
||||||
SSA values corresponding to iodelay variable names
|
|
||||||
:param call: (:class:`Call` or ``Constant(None, builtins.TNone())``)
|
:param call: (:class:`Call` or ``Constant(None, builtins.TNone())``)
|
||||||
the call instruction that caused this delay, if any
|
the call instruction that caused this delay, if any
|
||||||
:param target: (:class:`BasicBlock`) branch target
|
:param target: (:class:`BasicBlock`) branch target
|
||||||
"""
|
"""
|
||||||
def __init__(self, interval, substs, decomposition, target, name=""):
|
def __init__(self, interval, decomposition, target, name=""):
|
||||||
for var_name in substs: assert isinstance(var_name, str)
|
|
||||||
assert isinstance(decomposition, Call) or \
|
assert isinstance(decomposition, Call) or \
|
||||||
isinstance(decomposition, Builtin) and decomposition.op in ("delay", "delay_mu")
|
isinstance(decomposition, Builtin) and decomposition.op in ("delay", "delay_mu")
|
||||||
assert isinstance(target, BasicBlock)
|
assert isinstance(target, BasicBlock)
|
||||||
super().__init__([decomposition, target, *substs.values()], builtins.TNone(), name)
|
super().__init__([decomposition, target], builtins.TNone(), name)
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.var_names = list(substs.keys())
|
|
||||||
|
|
||||||
def copy(self, mapper):
|
def copy(self, mapper):
|
||||||
self_copy = super().copy(mapper)
|
self_copy = super().copy(mapper)
|
||||||
self_copy.interval = self.interval
|
self_copy.interval = self.interval
|
||||||
self_copy.var_names = list(self.var_names)
|
|
||||||
return self_copy
|
return self_copy
|
||||||
|
|
||||||
def decomposition(self):
|
def decomposition(self):
|
||||||
|
@ -1342,17 +1351,9 @@ class Delay(Terminator):
|
||||||
self.operands[1] = new_target
|
self.operands[1] = new_target
|
||||||
self.operands[1].uses.add(self)
|
self.operands[1].uses.add(self)
|
||||||
|
|
||||||
def substs(self):
|
|
||||||
return {key: value for key, value in zip(self.var_names, self.operands[2:])}
|
|
||||||
|
|
||||||
def _operands_as_string(self, type_printer):
|
def _operands_as_string(self, type_printer):
|
||||||
substs = self.substs()
|
result = "decomp {}, to {}".format(self.decomposition().as_operand(type_printer),
|
||||||
substs_as_strings = []
|
self.target().as_operand(type_printer))
|
||||||
for var_name in substs:
|
|
||||||
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
|
||||||
result = "[{}]".format(", ".join(substs_as_strings))
|
|
||||||
result += ", decomp {}, to {}".format(self.decomposition().as_operand(type_printer),
|
|
||||||
self.target().as_operand(type_printer))
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def opcode(self):
|
def opcode(self):
|
||||||
|
@ -1367,14 +1368,10 @@ class Loop(Terminator):
|
||||||
|
|
||||||
:ivar trip_count: (:class:`iodelay.Expr`)
|
:ivar trip_count: (:class:`iodelay.Expr`)
|
||||||
expression for trip count
|
expression for trip count
|
||||||
:ivar var_names: (list of string)
|
|
||||||
iodelay variable names corresponding to ``trip_count`` operands
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
:param trip_count: (:class:`iodelay.Expr`) expression
|
:param trip_count: (:class:`iodelay.Expr`) expression
|
||||||
:param substs: (dict of str to :class:`Value`)
|
|
||||||
SSA values corresponding to iodelay variable names
|
|
||||||
:param indvar: (:class:`Phi`)
|
:param indvar: (:class:`Phi`)
|
||||||
phi node corresponding to the induction SSA value,
|
phi node corresponding to the induction SSA value,
|
||||||
which advances from ``0`` to ``trip_count - 1``
|
which advances from ``0`` to ``trip_count - 1``
|
||||||
|
@ -1382,21 +1379,18 @@ class Loop(Terminator):
|
||||||
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
|
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
|
||||||
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
|
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
|
||||||
"""
|
"""
|
||||||
def __init__(self, trip_count, substs, indvar, cond, if_true, if_false, name=""):
|
def __init__(self, trip_count, indvar, cond, if_true, if_false, name=""):
|
||||||
for var_name in substs: assert isinstance(var_name, str)
|
|
||||||
assert isinstance(indvar, Phi)
|
assert isinstance(indvar, Phi)
|
||||||
assert isinstance(cond, Value)
|
assert isinstance(cond, Value)
|
||||||
assert builtins.is_bool(cond.type)
|
assert builtins.is_bool(cond.type)
|
||||||
assert isinstance(if_true, BasicBlock)
|
assert isinstance(if_true, BasicBlock)
|
||||||
assert isinstance(if_false, BasicBlock)
|
assert isinstance(if_false, BasicBlock)
|
||||||
super().__init__([indvar, cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
|
super().__init__([indvar, cond, if_true, if_false], builtins.TNone(), name)
|
||||||
self.trip_count = trip_count
|
self.trip_count = trip_count
|
||||||
self.var_names = list(substs.keys())
|
|
||||||
|
|
||||||
def copy(self, mapper):
|
def copy(self, mapper):
|
||||||
self_copy = super().copy(mapper)
|
self_copy = super().copy(mapper)
|
||||||
self_copy.trip_count = self.trip_count
|
self_copy.trip_count = self.trip_count
|
||||||
self_copy.var_names = list(self.var_names)
|
|
||||||
return self_copy
|
return self_copy
|
||||||
|
|
||||||
def induction_variable(self):
|
def induction_variable(self):
|
||||||
|
@ -1411,17 +1405,9 @@ class Loop(Terminator):
|
||||||
def if_false(self):
|
def if_false(self):
|
||||||
return self.operands[3]
|
return self.operands[3]
|
||||||
|
|
||||||
def substs(self):
|
|
||||||
return {key: value for key, value in zip(self.var_names, self.operands[4:])}
|
|
||||||
|
|
||||||
def _operands_as_string(self, type_printer):
|
def _operands_as_string(self, type_printer):
|
||||||
substs = self.substs()
|
result = "indvar {}, if {}, {}, {}".format(
|
||||||
substs_as_strings = []
|
*list(map(lambda value: value.as_operand(type_printer), self.operands)))
|
||||||
for var_name in substs:
|
|
||||||
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
|
|
||||||
result = "[{}]".format(", ".join(substs_as_strings))
|
|
||||||
result += ", indvar {}, if {}, {}, {}".format(
|
|
||||||
*list(map(lambda value: value.as_operand(type_printer), self.operands[0:4])))
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def opcode(self):
|
def opcode(self):
|
||||||
|
|
|
@ -526,9 +526,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
else_tail = tail
|
else_tail = tail
|
||||||
|
|
||||||
if node.trip_count is not None:
|
if node.trip_count is not None:
|
||||||
substs = {var_name: self.current_args[var_name]
|
head.append(ir.Loop(node.trip_count, phi, cond, body, else_tail))
|
||||||
for var_name in node.trip_count.free_vars()}
|
|
||||||
head.append(ir.Loop(node.trip_count, substs, phi, cond, body, else_tail))
|
|
||||||
else:
|
else:
|
||||||
head.append(ir.BranchIf(cond, body, else_tail))
|
head.append(ir.BranchIf(cond, body, else_tail))
|
||||||
if not post_body.is_terminated():
|
if not post_body.is_terminated():
|
||||||
|
@ -1659,10 +1657,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
assert None not in args
|
assert None not in args
|
||||||
|
|
||||||
if self.unwind_target is None:
|
if self.unwind_target is None:
|
||||||
insn = self.append(ir.Call(func, args))
|
insn = self.append(ir.Call(func, args, node.arg_exprs))
|
||||||
else:
|
else:
|
||||||
after_invoke = self.add_block()
|
after_invoke = self.add_block()
|
||||||
insn = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target))
|
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
|
||||||
|
after_invoke, self.unwind_target))
|
||||||
self.current_block = after_invoke
|
self.current_block = after_invoke
|
||||||
|
|
||||||
method_key = None
|
method_key = None
|
||||||
|
@ -1672,9 +1671,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
|
|
||||||
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
||||||
after_delay = self.add_block()
|
after_delay = self.add_block()
|
||||||
substs = {var_name: self.current_args[var_name]
|
self.append(ir.Delay(node.iodelay, insn, after_delay))
|
||||||
for var_name in node.iodelay.free_vars()}
|
|
||||||
self.append(ir.Delay(node.iodelay, substs, insn, after_delay))
|
|
||||||
self.current_block = after_delay
|
self.current_block = after_delay
|
||||||
|
|
||||||
return insn
|
return insn
|
||||||
|
|
|
@ -426,7 +426,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||||
|
|
||||||
def visit_Call(self, node):
|
def visit_Call(self, node):
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
node = asttyped.CallT(type=types.TVar(), iodelay=None,
|
node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={},
|
||||||
func=node.func, args=node.args, keywords=node.keywords,
|
func=node.func, args=node.args, keywords=node.keywords,
|
||||||
starargs=node.starargs, kwargs=node.kwargs,
|
starargs=node.starargs, kwargs=node.kwargs,
|
||||||
star_loc=node.star_loc, dstar_loc=node.dstar_loc,
|
star_loc=node.star_loc, dstar_loc=node.dstar_loc,
|
||||||
|
|
|
@ -297,8 +297,8 @@ class IODelayEstimator(algorithm.Visitor):
|
||||||
args[arg_name] = arg_node
|
args[arg_name] = arg_node
|
||||||
|
|
||||||
free_vars = delay.duration.free_vars()
|
free_vars = delay.duration.free_vars()
|
||||||
call_delay = delay.duration.fold(
|
node.arg_exprs = { arg: self.evaluate(args[arg], abort=abort) for arg in free_vars }
|
||||||
{ arg: self.evaluate(args[arg], abort=abort) for arg in free_vars })
|
call_delay = delay.duration.fold(node.arg_exprs)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue