compiler: don't typecheck RPCs except for return type.

Fixes #260.
This commit is contained in:
whitequark 2016-04-25 22:05:32 +00:00
parent 063639662e
commit 1464bae6b7
13 changed files with 269 additions and 267 deletions

View File

@ -255,6 +255,6 @@ def is_allocated(typ):
return not (is_none(typ) or is_bool(typ) or is_int(typ) or return not (is_none(typ) or is_bool(typ) or is_int(typ) or
is_float(typ) or is_range(typ) or is_float(typ) or is_range(typ) or
types._is_pointer(typ) or types.is_function(typ) or types._is_pointer(typ) or types.is_function(typ) or
types.is_c_function(typ) or types.is_rpc_function(typ) or types.is_c_function(typ) or types.is_rpc(typ) or
types.is_method(typ) or types.is_tuple(typ) or types.is_method(typ) or types.is_tuple(typ) or
types.is_value(typ)) types.is_value(typ))

View File

@ -428,9 +428,8 @@ class StitchingInferencer(Inferencer):
if attr_name not in attributes: if attr_name not in attributes:
# We just figured out what the type should be. Add it. # We just figured out what the type should be. Add it.
attributes[attr_name] = attr_value_type attributes[attr_name] = attr_value_type
elif not types.is_rpc_function(attr_value_type): else:
# Does this conflict with an earlier guess? # Does this conflict with an earlier guess?
# RPC function types are exempt because RPCs are dynamically typed.
try: try:
attributes[attr_name].unify(attr_value_type) attributes[attr_name].unify(attr_value_type)
except types.UnificationError as e: except types.UnificationError as e:
@ -694,29 +693,22 @@ class Stitcher:
# Let the rest of the program decide. # Let the rest of the program decide.
return types.TVar() return types.TVar()
def _quote_foreign_function(self, function, loc, syscall, flags): def _quote_syscall(self, function, loc):
signature = inspect.signature(function) signature = inspect.signature(function)
arg_types = OrderedDict() arg_types = OrderedDict()
optarg_types = OrderedDict() optarg_types = OrderedDict()
for param in signature.parameters.values(): for param in signature.parameters.values():
if param.kind not in (inspect.Parameter.POSITIONAL_ONLY, if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
inspect.Parameter.POSITIONAL_OR_KEYWORD): diag = diagnostic.Diagnostic("error",
# We pretend we don't see *args, kwpostargs=..., **kwargs. "system calls must only use positional arguments; '{argument}' isn't",
# Since every method can be still invoked without any arguments {"argument": param.name},
# going into *args and the slots after it, this is always safe, self._function_loc(function),
# if sometimes constraining. notes=self._call_site_note(loc, is_syscall=True))
# self.engine.process(diag)
# Accepting POSITIONAL_ONLY is OK, because the compiler
# desugars the keyword arguments into positional ones internally.
continue
if param.default is inspect.Parameter.empty: if param.default is inspect.Parameter.empty:
arg_types[param.name] = self._type_of_param(function, loc, param, arg_types[param.name] = self._type_of_param(function, loc, param, is_syscall=True)
is_syscall=syscall is not None)
elif syscall is None:
optarg_types[param.name] = self._type_of_param(function, loc, param,
is_syscall=False)
else: else:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"system call argument '{argument}' must not have a default value", "system call argument '{argument}' must not have a default value",
@ -727,10 +719,8 @@ class Stitcher:
if signature.return_annotation is not inspect.Signature.empty: if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(function, signature.return_annotation, ret_type = self._extract_annot(function, signature.return_annotation,
"return type", loc, is_syscall=syscall is not None) "return type", loc, is_syscall=True)
elif syscall is None: else:
ret_type = builtins.TNone()
else: # syscall is not None
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"system call must have a return type annotation", {}, "system call must have a return type annotation", {},
self._function_loc(function), self._function_loc(function),
@ -738,15 +728,23 @@ class Stitcher:
self.engine.process(diag) self.engine.process(diag)
ret_type = types.TVar() ret_type = types.TVar()
if syscall is None: function_type = types.TCFunction(arg_types, ret_type,
function_type = types.TRPCFunction(arg_types, optarg_types, ret_type, name=function.artiq_embedded.syscall,
service=self.object_map.store(function)) flags=function.artiq_embedded.flags)
else:
function_type = types.TCFunction(arg_types, ret_type,
name=syscall, flags=flags)
self.functions[function] = function_type self.functions[function] = function_type
return function_type
def _quote_rpc(self, function, loc):
signature = inspect.signature(function)
if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(function, signature.return_annotation,
"return type", loc, is_syscall=False)
else:
ret_type = builtins.TNone()
function_type = types.TRPC(ret_type, service=self.object_map.store(function))
self.functions[function] = function_type
return function_type return function_type
def _quote_function(self, function, loc): def _quote_function(self, function, loc):
@ -780,9 +778,7 @@ class Stitcher:
elif function.artiq_embedded.syscall is not None: elif function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler # Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call. # to perform a system call instead of a regular call.
self._quote_foreign_function(function, loc, self._quote_syscall(function, loc)
syscall=function.artiq_embedded.syscall,
flags=function.artiq_embedded.flags)
elif function.artiq_embedded.forbidden is not None: elif function.artiq_embedded.forbidden is not None:
diag = diagnostic.Diagnostic("fatal", diag = diagnostic.Diagnostic("fatal",
"this function cannot be called as an RPC", {}, "this function cannot be called as an RPC", {},
@ -792,14 +788,9 @@ class Stitcher:
else: else:
assert False assert False
else: else:
# Insert a storage-less global whose type instructs the compiler self._quote_rpc(function, loc)
# to perform an RPC instead of a regular call.
self._quote_foreign_function(function, loc, syscall=None, flags=None)
function_type = self.functions[function] return self.functions[function]
if types.is_rpc_function(function_type):
function_type = types.instantiate(function_type)
return function_type
def _quote(self, value, loc): def _quote(self, value, loc):
synthesizer = self._synthesizer(loc) synthesizer = self._synthesizer(loc)

View File

@ -23,12 +23,19 @@ def is_basic_block(typ):
return isinstance(typ, TBasicBlock) return isinstance(typ, TBasicBlock)
class TOption(types.TMono): class TOption(types.TMono):
def __init__(self, inner): def __init__(self, value):
super().__init__("option", {"inner": inner}) super().__init__("option", {"value": value})
def is_option(typ): def is_option(typ):
return isinstance(typ, TOption) return isinstance(typ, TOption)
class TKeyword(types.TMono):
def __init__(self, value):
super().__init__("keyword", {"value": value})
def is_keyword(typ):
return isinstance(typ, TKeyword)
class TExceptionTypeInfo(types.TMono): class TExceptionTypeInfo(types.TMono):
def __init__(self): def __init__(self):
super().__init__("exntypeinfo") super().__init__("exntypeinfo")
@ -678,7 +685,7 @@ class GetAttr(Instruction):
typ = obj.type.attributes[attr] typ = obj.type.attributes[attr]
else: else:
typ = obj.type.constructor.attributes[attr] typ = obj.type.constructor.attributes[attr]
if types.is_function(typ): if types.is_function(typ) or types.is_rpc(typ):
typ = types.TMethod(obj.type, typ) typ = types.TMethod(obj.type, typ)
super().__init__([obj], typ, name) super().__init__([obj], typ, name)
self.attr = attr self.attr = attr

View File

@ -922,9 +922,6 @@ class ARTIQIRGenerator(algorithm.Visitor):
if self.current_assign is None: if self.current_assign is None:
return self.append(ir.GetAttr(obj, node.attr, return self.append(ir.GetAttr(obj, node.attr,
name="{}.FLD.{}".format(_readable_name(obj), node.attr))) name="{}.FLD.{}".format(_readable_name(obj), node.attr)))
elif types.is_rpc_function(self.current_assign.type):
# RPC functions are just type-level markers
return self.append(ir.Builtin("nop", [], builtins.TNone()))
else: else:
return self.append(ir.SetAttr(obj, node.attr, self.current_assign)) return self.append(ir.SetAttr(obj, node.attr, self.current_assign))
@ -1719,7 +1716,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.engine.process(diag) self.engine.process(diag)
def _user_call(self, callee, positional, keywords, arg_exprs={}): def _user_call(self, callee, positional, keywords, arg_exprs={}):
if types.is_function(callee.type): if types.is_function(callee.type) or types.is_rpc(callee.type):
func = callee func = callee
self_arg = None self_arg = None
fn_typ = callee.type fn_typ = callee.type
@ -1734,40 +1731,51 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
assert False assert False
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) if types.is_rpc(fn_typ):
if self_arg is None:
for index, arg in enumerate(positional): args = positional
if index + offset < len(fn_typ.args):
args[index + offset] = arg
else: else:
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type))) args = [self_arg] + positional
for keyword in keywords: for keyword in keywords:
arg = keywords[keyword] arg = keywords[keyword]
if keyword in fn_typ.args: args.append(self.append(ir.Alloc([ir.Constant(keyword, builtins.TStr()), arg],
for index, arg_name in enumerate(fn_typ.args): ir.TKeyword(arg.type))))
if keyword == arg_name: else:
assert args[index] is None args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
args[index] = arg
break
elif keyword in fn_typ.optargs:
for index, optarg_name in enumerate(fn_typ.optargs):
if keyword == optarg_name:
assert args[len(fn_typ.args) + index] is None
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
break
for index, optarg_name in enumerate(fn_typ.optargs): for index, arg in enumerate(positional):
if args[len(fn_typ.args) + index] is None: if index + offset < len(fn_typ.args):
args[len(fn_typ.args) + index] = \ args[index + offset] = arg
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name]))) else:
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
if self_arg is not None: for keyword in keywords:
assert args[0] is None arg = keywords[keyword]
args[0] = self_arg if keyword in fn_typ.args:
for index, arg_name in enumerate(fn_typ.args):
if keyword == arg_name:
assert args[index] is None
args[index] = arg
break
elif keyword in fn_typ.optargs:
for index, optarg_name in enumerate(fn_typ.optargs):
if keyword == optarg_name:
assert args[len(fn_typ.args) + index] is None
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
break
assert None not in args for index, optarg_name in enumerate(fn_typ.optargs):
if args[len(fn_typ.args) + index] is None:
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
if self_arg is not None:
assert args[0] is None
args[0] = self_arg
assert None not in args
if self.unwind_target is None: if self.unwind_target is None:
insn = self.append(ir.Call(func, args, arg_exprs)) insn = self.append(ir.Call(func, args, arg_exprs))

