Make delay component of function type unifyable.

This commit is contained in:
whitequark 2015-09-30 18:41:14 +03:00
parent 3e1348a084
commit 7a6fc3983c
8 changed files with 128 additions and 96 deletions

View File

@ -4,9 +4,6 @@ the statically inferred RTIO delay arising from executing
a function. a function.
""" """
from pythonparser import diagnostic
class Expr: class Expr:
def __add__(lhs, rhs): def __add__(lhs, rhs):
assert isinstance(rhs, Expr) assert isinstance(rhs, Expr)
@ -227,59 +224,3 @@ def is_const(expr, value=None):
def is_zero(expr): def is_zero(expr):
return is_const(expr, 0) return is_const(expr, 0)
class Delay:
pass
class Unknown(Delay):
"""
Unknown delay, that is, IO delay that we have not
tried to compute yet.
"""
def __repr__(self):
return "{}.Unknown()".format(__name__)
class Indeterminate(Delay):
"""
Indeterminate delay, that is, IO delay that can vary from
invocation to invocation.
:ivar cause: (:class:`pythonparser.diagnostic.Diagnostic`)
reason for the delay not being inferred
"""
def __init__(self, cause):
assert isinstance(cause, diagnostic.Diagnostic)
self.cause = cause
def __repr__(self):
return "<{}.Indeterminate>".format(__name__)
class Fixed(Delay):
"""
Fixed delay, that is, IO delay that is always the same
for every invocation.
:ivar length: (int) delay in machine units
"""
def __init__(self, length):
assert isinstance(length, Expr)
self.length = length
def __repr__(self):
return "{}.Fixed({})".format(__name__, self.length)
def is_unknown(delay):
return isinstance(delay, Unknown)
def is_indeterminate(delay):
return isinstance(delay, Indeterminate)
def is_fixed(delay, length=None):
if length is None:
return isinstance(delay, Fixed)
else:
return isinstance(delay, Fixed) and delay.length == length

View File

@ -50,7 +50,8 @@ class Module:
inferencer = transforms.Inferencer(engine=self.engine) inferencer = transforms.Inferencer(engine=self.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine) monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
escape_validator = validators.EscapeValidator(engine=self.engine) escape_validator = validators.EscapeValidator(engine=self.engine)
iodelay_estimator = transforms.IODelayEstimator(ref_period=ref_period) iodelay_estimator = transforms.IODelayEstimator(engine=self.engine,
ref_period=ref_period)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine, artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine,
module_name=src.name, module_name=src.name,
ref_period=ref_period) ref_period=ref_period)

View File

@ -32,8 +32,8 @@ def main():
if force_delays: if force_delays:
for var in mod.globals: for var in mod.globals:
typ = mod.globals[var].find() typ = mod.globals[var].find()
if types.is_function(typ) and iodelay.is_indeterminate(typ.delay): if types.is_function(typ) and types.is_indeterminate_delay(typ.delay):
process_diagnostic(typ.delay.cause) process_diagnostic(typ.delay.find().cause)
print(repr(mod)) print(repr(mod))
except: except:

View File

