forked from M-Labs/artiq
Make delay component of function type unifyable.
This commit is contained in:
parent
3e1348a084
commit
7a6fc3983c
|
@ -4,9 +4,6 @@ the statically inferred RTIO delay arising from executing
|
|||
a function.
|
||||
"""
|
||||
|
||||
from pythonparser import diagnostic
|
||||
|
||||
|
||||
class Expr:
|
||||
def __add__(lhs, rhs):
|
||||
assert isinstance(rhs, Expr)
|
||||
|
@ -227,59 +224,3 @@ def is_const(expr, value=None):
|
|||
|
||||
def is_zero(expr):
|
||||
return is_const(expr, 0)
|
||||
|
||||
|
||||
class Delay:
|
||||
pass
|
||||
|
||||
class Unknown(Delay):
|
||||
"""
|
||||
Unknown delay, that is, IO delay that we have not
|
||||
tried to compute yet.
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return "{}.Unknown()".format(__name__)
|
||||
|
||||
class Indeterminate(Delay):
|
||||
"""
|
||||
Indeterminate delay, that is, IO delay that can vary from
|
||||
invocation to invocation.
|
||||
|
||||
:ivar cause: (:class:`pythonparser.diagnostic.Diagnostic`)
|
||||
reason for the delay not being inferred
|
||||
"""
|
||||
|
||||
def __init__(self, cause):
|
||||
assert isinstance(cause, diagnostic.Diagnostic)
|
||||
self.cause = cause
|
||||
|
||||
def __repr__(self):
|
||||
return "<{}.Indeterminate>".format(__name__)
|
||||
|
||||
class Fixed(Delay):
|
||||
"""
|
||||
Fixed delay, that is, IO delay that is always the same
|
||||
for every invocation.
|
||||
|
||||
:ivar length: (int) delay in machine units
|
||||
"""
|
||||
|
||||
def __init__(self, length):
|
||||
assert isinstance(length, Expr)
|
||||
self.length = length
|
||||
|
||||
def __repr__(self):
|
||||
return "{}.Fixed({})".format(__name__, self.length)
|
||||
|
||||
def is_unknown(delay):
|
||||
return isinstance(delay, Unknown)
|
||||
|
||||
def is_indeterminate(delay):
|
||||
return isinstance(delay, Indeterminate)
|
||||
|
||||
def is_fixed(delay, length=None):
|
||||
if length is None:
|
||||
return isinstance(delay, Fixed)
|
||||
else:
|
||||
return isinstance(delay, Fixed) and delay.length == length
|
||||
|
|
|
@ -50,7 +50,8 @@ class Module:
|
|||
inferencer = transforms.Inferencer(engine=self.engine)
|
||||
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
|
||||
escape_validator = validators.EscapeValidator(engine=self.engine)
|
||||
iodelay_estimator = transforms.IODelayEstimator(ref_period=ref_period)
|
||||
iodelay_estimator = transforms.IODelayEstimator(engine=self.engine,
|
||||
ref_period=ref_period)
|
||||
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine,
|
||||
module_name=src.name,
|
||||
ref_period=ref_period)
|
||||
|
|
|
@ -32,8 +32,8 @@ def main():
|
|||
if force_delays:
|
||||
for var in mod.globals:
|
||||
typ = mod.globals[var].find()
|
||||
if types.is_function(typ) and iodelay.is_indeterminate(typ.delay):
|
||||
process_diagnostic(typ.delay.cause)
|
||||
if types.is_function(typ) and types.is_indeterminate_delay(typ.delay):
|
||||
process_diagnostic(typ.delay.find().cause)
|
||||
|
||||
print(repr(mod))
|
||||
except:
|
||||
|
|
|
@ -11,7 +11,8 @@ class _UnknownDelay(Exception):
|
|||
pass
|
||||
|
||||
class IODelayEstimator(algorithm.Visitor):
|
||||
def __init__(self, ref_period):
|
||||
def __init__(self, engine, ref_period):
|
||||
self.engine = engine
|
||||
self.ref_period = ref_period
|
||||
self.changed = False
|
||||
self.current_delay = iodelay.Const(0)
|
||||
|
@ -83,7 +84,7 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
except (diagnostic.Error, _UnknownDelay):
|
||||
pass # we don't care; module-level code is never interleaved
|
||||
|
||||
def visit_function(self, args, body, typ):
|
||||
def visit_function(self, args, body, typ, loc):
|
||||
old_args, self.current_args = self.current_args, args
|
||||
old_return, self.current_return = self.current_return, None
|
||||
old_delay, self.current_delay = self.current_delay, iodelay.Const(0)
|
||||
|
@ -93,19 +94,23 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
self.abort("only return statement at the end of the function "
|
||||
"can be interleaved", self.current_return.loc)
|
||||
|
||||
delay = iodelay.Fixed(self.current_delay.fold())
|
||||
delay = types.TFixedDelay(self.current_delay.fold())
|
||||
except diagnostic.Error as error:
|
||||
delay = iodelay.Indeterminate(error.diagnostic)
|
||||
delay = types.TIndeterminateDelay(error.diagnostic)
|
||||
self.current_delay = old_delay
|
||||
self.current_return = old_return
|
||||
self.current_args = old_args
|
||||
|
||||
if iodelay.is_unknown(typ.delay) or iodelay.is_indeterminate(typ.delay):
|
||||
typ.delay = delay
|
||||
elif iodelay.is_fixed(typ.delay):
|
||||
assert typ.delay.value == delay.value
|
||||
else:
|
||||
assert False
|
||||
try:
|
||||
typ.delay.unify(delay)
|
||||
except types.UnificationError as e:
|
||||
printer = types.TypePrinter()
|
||||
diag = diagnostic.Diagnostic("fatal",
|
||||
"delay {delaya} was inferred for this function, but its delay is already "
|
||||
"constrained externally to {delayb}",
|
||||
{"delaya": printer.name(delay), "delayb": printer.name(typ.delay)},
|
||||
loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def visit_FunctionDefT(self, node):
|
||||
self.visit(node.args.defaults)
|
||||
|
@ -116,10 +121,10 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
body = node.body[:-1]
|
||||
else:
|
||||
body = node.body
|
||||
self.visit_function(node.args, body, node.signature_type.find())
|
||||
self.visit_function(node.args, body, node.signature_type.find(), node.loc)
|
||||
|
||||
def visit_LambdaT(self, node):
|
||||
self.visit_function(node.args, node.body, node.type.find())
|
||||
self.visit_function(node.args, node.body, node.type.find(), node.loc)
|
||||
|
||||
def get_iterable_length(self, node):
|
||||
def abort(notes):
|
||||
|
@ -225,10 +230,10 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
else:
|
||||
assert False
|
||||
|
||||
delay = typ.find().delay
|
||||
if iodelay.is_unknown(delay):
|
||||
delay = typ.find().delay.find()
|
||||
if types.is_var(delay):
|
||||
raise _UnknownDelay()
|
||||
elif iodelay.is_indeterminate(delay):
|
||||
elif delay.is_indeterminate():
|
||||
cause = delay.cause
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"function called here", {},
|
||||
|
@ -236,15 +241,15 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
diag = diagnostic.Diagnostic(cause.level, cause.reason, cause.arguments,
|
||||
cause.location, cause.highlights, cause.notes + [note])
|
||||
raise diagnostic.Error(diag)
|
||||
elif iodelay.is_fixed(delay):
|
||||
elif delay.is_fixed():
|
||||
args = {}
|
||||
for kw_node in node.keywords:
|
||||
args[kw_node.arg] = kw_node.value
|
||||
for arg_name, arg_node in zip(typ.args, node.args[offset:]):
|
||||
args[arg_name] = arg_node
|
||||
|
||||
free_vars = delay.length.free_vars()
|
||||
self.current_delay += delay.length.fold(
|
||||
free_vars = delay.duration.free_vars()
|
||||
self.current_delay += delay.duration.fold(
|
||||
{ arg: self.evaluate(args[arg], abort=abort) for arg in free_vars })
|
||||
else:
|
||||
assert False
|
||||
|
|
|
@ -5,6 +5,7 @@ in :mod:`asttyped`.
|
|||
|
||||
import string
|
||||
from collections import OrderedDict
|
||||
from pythonparser import diagnostic
|
||||
from . import iodelay
|
||||
|
||||
|
||||
|
@ -192,8 +193,8 @@ class TFunction(Type):
|
|||
optional arguments
|
||||
:ivar ret: (:class:`Type`)
|
||||
return type
|
||||
:ivar delay: (:class:`iodelay.Delay`)
|
||||
RTIO delay expression
|
||||
:ivar delay: (:class:`Type`)
|
||||
RTIO delay
|
||||
"""
|
||||
|
||||
attributes = OrderedDict([
|
||||
|
@ -201,12 +202,12 @@ class TFunction(Type):
|
|||
('__closure__', _TPointer()),
|
||||
])
|
||||
|
||||
def __init__(self, args, optargs, ret, delay=iodelay.Unknown()):
|
||||
def __init__(self, args, optargs, ret):
|
||||
assert isinstance(args, OrderedDict)
|
||||
assert isinstance(optargs, OrderedDict)
|
||||
assert isinstance(ret, Type)
|
||||
assert isinstance(delay, iodelay.Delay)
|
||||
self.args, self.optargs, self.ret, self.delay = args, optargs, ret, delay
|
||||
self.args, self.optargs, self.ret = args, optargs, ret
|
||||
self.delay = TVar()
|
||||
|
||||
def arity(self):
|
||||
return len(self.args) + len(self.optargs)
|
||||
|
@ -222,6 +223,7 @@ class TFunction(Type):
|
|||
list(other.args.values()) + list(other.optargs.values())):
|
||||
selfarg.unify(otherarg)
|
||||
self.ret.unify(other.ret)
|
||||
self.delay.unify(other.delay)
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
else:
|
||||
|
@ -261,7 +263,7 @@ class TRPCFunction(TFunction):
|
|||
|
||||
def __init__(self, args, optargs, ret, service):
|
||||
super().__init__(args, optargs, ret,
|
||||
delay=iodelay.Fixed(iodelay.Constant(0)))
|
||||
delay=FixedDelay(iodelay.Constant(0)))
|
||||
self.service = service
|
||||
|
||||
def unify(self, other):
|
||||
|
@ -284,7 +286,7 @@ class TCFunction(TFunction):
|
|||
|
||||
def __init__(self, args, ret, name):
|
||||
super().__init__(args, OrderedDict(), ret,
|
||||
delay=iodelay.Fixed(iodelay.Constant(0)))
|
||||
delay=FixedDelay(iodelay.Constant(0)))
|
||||
self.name = name
|
||||
|
||||
def unify(self, other):
|
||||
|
@ -418,6 +420,63 @@ class TValue(Type):
|
|||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
class TDelay(Type):
|
||||
"""
|
||||
The type-level representation of IO delay.
|
||||
"""
|
||||
|
||||
def __init__(self, duration, cause):
|
||||
assert duration is None or isinstance(duration, iodelay.Expr)
|
||||
assert cause is None or isinstance(cause, diagnostic.Diagnostic)
|
||||
assert (not (duration and cause)) and (duration or cause)
|
||||
self.duration, self.cause = duration, cause
|
||||
|
||||
def is_fixed(self):
|
||||
return self.duration is not None
|
||||
|
||||
def is_indeterminate(self):
|
||||
return self.cause is not None
|
||||
|
||||
def find(self):
|
||||
return self
|
||||
|
||||
def unify(self, other):
|
||||
other = other.find()
|
||||
|
||||
if self.is_fixed() and other.is_fixed() and \
|
||||
self.duration.fold() == other.duration.fold():
|
||||
pass
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
else:
|
||||
raise UnificationError(self, other)
|
||||
|
||||
def fold(self, accum, fn):
|
||||
# delay types do not participate in folding
|
||||
pass
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TDelay) and \
|
||||
(self.duration == other.duration and \
|
||||
self.cause == other.cause)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __repr__(self):
|
||||
if self.duration is None:
|
||||
return "<{}.TIndeterminateDelay>".format(__name__)
|
||||
elif self.cause is None:
|
||||
return "{}.TFixedDelay({})".format(__name__, self.duration)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def TIndeterminateDelay(cause):
|
||||
return TDelay(None, cause)
|
||||
|
||||
def TFixedDelay(duration):
|
||||
return TDelay(duration, None)
|
||||
|
||||
|
||||
def is_var(typ):
|
||||
return isinstance(typ.find(), TVar)
|
||||
|
@ -511,6 +570,16 @@ def get_value(typ):
|
|||
else:
|
||||
assert False
|
||||
|
||||
def is_delay(typ):
|
||||
return isinstance(typ.find(), TDelay)
|
||||
|
||||
def is_fixed_delay(typ):
|
||||
return is_delay(typ) and typ.find().is_fixed()
|
||||
|
||||
def is_indeterminate_delay(typ):
|
||||
return is_delay(typ) and typ.find().is_indeterminate()
|
||||
|
||||
|
||||
class TypePrinter(object):
|
||||
"""
|
||||
A class that prints types using Python-like syntax and gives
|
||||
|
@ -553,14 +622,10 @@ class TypePrinter(object):
|
|||
args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs]
|
||||
signature = "(%s)->%s" % (", ".join(args), self.name(typ.ret))
|
||||
|
||||
if iodelay.is_unknown(typ.delay) or iodelay.is_fixed(typ.delay, 0):
|
||||
pass
|
||||
elif iodelay.is_fixed(typ.delay):
|
||||
signature += " delay({} mu)".format(typ.delay.length)
|
||||
elif iodelay.is_indeterminate(typ.delay):
|
||||
signature += " delay(?)"
|
||||
else:
|
||||
assert False
|
||||
delay = typ.delay.find()
|
||||
if not (isinstance(delay, TVar) or
|
||||
delay.is_fixed() and iodelay.is_zero(delay.duration)):
|
||||
signature += " " + self.name(delay)
|
||||
|
||||
if isinstance(typ, TRPCFunction):
|
||||
return "rpc({}) {}".format(typ.service, signature)
|
||||
|
@ -580,5 +645,12 @@ class TypePrinter(object):
|
|||
return "<constructor {} {{{}}}>".format(typ.name, attrs)
|
||||
elif isinstance(typ, TValue):
|
||||
return repr(typ.value)
|
||||
elif isinstance(typ, TDelay):
|
||||
if typ.is_fixed():
|
||||
return "delay({} mu)".format(typ.duration)
|
||||
elif typ.is_indeterminate():
|
||||
return "delay(?)"
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
assert False
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
from artiq.language.units import *
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
def f():
|
||||
delay_mu(10)
|
||||
|
||||
# CHECK-L: ${LINE:+1}: fatal: delay delay(20 mu) was inferred for this function, but its delay is already constrained externally to delay(10 mu)
|
||||
def g():
|
||||
delay_mu(20)
|
||||
|
||||
x = f if True else g
|
|
@ -7,7 +7,8 @@ def f():
|
|||
delay(1.5)
|
||||
return 10
|
||||
|
||||
# CHECK-L: g: (x:float)->int(width=32) delay(0 mu)
|
||||
# CHECK-L: g: (x:float)->int(width=32)
|
||||
# CHECK-NOT-L: delay
|
||||
def g(x):
|
||||
if x > 1.0:
|
||||
return 1
|
||||
|
|
Loading…
Reference in New Issue