View File

@ -109,17 +109,11 @@ class Inferencer(algorithm.Visitor):
] ]
attr_type = object_type.attributes[attr_name] attr_type = object_type.attributes[attr_name]
if types.is_rpc_function(attr_type):
attr_type = types.instantiate(attr_type)
self._unify(result_type, attr_type, loc, None, self._unify(result_type, attr_type, loc, None,
makenotes=makenotes, when=" for attribute '{}'".format(attr_name)) makenotes=makenotes, when=" for attribute '{}'".format(attr_name))
elif types.is_instance(object_type) and \ elif types.is_instance(object_type) and \
attr_name in object_type.constructor.attributes: attr_name in object_type.constructor.attributes:
attr_type = object_type.constructor.attributes[attr_name].find() attr_type = object_type.constructor.attributes[attr_name].find()
if types.is_rpc_function(attr_type):
attr_type = types.instantiate(attr_type)
if types.is_function(attr_type): if types.is_function(attr_type):
# Convert to a method. # Convert to a method.
if len(attr_type.args) < 1: if len(attr_type.args) < 1:
@ -155,6 +149,10 @@ class Inferencer(algorithm.Visitor):
when=" while inferring the type for self argument") when=" while inferring the type for self argument")
attr_type = types.TMethod(object_type, attr_type) attr_type = types.TMethod(object_type, attr_type)
elif types.is_rpc(attr_type):
# Convert to a method. We don't have to bother typechecking
# the self argument, since for RPCs anything goes.
attr_type = types.TMethod(object_type, attr_type)
if not types.is_var(attr_type): if not types.is_var(attr_type):
self._unify(result_type, attr_type, self._unify(result_type, attr_type,
@ -871,6 +869,10 @@ class Inferencer(algorithm.Visitor):
return # not enough info yet return # not enough info yet
elif types.is_builtin(typ): elif types.is_builtin(typ):
return self.visit_builtin_call(node) return self.visit_builtin_call(node)
elif types.is_rpc(typ):
self._unify(node.type, typ.ret,
node.loc, None)
return
elif not (types.is_function(typ) or types.is_method(typ)): elif not (types.is_function(typ) or types.is_method(typ)):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"cannot call this expression of type {type}", "cannot call this expression of type {type}",
@ -888,6 +890,10 @@ class Inferencer(algorithm.Visitor):
typ = types.get_method_function(typ) typ = types.get_method_function(typ)
if types.is_var(typ): if types.is_var(typ):
return # not enough info yet return # not enough info yet
elif types.is_rpc(typ):
self._unify(node.type, typ.ret,
node.loc, None)
return
typ_arity = typ.arity() - 1 typ_arity = typ.arity() - 1
typ_args = OrderedDict(list(typ.args.items())[1:]) typ_args = OrderedDict(list(typ.args.items())[1:])

View File

@ -280,7 +280,7 @@ class IODelayEstimator(algorithm.Visitor):
context="as an argument for delay_mu()") context="as an argument for delay_mu()")
call_delay = value call_delay = value
elif not types.is_builtin(typ): elif not types.is_builtin(typ):
if types.is_function(typ): if types.is_function(typ) or types.is_rpc(typ):
offset = 0 offset = 0
elif types.is_method(typ): elif types.is_method(typ):
offset = 1 offset = 1
@ -288,35 +288,38 @@ class IODelayEstimator(algorithm.Visitor):
else: else:
assert False assert False
delay = typ.find().delay.find() if types.is_rpc(typ):
if types.is_var(delay): call_delay = iodelay.Const(0)
raise _UnknownDelay()
elif delay.is_indeterminate():
note = diagnostic.Diagnostic("note",
"function called here", {},
node.loc)
cause = delay.cause
cause = diagnostic.Diagnostic(cause.level, cause.reason, cause.arguments,
cause.location, cause.highlights,
cause.notes + [note])
raise _IndeterminateDelay(cause)
elif delay.is_fixed():
args = {}
for kw_node in node.keywords:
args[kw_node.arg] = kw_node.value
for arg_name, arg_node in zip(list(typ.args)[offset:], node.args):
args[arg_name] = arg_node
free_vars = delay.duration.free_vars()
node.arg_exprs = {
arg: self.evaluate(args[arg], abort=abort,
context="in the expression for argument '{}' "
"that affects I/O delay".format(arg))
for arg in free_vars
}
call_delay = delay.duration.fold(node.arg_exprs)
else: else:
assert False delay = typ.find().delay.find()
if types.is_var(delay):
raise _UnknownDelay()
elif delay.is_indeterminate():
note = diagnostic.Diagnostic("note",
"function called here", {},
node.loc)
cause = delay.cause
cause = diagnostic.Diagnostic(cause.level, cause.reason, cause.arguments,
cause.location, cause.highlights,
cause.notes + [note])
raise _IndeterminateDelay(cause)
elif delay.is_fixed():
args = {}
for kw_node in node.keywords:
args[kw_node.arg] = kw_node.value
for arg_name, arg_node in zip(list(typ.args)[offset:], node.args):
args[arg_name] = arg_node
free_vars = delay.duration.free_vars()
node.arg_exprs = {
arg: self.evaluate(args[arg], abort=abort,
context="in the expression for argument '{}' "
"that affects I/O delay".format(arg))
for arg in free_vars
}
call_delay = delay.duration.fold(node.arg_exprs)
else:
assert False
else: else:
call_delay = iodelay.Const(0) call_delay = iodelay.Const(0)