@ -11,7 +11,8 @@ class _UnknownDelay(Exception):
pass pass
class IODelayEstimator(algorithm.Visitor): class IODelayEstimator(algorithm.Visitor):
def __init__(self, ref_period): def __init__(self, engine, ref_period):
self.engine = engine
self.ref_period = ref_period self.ref_period = ref_period
self.changed = False self.changed = False
self.current_delay = iodelay.Const(0) self.current_delay = iodelay.Const(0)
@ -83,7 +84,7 @@ class IODelayEstimator(algorithm.Visitor):
except (diagnostic.Error, _UnknownDelay): except (diagnostic.Error, _UnknownDelay):
pass # we don't care; module-level code is never interleaved pass # we don't care; module-level code is never interleaved
def visit_function(self, args, body, typ): def visit_function(self, args, body, typ, loc):
old_args, self.current_args = self.current_args, args old_args, self.current_args = self.current_args, args
old_return, self.current_return = self.current_return, None old_return, self.current_return = self.current_return, None
old_delay, self.current_delay = self.current_delay, iodelay.Const(0) old_delay, self.current_delay = self.current_delay, iodelay.Const(0)
@ -93,19 +94,23 @@ class IODelayEstimator(algorithm.Visitor):
self.abort("only return statement at the end of the function " self.abort("only return statement at the end of the function "
"can be interleaved", self.current_return.loc) "can be interleaved", self.current_return.loc)
delay = iodelay.Fixed(self.current_delay.fold()) delay = types.TFixedDelay(self.current_delay.fold())
except diagnostic.Error as error: except diagnostic.Error as error:
delay = iodelay.Indeterminate(error.diagnostic) delay = types.TIndeterminateDelay(error.diagnostic)
self.current_delay = old_delay self.current_delay = old_delay
self.current_return = old_return self.current_return = old_return
self.current_args = old_args self.current_args = old_args
if iodelay.is_unknown(typ.delay) or iodelay.is_indeterminate(typ.delay): try:
typ.delay = delay typ.delay.unify(delay)
elif iodelay.is_fixed(typ.delay): except types.UnificationError as e:
assert typ.delay.value == delay.value printer = types.TypePrinter()
else: diag = diagnostic.Diagnostic("fatal",
assert False "delay {delaya} was inferred for this function, but its delay is already "
"constrained externally to {delayb}",
{"delaya": printer.name(delay), "delayb": printer.name(typ.delay)},
loc)
self.engine.process(diag)
def visit_FunctionDefT(self, node): def visit_FunctionDefT(self, node):
self.visit(node.args.defaults) self.visit(node.args.defaults)
@ -116,10 +121,10 @@ class IODelayEstimator(algorithm.Visitor):
body = node.body[:-1] body = node.body[:-1]
else: else:
body = node.body body = node.body
self.visit_function(node.args, body, node.signature_type.find()) self.visit_function(node.args, body, node.signature_type.find(), node.loc)
def visit_LambdaT(self, node): def visit_LambdaT(self, node):
self.visit_function(node.args, node.body, node.type.find()) self.visit_function(node.args, node.body, node.type.find(), node.loc)
def get_iterable_length(self, node): def get_iterable_length(self, node):
def abort(notes): def abort(notes):
@ -225,10 +230,10 @@ class IODelayEstimator(algorithm.Visitor):
else: else:
assert False assert False
delay = typ.find().delay delay = typ.find().delay.find()
if iodelay.is_unknown(delay): if types.is_var(delay):
raise _UnknownDelay() raise _UnknownDelay()
elif iodelay.is_indeterminate(delay): elif delay.is_indeterminate():
cause = delay.cause cause = delay.cause
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
"function called here", {}, "function called here", {},
@ -236,15 +241,15 @@ class IODelayEstimator(algorithm.Visitor):
diag = diagnostic.Diagnostic(cause.level, cause.reason, cause.arguments, diag = diagnostic.Diagnostic(cause.level, cause.reason, cause.arguments,
cause.location, cause.highlights, cause.notes + [note]) cause.location, cause.highlights, cause.notes + [note])
raise diagnostic.Error(diag) raise diagnostic.Error(diag)
elif iodelay.is_fixed(delay): elif delay.is_fixed():
args = {} args = {}
for kw_node in node.keywords: for kw_node in node.keywords:
args[kw_node.arg] = kw_node.value args[kw_node.arg] = kw_node.value
for arg_name, arg_node in zip(typ.args, node.args[offset:]): for arg_name, arg_node in zip(typ.args, node.args[offset:]):
args[arg_name] = arg_node args[arg_name] = arg_node
free_vars = delay.length.free_vars() free_vars = delay.duration.free_vars()
self.current_delay += delay.length.fold( self.current_delay += delay.duration.fold(
{ arg: self.evaluate(args[arg], abort=abort) for arg in free_vars }) { arg: self.evaluate(args[arg], abort=abort) for arg in free_vars })
else: else:
assert False assert False

View File