View File

@ -177,7 +177,7 @@ class LLVMIRGenerator:
typ = typ.find() typ = typ.find()
if types.is_tuple(typ): if types.is_tuple(typ):
return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts]) return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts])
elif types.is_rpc_function(typ) or types.is_c_function(typ): elif types.is_rpc(typ) or types.is_c_function(typ):
if for_return: if for_return:
return llvoid return llvoid
else: else:
@ -229,7 +229,9 @@ class LLVMIRGenerator:
elif ir.is_basic_block(typ): elif ir.is_basic_block(typ):
return llptr return llptr
elif ir.is_option(typ): elif ir.is_option(typ):
return ll.LiteralStructType([lli1, self.llty_of_type(typ.params["inner"])]) return ll.LiteralStructType([lli1, self.llty_of_type(typ.params["value"])])
elif ir.is_keyword(typ):
return ll.LiteralStructType([llptr, self.llty_of_type(typ.params["value"])])
elif ir.is_environment(typ): elif ir.is_environment(typ):
llty = self.llcontext.get_identified_type("env.{}".format(typ.env_name)) llty = self.llcontext.get_identified_type("env.{}".format(typ.env_name))
if llty.elements is None: if llty.elements is None:
@ -618,7 +620,7 @@ class LLVMIRGenerator:
size=llsize) size=llsize)
llvalue = self.llbuilder.insert_value(llvalue, llalloc, 1, name=insn.name) llvalue = self.llbuilder.insert_value(llvalue, llalloc, 1, name=insn.name)
return llvalue return llvalue
elif not builtins.is_allocated(insn.type): elif not builtins.is_allocated(insn.type) or ir.is_keyword(insn.type):
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
for index, elt in enumerate(insn.operands): for index, elt in enumerate(insn.operands):
llvalue = self.llbuilder.insert_value(llvalue, self.map(elt), index) llvalue = self.llbuilder.insert_value(llvalue, self.map(elt), index)
@ -707,8 +709,8 @@ class LLVMIRGenerator:
def get_global_closure(self, typ, attr): def get_global_closure(self, typ, attr):
closure_type = typ.attributes[attr] closure_type = typ.attributes[attr]
assert types.is_constructor(typ) assert types.is_constructor(typ)
assert types.is_function(closure_type) assert types.is_function(closure_type) or types.is_rpc(closure_type)
if types.is_c_function(closure_type) or types.is_rpc_function(closure_type): if types.is_c_function(closure_type) or types.is_rpc(closure_type):
return None return None
llty = self.llty_of_type(typ.attributes[attr]) llty = self.llty_of_type(typ.attributes[attr])
@ -1156,8 +1158,8 @@ class LLVMIRGenerator:
elif builtins.is_range(typ): elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ), return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
error_handler) error_handler)
elif ir.is_option(typ): elif ir.is_keyword(typ):
return b"o" + self._rpc_tag(typ.params["inner"], return b"k" + self._rpc_tag(typ.params["value"],
error_handler) error_handler)
elif '__objectid__' in typ.attributes: elif '__objectid__' in typ.attributes:
return b"O" return b"O"
@ -1271,7 +1273,7 @@ class LLVMIRGenerator:
def process_Call(self, insn): def process_Call(self, insn):
functiontyp = insn.target_function().type functiontyp = insn.target_function().type
if types.is_rpc_function(functiontyp): if types.is_rpc(functiontyp):
return self._build_rpc(insn.target_function().loc, return self._build_rpc(insn.target_function().loc,
functiontyp, functiontyp,
insn.arguments(), insn.arguments(),
@ -1303,7 +1305,7 @@ class LLVMIRGenerator:
functiontyp = insn.target_function().type functiontyp = insn.target_function().type
llnormalblock = self.map(insn.normal_target()) llnormalblock = self.map(insn.normal_target())
llunwindblock = self.map(insn.exception_target()) llunwindblock = self.map(insn.exception_target())
if types.is_rpc_function(functiontyp): if types.is_rpc(functiontyp):
return self._build_rpc(insn.target_function().loc, return self._build_rpc(insn.target_function().loc,
functiontyp, functiontyp,
insn.arguments(), insn.arguments(),
@ -1392,7 +1394,7 @@ class LLVMIRGenerator:
lleltsptr = llglobal.bitcast(lleltsary.type.element.as_pointer()) lleltsptr = llglobal.bitcast(lleltsary.type.element.as_pointer())
llconst = ll.Constant(llty, [ll.Constant(lli32, len(llelts)), lleltsptr]) llconst = ll.Constant(llty, [ll.Constant(lli32, len(llelts)), lleltsptr])
return llconst return llconst
elif types.is_function(typ): elif types.is_rpc(typ) or types.is_function(typ):
# RPC and C functions have no runtime representation. # RPC and C functions have no runtime representation.
# We only get down this codepath for ARTIQ Python functions when they're # We only get down this codepath for ARTIQ Python functions when they're
# referenced from a constructor, and the value inside the constructor # referenced from a constructor, and the value inside the constructor

View File

@ -94,9 +94,6 @@ class TVar(Type):
else: else:
return self.find().fold(accum, fn) return self.find().fold(accum, fn)
def map(self, fn):
return fn(self)
def __repr__(self): def __repr__(self):
if self.parent is self: if self.parent is self:
return "<artiq.compiler.types.TVar %d>" % id(self) return "<artiq.compiler.types.TVar %d>" % id(self)
@ -141,21 +138,6 @@ class TMono(Type):
accum = self.params[param].fold(accum, fn) accum = self.params[param].fold(accum, fn)
return fn(accum, self) return fn(accum, self)
def map(self, fn):
params = OrderedDict()
for param in self.params:
params[param] = self.params[param].map(fn)
attributes = OrderedDict()
for attr in self.attributes:
attributes[attr] = self.attributes[attr].map(fn)
self_copy = self.__class__.__new__(self.__class__)
self_copy.name = self.name
self_copy.params = params
self_copy.attributes = attributes
return fn(self_copy)
def __repr__(self): def __repr__(self):
return "artiq.compiler.types.TMono(%s, %s)" % (repr(self.name), repr(self.params)) return "artiq.compiler.types.TMono(%s, %s)" % (repr(self.name), repr(self.params))
@ -202,9 +184,6 @@ class TTuple(Type):
accum = elt.fold(accum, fn) accum = elt.fold(accum, fn)
return fn(accum, self) return fn(accum, self)
def map(self, fn):
return fn(TTuple(list(map(lambda elt: elt.map(fn), self.elts))))
def __repr__(self): def __repr__(self):
return "artiq.compiler.types.TTuple(%s)" % repr(self.elts) return "artiq.compiler.types.TTuple(%s)" % repr(self.elts)
@ -276,23 +255,6 @@ class TFunction(Type):
accum = self.ret.fold(accum, fn) accum = self.ret.fold(accum, fn)
return fn(accum, self) return fn(accum, self)
def _map_args(self, fn):
args = OrderedDict()
for arg in self.args:
args[arg] = self.args[arg].map(fn)
optargs = OrderedDict()
for optarg in self.optargs:
optargs[optarg] = self.optargs[optarg].map(fn)
return args, optargs, self.ret.map(fn)
def map(self, fn):
args, optargs, ret = self._map_args(fn)
self_copy = TFunction(args, optargs, ret)
self_copy.delay = self.delay.map(fn)
return fn(self_copy)
def __repr__(self): def __repr__(self):
return "artiq.compiler.types.TFunction({}, {}, {})".format( return "artiq.compiler.types.TFunction({}, {}, {})".format(
repr(self.args), repr(self.optargs), repr(self.ret)) repr(self.args), repr(self.optargs), repr(self.ret))
@ -308,35 +270,6 @@ class TFunction(Type):
def __hash__(self): def __hash__(self):
return hash((_freeze(self.args), _freeze(self.optargs), self.ret)) return hash((_freeze(self.args), _freeze(self.optargs), self.ret))
class TRPCFunction(TFunction):
"""
A function type of a remote function.
:ivar service: (int) RPC service number
"""
attributes = OrderedDict()
def __init__(self, args, optargs, ret, service):
super().__init__(args, optargs, ret)
self.service = service
self.delay = TFixedDelay(iodelay.Const(0))
def unify(self, other):
if isinstance(other, TRPCFunction) and \
self.service == other.service:
super().unify(other)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
def map(self, fn):
args, optargs, ret = self._map_args(fn)
self_copy = TRPCFunction(args, optargs, ret, self.service)
self_copy.delay = self.delay.map(fn)
return fn(self_copy)
class TCFunction(TFunction): class TCFunction(TFunction):
""" """
A function type of a runtime-provided C function. A function type of a runtime-provided C function.
@ -368,11 +301,49 @@ class TCFunction(TFunction):
else: else:
raise UnificationError(self, other) raise UnificationError(self, other)
def map(self, fn): class TRPC(Type):
args, _optargs, ret = self._map_args(fn) """
self_copy = TCFunction(args, ret, self.name) A type of a remote call.
self_copy.delay = self.delay.map(fn)
return fn(self_copy) :ivar ret: (:class:`Type`)
return type
:ivar service: (int) RPC service number
"""
attributes = OrderedDict()
def __init__(self, ret, service):
assert isinstance(ret, Type)
self.ret, self.service = ret, service
def find(self):
return self
def unify(self, other):
if isinstance(other, TRPC) and \
self.service == other.service:
self.ret.unify(other.ret)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
def fold(self, accum, fn):
accum = self.ret.fold(accum, fn)
return fn(accum, self)
def __repr__(self):
return "artiq.compiler.types.TRPC({})".format(repr(self.ret))
def __eq__(self, other):
return isinstance(other, TRPC) and \
self.service == other.service
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash(self.service)
class TBuiltin(Type): class TBuiltin(Type):
""" """
@ -395,9 +366,6 @@ class TBuiltin(Type):
def fold(self, accum, fn): def fold(self, accum, fn):
return fn(accum, self) return fn(accum, self)
def map(self, fn):
return fn(self)
def __repr__(self): def __repr__(self):
return "artiq.compiler.types.{}({})".format(type(self).__name__, repr(self.name)) return "artiq.compiler.types.{}({})".format(type(self).__name__, repr(self.name))
@ -490,9 +458,6 @@ class TValue(Type):
def fold(self, accum, fn): def fold(self, accum, fn):
return fn(accum, self) return fn(accum, self)
def map(self, fn):
return fn(self)
def __repr__(self): def __repr__(self):
return "artiq.compiler.types.TValue(%s)" % repr(self.value) return "artiq.compiler.types.TValue(%s)" % repr(self.value)
@ -543,10 +508,6 @@ class TDelay(Type):
# delay types do not participate in folding # delay types do not participate in folding
pass pass
def map(self, fn):
# or mapping
return self
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TDelay) and \ return isinstance(other, TDelay) and \
(self.duration == other.duration and \ (self.duration == other.duration and \
@ -570,18 +531,6 @@ def TFixedDelay(duration):
return TDelay(duration, None) return TDelay(duration, None)
def instantiate(typ):
tvar_map = dict()
def mapper(typ):
typ = typ.find()
if is_var(typ):
if typ not in tvar_map:
tvar_map[typ] = TVar()
return tvar_map[typ]
return typ
return typ.map(mapper)
def is_var(typ): def is_var(typ):
return isinstance(typ.find(), TVar) return isinstance(typ.find(), TVar)
@ -616,8 +565,8 @@ def _is_pointer(typ):
def is_function(typ): def is_function(typ):
return isinstance(typ.find(), TFunction) return isinstance(typ.find(), TFunction)
def is_rpc_function(typ): def is_rpc(typ):
return isinstance(typ.find(), TRPCFunction) return isinstance(typ.find(), TRPC)
def is_c_function(typ, name=None): def is_c_function(typ, name=None):
typ = typ.find() typ = typ.find()
@ -732,7 +681,7 @@ class TypePrinter(object):
return "(%s,)" % self.name(typ.elts[0], depth + 1) return "(%s,)" % self.name(typ.elts[0], depth + 1)
else: else:
return "(%s)" % ", ".join([self.name(typ, depth + 1) for typ in typ.elts]) return "(%s)" % ", ".join([self.name(typ, depth + 1) for typ in typ.elts])
elif isinstance(typ, (TFunction, TRPCFunction, TCFunction)): elif isinstance(typ, (TFunction, TCFunction)):
args = [] args = []
args += [ "%s:%s" % (arg, self.name(typ.args[arg], depth + 1)) args += [ "%s:%s" % (arg, self.name(typ.args[arg], depth + 1))
for arg in typ.args] for arg in typ.args]
@ -746,12 +695,12 @@ class TypePrinter(object):
elif not (delay.is_fixed() and iodelay.is_zero(delay.duration)): elif not (delay.is_fixed() and iodelay.is_zero(delay.duration)):
signature += " " + self.name(delay, depth + 1) signature += " " + self.name(delay, depth + 1)
if isinstance(typ, TRPCFunction):
return "[rpc #{}]{}".format(typ.service, signature)
if isinstance(typ, TCFunction): if isinstance(typ, TCFunction):
return "[ffi {}]{}".format(repr(typ.name), signature) return "[ffi {}]{}".format(repr(typ.name), signature)
elif isinstance(typ, TFunction): elif isinstance(typ, TFunction):
return signature return signature
elif isinstance(typ, TRPC):
return "[rpc #{}](...)->{}".format(typ.service, self.name(typ.ret, depth + 1))
elif isinstance(typ, TBuiltinFunction): elif isinstance(typ, TBuiltinFunction):
return "<function {}>".format(typ.name) return "<function {}>".format(typ.name)
elif isinstance(typ, (TConstructor, TExceptionConstructor)): elif isinstance(typ, (TConstructor, TExceptionConstructor)):

View File

@ -30,7 +30,7 @@ class MonomorphismValidator(algorithm.Visitor):
super().generic_visit(node) super().generic_visit(node)
if isinstance(node, asttyped.commontyped): if isinstance(node, asttyped.commontyped):
if types.is_polymorphic(node.type) and not types.is_rpc_function(node.type): if types.is_polymorphic(node.type):
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
"the expression has type {type}", "the expression has type {type}",
{"type": types.TypePrinter().name(node.type)}, {"type": types.TypePrinter().name(node.type)},

View File

@ -3,6 +3,7 @@ import logging
import traceback import traceback
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
from collections import namedtuple
from artiq.coredevice import exceptions from artiq.coredevice import exceptions
from artiq.language.core import int as wrapping_int from artiq.language.core import int as wrapping_int
@ -62,6 +63,9 @@ class RPCReturnValueError(ValueError):
pass pass
RPCKeyword = namedtuple('RPCKeyword', ['name', 'value'])
class CommGeneric: class CommGeneric:
def __init__(self): def __init__(self):
self._read_type = self._write_type = None self._read_type = self._write_type = None
@ -233,7 +237,8 @@ class CommGeneric:
raise UnsupportedDevice("Unsupported runtime ID: {}" raise UnsupportedDevice("Unsupported runtime ID: {}"
.format(runtime_id)) .format(runtime_id))
gateware_version = self._read_chunk(self._read_length).decode("utf-8") gateware_version = self._read_chunk(self._read_length).decode("utf-8")
if gateware_version != software_version: if gateware_version != software_version and \
gateware_version + ".dirty" != software_version:
logger.warning("Mismatch between gateware (%s) " logger.warning("Mismatch between gateware (%s) "
"and software (%s) versions", "and software (%s) versions",
gateware_version, software_version) gateware_version, software_version)
@ -302,7 +307,6 @@ class CommGeneric:
logger.debug("running kernel") logger.debug("running kernel")
_rpc_sentinel = object() _rpc_sentinel = object()
_rpc_undefined = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, object_map): def _receive_rpc_value(self, object_map):
@ -336,27 +340,23 @@ class CommGeneric:
stop = self._receive_rpc_value(object_map) stop = self._receive_rpc_value(object_map)
step = self._receive_rpc_value(object_map) step = self._receive_rpc_value(object_map)
return range(start, stop, step) return range(start, stop, step)
elif tag == "o": elif tag == "k":
present = self._read_int8() name = self._read_string()
if present: value = self._receive_rpc_value(object_map)
return self._receive_rpc_value(object_map) return RPCKeyword(name, value)
else:
return self._rpc_undefined
elif tag == "O": elif tag == "O":
return object_map.retrieve(self._read_int32()) return object_map.retrieve(self._read_int32())
else: else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, object_map, defaults): def _receive_rpc_args(self, object_map, defaults):
args = [] args, kwargs = [], {}
default_arg_num = 0
while True: while True:
value = self._receive_rpc_value(object_map) value = self._receive_rpc_value(object_map)
if value is self._rpc_sentinel: if value is self._rpc_sentinel:
return args return args, kwargs
elif value is self._rpc_undefined: elif isinstance(value, RPCKeyword):
args.append(defaults[default_arg_num]) kwargs[value.name] = value.value
default_arg_num += 1
else: else:
args.append(value) args.append(value)
@ -447,13 +447,13 @@ class CommGeneric:
else: else:
service = object_map.retrieve(service_id) service = object_map.retrieve(service_id)
arguments = self._receive_rpc_args(object_map, service.__defaults__) args, kwargs = self._receive_rpc_args(object_map, service.__defaults__)
return_tags = self._read_bytes() return_tags = self._read_bytes()
logger.debug("rpc service: [%d]%r %r -> %s", service_id, service, arguments, return_tags) logger.debug("rpc service: [%d]%r %r %r -> %s", service_id, service, args, kwargs, return_tags)
try: try:
result = service(*arguments) result = service(*args, **kwargs)
logger.debug("rpc service: %d %r == %r", service_id, arguments, result) logger.debug("rpc service: %d %r %r == %r", service_id, args, kwargs, result)
if service_id != 0: if service_id != 0:
self._write_header(_H2DMsgType.RPC_REPLY) self._write_header(_H2DMsgType.RPC_REPLY)
@ -461,7 +461,7 @@ class CommGeneric:
self._send_rpc_value(bytearray(return_tags), result, result, service) self._send_rpc_value(bytearray(return_tags), result, result, service)
self._write_flush() self._write_flush()
except Exception as exn: except Exception as exn:
logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn) logger.debug("rpc service: %d %r %r ! %r", service_id, args, kwargs, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION) self._write_header(_H2DMsgType.RPC_EXCEPTION)
@ -490,7 +490,13 @@ class CommGeneric:
for index in range(3): for index in range(3):
self._write_int64(0) self._write_int64(0)
(_, (filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__, 2) tb = traceback.extract_tb(exn.__traceback__, 2)
if len(tb) == 2:
(_, (filename, line, function, _), ) = tb
elif len(tb) == 1:
((filename, line, function, _), ) = tb
else:
assert False
self._write_string(filename) self._write_string(filename)
self._write_int32(line) self._write_int32(line)
self._write_int32(-1) # column not known self._write_int32(-1) # column not known

View File

@ -862,21 +862,16 @@ static int send_rpc_value(const char **tag, void **value)
break; break;
} }
case 'o': { // option(inner='a) case 'k': { // keyword(value='a)
struct { int8_t present; struct {} contents; } *option = *value; struct { const char *name; struct {} contents; } *option = *value;
void *contents = &option->contents; void *contents = &option->contents;
if(!out_packet_int8(option->present)) if(!out_packet_string(option->name))
return 0; return 0;
// option never appears in composite types, so we don't have // keyword never appears in composite types, so we don't have
// to accurately advance *value. // to accurately advance *value.
if(option->present) { return send_rpc_value(tag, &contents);
return send_rpc_value(tag, &contents);
} else {
skip_rpc_value(tag);
break;
}
} }
case 'O': { // host object case 'O': { // host object

View File

@ -59,6 +59,56 @@ class DefaultArgTest(ExperimentCase):
self.assertEqual(exp.run(), 42) self.assertEqual(exp.run(), 42)
class _RPC(EnvExperiment):
def build(self):
self.setattr_device("core")
def args(self, *args) -> TInt32:
return len(args)
def kwargs(self, x="", **kwargs) -> TInt32:
return len(kwargs)
@kernel
def args0(self):
return self.args()
@kernel
def args1(self):
return self.args("A")
@kernel
def args2(self):
return self.args("A", 1)
@kernel
def kwargs0(self):
return self.kwargs()
@kernel
def kwargs1(self):
return self.kwargs(a="A")
@kernel
def kwargs2(self):
return self.kwargs(a="A", b=1)
@kernel
def args1kwargs2(self):
return self.kwargs("X", a="A", b=1)
class RPCTest(ExperimentCase):
def test_args(self):
exp = self.create(_RPC)
self.assertEqual(exp.args0(), 0)
self.assertEqual(exp.args1(), 1)
self.assertEqual(exp.args2(), 2)
self.assertEqual(exp.kwargs0(), 0)
self.assertEqual(exp.kwargs1(), 1)
self.assertEqual(exp.kwargs2(), 2)
self.assertEqual(exp.args1kwargs2(), 2)
class _Payload1MB(EnvExperiment): class _Payload1MB(EnvExperiment):
def build(self): def build(self):
self.setattr_device("core") self.setattr_device("core")

View File

@ -1,15 +0,0 @@
# RUN: %python -m artiq.compiler.testbench.embedding +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
from artiq.language.core import *
from artiq.language.types import *
# CHECK-L: <synthesized>:1: error: cannot unify int(width='a) with str
# CHECK-L: ${LINE:+1}: note: expanded from here while trying to infer a type for an unannotated optional argument 'x' from its default value
def foo(x=[1,"x"]):
pass
@kernel
def entrypoint():
# CHECK-L: ${LINE:+1}: note: in function called remotely here
foo()