@ -5,6 +5,7 @@ in :mod:`asttyped`.
import string import string
from collections import OrderedDict from collections import OrderedDict
from pythonparser import diagnostic
from . import iodelay from . import iodelay
@ -192,8 +193,8 @@ class TFunction(Type):
optional arguments optional arguments
:ivar ret: (:class:`Type`) :ivar ret: (:class:`Type`)
return type return type
:ivar delay: (:class:`iodelay.Delay`) :ivar delay: (:class:`Type`)
RTIO delay expression RTIO delay
""" """
attributes = OrderedDict([ attributes = OrderedDict([
@ -201,12 +202,12 @@ class TFunction(Type):
('__closure__', _TPointer()), ('__closure__', _TPointer()),
]) ])
def __init__(self, args, optargs, ret, delay=iodelay.Unknown()): def __init__(self, args, optargs, ret):
assert isinstance(args, OrderedDict) assert isinstance(args, OrderedDict)
assert isinstance(optargs, OrderedDict) assert isinstance(optargs, OrderedDict)
assert isinstance(ret, Type) assert isinstance(ret, Type)
assert isinstance(delay, iodelay.Delay) self.args, self.optargs, self.ret = args, optargs, ret
self.args, self.optargs, self.ret, self.delay = args, optargs, ret, delay self.delay = TVar()
def arity(self): def arity(self):
return len(self.args) + len(self.optargs) return len(self.args) + len(self.optargs)
@ -222,6 +223,7 @@ class TFunction(Type):
list(other.args.values()) + list(other.optargs.values())): list(other.args.values()) + list(other.optargs.values())):
selfarg.unify(otherarg) selfarg.unify(otherarg)
self.ret.unify(other.ret) self.ret.unify(other.ret)
self.delay.unify(other.delay)
elif isinstance(other, TVar): elif isinstance(other, TVar):
other.unify(self) other.unify(self)
else: else:
@ -261,7 +263,7 @@ class TRPCFunction(TFunction):
def __init__(self, args, optargs, ret, service): def __init__(self, args, optargs, ret, service):
super().__init__(args, optargs, ret, super().__init__(args, optargs, ret,
delay=iodelay.Fixed(iodelay.Constant(0))) delay=FixedDelay(iodelay.Constant(0)))
self.service = service self.service = service
def unify(self, other): def unify(self, other):
@ -284,7 +286,7 @@ class TCFunction(TFunction):
def __init__(self, args, ret, name): def __init__(self, args, ret, name):
super().__init__(args, OrderedDict(), ret, super().__init__(args, OrderedDict(), ret,
delay=iodelay.Fixed(iodelay.Constant(0))) delay=FixedDelay(iodelay.Constant(0)))
self.name = name self.name = name
def unify(self, other): def unify(self, other):
@ -418,6 +420,63 @@ class TValue(Type):
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
class TDelay(Type):
"""
The type-level representation of IO delay.
"""
def __init__(self, duration, cause):
assert duration is None or isinstance(duration, iodelay.Expr)
assert cause is None or isinstance(cause, diagnostic.Diagnostic)
assert (not (duration and cause)) and (duration or cause)
self.duration, self.cause = duration, cause
def is_fixed(self):
return self.duration is not None
def is_indeterminate(self):
return self.cause is not None
def find(self):
return self
def unify(self, other):
other = other.find()
if self.is_fixed() and other.is_fixed() and \
self.duration.fold() == other.duration.fold():
pass
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
def fold(self, accum, fn):
# delay types do not participate in folding
pass
def __eq__(self, other):
return isinstance(other, TDelay) and \
(self.duration == other.duration and \
self.cause == other.cause)
def __ne__(self, other):
return not (self == other)
def __repr__(self):
if self.duration is None:
return "<{}.TIndeterminateDelay>".format(__name__)
elif self.cause is None:
return "{}.TFixedDelay({})".format(__name__, self.duration)
else:
assert False
def TIndeterminateDelay(cause):
return TDelay(None, cause)
def TFixedDelay(duration):
return TDelay(duration, None)
def is_var(typ): def is_var(typ):
return isinstance(typ.find(), TVar) return isinstance(typ.find(), TVar)
@ -511,6 +570,16 @@ def get_value(typ):
else: else:
assert False assert False
def is_delay(typ):
return isinstance(typ.find(), TDelay)
def is_fixed_delay(typ):
return is_delay(typ) and typ.find().is_fixed()
def is_indeterminate_delay(typ):
return is_delay(typ) and typ.find().is_indeterminate()
class TypePrinter(object): class TypePrinter(object):
""" """
A class that prints types using Python-like syntax and gives A class that prints types using Python-like syntax and gives
@ -553,14 +622,10 @@ class TypePrinter(object):
args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs] args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs]
signature = "(%s)->%s" % (", ".join(args), self.name(typ.ret)) signature = "(%s)->%s" % (", ".join(args), self.name(typ.ret))
if iodelay.is_unknown(typ.delay) or iodelay.is_fixed(typ.delay, 0): delay = typ.delay.find()
pass if not (isinstance(delay, TVar) or
elif iodelay.is_fixed(typ.delay): delay.is_fixed() and iodelay.is_zero(delay.duration)):
signature += " delay({} mu)".format(typ.delay.length) signature += " " + self.name(delay)
elif iodelay.is_indeterminate(typ.delay):
signature += " delay(?)"
else:
assert False
if isinstance(typ, TRPCFunction): if isinstance(typ, TRPCFunction):
return "rpc({}) {}".format(typ.service, signature) return "rpc({}) {}".format(typ.service, signature)
@ -580,5 +645,12 @@ class TypePrinter(object):
return "<constructor {} {{{}}}>".format(typ.name, attrs) return "<constructor {} {{{}}}>".format(typ.name, attrs)
elif isinstance(typ, TValue): elif isinstance(typ, TValue):
return repr(typ.value) return repr(typ.value)
elif isinstance(typ, TDelay):
if typ.is_fixed():
return "delay({} mu)".format(typ.duration)
elif typ.is_indeterminate():
return "delay(?)"
else:
assert False
else: else:
assert False assert False

View File

@ -1,4 +1,5 @@
from artiq.language.core import * from artiq.language.core import *
from artiq.language.types import *
from artiq.language.units import * from artiq.language.units import *

View File

@ -0,0 +1,11 @@
# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t
# RUN: OutputCheck %s --file-to-check=%t
def f():
delay_mu(10)
# CHECK-L: ${LINE:+1}: fatal: delay delay(20 mu) was inferred for this function, but its delay is already constrained externally to delay(10 mu)
def g():
delay_mu(20)
x = f if True else g

View File

@ -7,7 +7,8 @@ def f():
delay(1.5) delay(1.5)
return 10 return 10
# CHECK-L: g: (x:float)->int(width=32) delay(0 mu) # CHECK-L: g: (x:float)->int(width=32)
# CHECK-NOT-L: delay
def g(x): def g(x):
if x > 1.0: if x > 1.0:
return 1 return 1