From 8d9a22f8da6b87f583c83fd3605f61e398dd13bf Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 25 Apr 2016 22:05:32 +0000 Subject: [PATCH] compiler: don't typecheck RPCs except for return type. Fixes #260. --- artiq/compiler/builtins.py | 2 +- artiq/compiler/embedding.py | 69 ++++---- artiq/compiler/ir.py | 13 +- .../compiler/transforms/artiq_ir_generator.py | 74 +++++---- artiq/compiler/transforms/inferencer.py | 18 ++- .../compiler/transforms/iodelay_estimator.py | 61 ++++---- .../compiler/transforms/llvm_ir_generator.py | 22 +-- artiq/compiler/types.py | 147 ++++++------------ artiq/compiler/validators/monomorphism.py | 2 +- artiq/coredevice/comm_generic.py | 48 +++--- artiq/runtime/session.c | 15 +- artiq/test/coredevice/test_embedding.py | 50 ++++++ .../lit/embedding/error_rpc_default_unify.py | 15 -- 13 files changed, 269 insertions(+), 267 deletions(-) delete mode 100644 artiq/test/lit/embedding/error_rpc_default_unify.py diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index 47a89641c..c06f600d3 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -255,6 +255,6 @@ def is_allocated(typ): return not (is_none(typ) or is_bool(typ) or is_int(typ) or is_float(typ) or is_range(typ) or types._is_pointer(typ) or types.is_function(typ) or - types.is_c_function(typ) or types.is_rpc_function(typ) or + types.is_c_function(typ) or types.is_rpc(typ) or types.is_method(typ) or types.is_tuple(typ) or types.is_value(typ)) diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 006d274ad..426728fd5 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -428,9 +428,8 @@ class StitchingInferencer(Inferencer): if attr_name not in attributes: # We just figured out what the type should be. Add it. attributes[attr_name] = attr_value_type - elif not types.is_rpc_function(attr_value_type): + else: # Does this conflict with an earlier guess? - # RPC function types are exempt because RPCs are dynamically typed. try: attributes[attr_name].unify(attr_value_type) except types.UnificationError as e: @@ -694,29 +693,22 @@ class Stitcher: # Let the rest of the program decide. return types.TVar() - def _quote_foreign_function(self, function, loc, syscall, flags): + def _quote_syscall(self, function, loc): signature = inspect.signature(function) arg_types = OrderedDict() optarg_types = OrderedDict() for param in signature.parameters.values(): - if param.kind not in (inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD): - # We pretend we don't see *args, kwpostargs=..., **kwargs. - # Since every method can be still invoked without any arguments - # going into *args and the slots after it, this is always safe, - # if sometimes constraining. - # - # Accepting POSITIONAL_ONLY is OK, because the compiler - # desugars the keyword arguments into positional ones internally. - continue + if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: + diag = diagnostic.Diagnostic("error", + "system calls must only use positional arguments; '{argument}' isn't", + {"argument": param.name}, + self._function_loc(function), + notes=self._call_site_note(loc, is_syscall=True)) + self.engine.process(diag) if param.default is inspect.Parameter.empty: - arg_types[param.name] = self._type_of_param(function, loc, param, - is_syscall=syscall is not None) - elif syscall is None: - optarg_types[param.name] = self._type_of_param(function, loc, param, - is_syscall=False) + arg_types[param.name] = self._type_of_param(function, loc, param, is_syscall=True) else: diag = diagnostic.Diagnostic("error", "system call argument '{argument}' must not have a default value", @@ -727,10 +719,8 @@ class Stitcher: if signature.return_annotation is not inspect.Signature.empty: ret_type = self._extract_annot(function, signature.return_annotation, - "return type", loc, is_syscall=syscall is not None) - elif syscall is None: - ret_type = builtins.TNone() - else: # syscall is not None + "return type", loc, is_syscall=True) + else: diag = diagnostic.Diagnostic("error", "system call must have a return type annotation", {}, self._function_loc(function), @@ -738,15 +728,23 @@ class Stitcher: self.engine.process(diag) ret_type = types.TVar() - if syscall is None: - function_type = types.TRPCFunction(arg_types, optarg_types, ret_type, - service=self.object_map.store(function)) - else: - function_type = types.TCFunction(arg_types, ret_type, - name=syscall, flags=flags) - + function_type = types.TCFunction(arg_types, ret_type, + name=function.artiq_embedded.syscall, + flags=function.artiq_embedded.flags) self.functions[function] = function_type + return function_type + def _quote_rpc(self, function, loc): + signature = inspect.signature(function) + + if signature.return_annotation is not inspect.Signature.empty: + ret_type = self._extract_annot(function, signature.return_annotation, + "return type", loc, is_syscall=False) + else: + ret_type = builtins.TNone() + + function_type = types.TRPC(ret_type, service=self.object_map.store(function)) + self.functions[function] = function_type return function_type def _quote_function(self, function, loc): @@ -780,9 +778,7 @@ class Stitcher: elif function.artiq_embedded.syscall is not None: # Insert a storage-less global whose type instructs the compiler # to perform a system call instead of a regular call. - self._quote_foreign_function(function, loc, - syscall=function.artiq_embedded.syscall, - flags=function.artiq_embedded.flags) + self._quote_syscall(function, loc) elif function.artiq_embedded.forbidden is not None: diag = diagnostic.Diagnostic("fatal", "this function cannot be called as an RPC", {}, @@ -792,14 +788,9 @@ class Stitcher: else: assert False else: - # Insert a storage-less global whose type instructs the compiler - # to perform an RPC instead of a regular call. - self._quote_foreign_function(function, loc, syscall=None, flags=None) + self._quote_rpc(function, loc) - function_type = self.functions[function] - if types.is_rpc_function(function_type): - function_type = types.instantiate(function_type) - return function_type + return self.functions[function] def _quote(self, value, loc): synthesizer = self._synthesizer(loc) diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 8414c45c9..57936c48c 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -23,12 +23,19 @@ def is_basic_block(typ): return isinstance(typ, TBasicBlock) class TOption(types.TMono): - def __init__(self, inner): - super().__init__("option", {"inner": inner}) + def __init__(self, value): + super().__init__("option", {"value": value}) def is_option(typ): return isinstance(typ, TOption) +class TKeyword(types.TMono): + def __init__(self, value): + super().__init__("keyword", {"value": value}) + +def is_keyword(typ): + return isinstance(typ, TKeyword) + class TExceptionTypeInfo(types.TMono): def __init__(self): super().__init__("exntypeinfo") @@ -678,7 +685,7 @@ class GetAttr(Instruction): typ = obj.type.attributes[attr] else: typ = obj.type.constructor.attributes[attr] - if types.is_function(typ): + if types.is_function(typ) or types.is_rpc(typ): typ = types.TMethod(obj.type, typ) super().__init__([obj], typ, name) self.attr = attr diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index a42cbe213..c5b25de7a 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -922,9 +922,6 @@ class ARTIQIRGenerator(algorithm.Visitor): if self.current_assign is None: return self.append(ir.GetAttr(obj, node.attr, name="{}.FLD.{}".format(_readable_name(obj), node.attr))) - elif types.is_rpc_function(self.current_assign.type): - # RPC functions are just type-level markers - return self.append(ir.Builtin("nop", [], builtins.TNone())) else: return self.append(ir.SetAttr(obj, node.attr, self.current_assign)) @@ -1719,7 +1716,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.engine.process(diag) def _user_call(self, callee, positional, keywords, arg_exprs={}): - if types.is_function(callee.type): + if types.is_function(callee.type) or types.is_rpc(callee.type): func = callee self_arg = None fn_typ = callee.type @@ -1734,40 +1731,51 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False - args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) - - for index, arg in enumerate(positional): - if index + offset < len(fn_typ.args): - args[index + offset] = arg + if types.is_rpc(fn_typ): + if self_arg is None: + args = positional else: - args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type))) + args = [self_arg] + positional - for keyword in keywords: - arg = keywords[keyword] - if keyword in fn_typ.args: - for index, arg_name in enumerate(fn_typ.args): - if keyword == arg_name: - assert args[index] is None - args[index] = arg - break - elif keyword in fn_typ.optargs: - for index, optarg_name in enumerate(fn_typ.optargs): - if keyword == optarg_name: - assert args[len(fn_typ.args) + index] is None - args[len(fn_typ.args) + index] = \ - self.append(ir.Alloc([arg], ir.TOption(arg.type))) - break + for keyword in keywords: + arg = keywords[keyword] + args.append(self.append(ir.Alloc([ir.Constant(keyword, builtins.TStr()), arg], + ir.TKeyword(arg.type)))) + else: + args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) - for index, optarg_name in enumerate(fn_typ.optargs): - if args[len(fn_typ.args) + index] is None: - args[len(fn_typ.args) + index] = \ - self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name]))) + for index, arg in enumerate(positional): + if index + offset < len(fn_typ.args): + args[index + offset] = arg + else: + args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type))) - if self_arg is not None: - assert args[0] is None - args[0] = self_arg + for keyword in keywords: + arg = keywords[keyword] + if keyword in fn_typ.args: + for index, arg_name in enumerate(fn_typ.args): + if keyword == arg_name: + assert args[index] is None + args[index] = arg + break + elif keyword in fn_typ.optargs: + for index, optarg_name in enumerate(fn_typ.optargs): + if keyword == optarg_name: + assert args[len(fn_typ.args) + index] is None + args[len(fn_typ.args) + index] = \ + self.append(ir.Alloc([arg], ir.TOption(arg.type))) + break - assert None not in args + for index, optarg_name in enumerate(fn_typ.optargs): + if args[len(fn_typ.args) + index] is None: + args[len(fn_typ.args) + index] = \ + self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name]))) + + if self_arg is not None: + assert args[0] is None + args[0] = self_arg + + assert None not in args if self.unwind_target is None: insn = self.append(ir.Call(func, args, arg_exprs)) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 460be6767..3d2bb0541 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -109,17 +109,11 @@ class Inferencer(algorithm.Visitor): ] attr_type = object_type.attributes[attr_name] - if types.is_rpc_function(attr_type): - attr_type = types.instantiate(attr_type) - self._unify(result_type, attr_type, loc, None, makenotes=makenotes, when=" for attribute '{}'".format(attr_name)) elif types.is_instance(object_type) and \ attr_name in object_type.constructor.attributes: attr_type = object_type.constructor.attributes[attr_name].find() - if types.is_rpc_function(attr_type): - attr_type = types.instantiate(attr_type) - if types.is_function(attr_type): # Convert to a method. if len(attr_type.args) < 1: @@ -155,6 +149,10 @@ class Inferencer(algorithm.Visitor): when=" while inferring the type for self argument") attr_type = types.TMethod(object_type, attr_type) + elif types.is_rpc(attr_type): + # Convert to a method. We don't have to bother typechecking + # the self argument, since for RPCs anything goes. + attr_type = types.TMethod(object_type, attr_type) if not types.is_var(attr_type): self._unify(result_type, attr_type, @@ -871,6 +869,10 @@ class Inferencer(algorithm.Visitor): return # not enough info yet elif types.is_builtin(typ): return self.visit_builtin_call(node) + elif types.is_rpc(typ): + self._unify(node.type, typ.ret, + node.loc, None) + return elif not (types.is_function(typ) or types.is_method(typ)): diag = diagnostic.Diagnostic("error", "cannot call this expression of type {type}", @@ -888,6 +890,10 @@ class Inferencer(algorithm.Visitor): typ = types.get_method_function(typ) if types.is_var(typ): return # not enough info yet + elif types.is_rpc(typ): + self._unify(node.type, typ.ret, + node.loc, None) + return typ_arity = typ.arity() - 1 typ_args = OrderedDict(list(typ.args.items())[1:]) diff --git a/artiq/compiler/transforms/iodelay_estimator.py b/artiq/compiler/transforms/iodelay_estimator.py index fff7470de..90bfefdb3 100644 --- a/artiq/compiler/transforms/iodelay_estimator.py +++ b/artiq/compiler/transforms/iodelay_estimator.py @@ -280,7 +280,7 @@ class IODelayEstimator(algorithm.Visitor): context="as an argument for delay_mu()") call_delay = value elif not types.is_builtin(typ): - if types.is_function(typ): + if types.is_function(typ) or types.is_rpc(typ): offset = 0 elif types.is_method(typ): offset = 1 @@ -288,35 +288,38 @@ class IODelayEstimator(algorithm.Visitor): else: assert False - delay = typ.find().delay.find() - if types.is_var(delay): - raise _UnknownDelay() - elif delay.is_indeterminate(): - note = diagnostic.Diagnostic("note", - "function called here", {}, - node.loc) - cause = delay.cause - cause = diagnostic.Diagnostic(cause.level, cause.reason, cause.arguments, - cause.location, cause.highlights, - cause.notes + [note]) - raise _IndeterminateDelay(cause) - elif delay.is_fixed(): - args = {} - for kw_node in node.keywords: - args[kw_node.arg] = kw_node.value - for arg_name, arg_node in zip(list(typ.args)[offset:], node.args): - args[arg_name] = arg_node - - free_vars = delay.duration.free_vars() - node.arg_exprs = { - arg: self.evaluate(args[arg], abort=abort, - context="in the expression for argument '{}' " - "that affects I/O delay".format(arg)) - for arg in free_vars - } - call_delay = delay.duration.fold(node.arg_exprs) + if types.is_rpc(typ): + call_delay = iodelay.Const(0) else: - assert False + delay = typ.find().delay.find() + if types.is_var(delay): + raise _UnknownDelay() + elif delay.is_indeterminate(): + note = diagnostic.Diagnostic("note", + "function called here", {}, + node.loc) + cause = delay.cause + cause = diagnostic.Diagnostic(cause.level, cause.reason, cause.arguments, + cause.location, cause.highlights, + cause.notes + [note]) + raise _IndeterminateDelay(cause) + elif delay.is_fixed(): + args = {} + for kw_node in node.keywords: + args[kw_node.arg] = kw_node.value + for arg_name, arg_node in zip(list(typ.args)[offset:], node.args): + args[arg_name] = arg_node + + free_vars = delay.duration.free_vars() + node.arg_exprs = { + arg: self.evaluate(args[arg], abort=abort, + context="in the expression for argument '{}' " + "that affects I/O delay".format(arg)) + for arg in free_vars + } + call_delay = delay.duration.fold(node.arg_exprs) + else: + assert False else: call_delay = iodelay.Const(0) diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 816a7d6e8..acec6953a 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -177,7 +177,7 @@ class LLVMIRGenerator: typ = typ.find() if types.is_tuple(typ): return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts]) - elif types.is_rpc_function(typ) or types.is_c_function(typ): + elif types.is_rpc(typ) or types.is_c_function(typ): if for_return: return llvoid else: @@ -229,7 +229,9 @@ class LLVMIRGenerator: elif ir.is_basic_block(typ): return llptr elif ir.is_option(typ): - return ll.LiteralStructType([lli1, self.llty_of_type(typ.params["inner"])]) + return ll.LiteralStructType([lli1, self.llty_of_type(typ.params["value"])]) + elif ir.is_keyword(typ): + return ll.LiteralStructType([llptr, self.llty_of_type(typ.params["value"])]) elif ir.is_environment(typ): llty = self.llcontext.get_identified_type("env.{}".format(typ.env_name)) if llty.elements is None: @@ -618,7 +620,7 @@ class LLVMIRGenerator: size=llsize) llvalue = self.llbuilder.insert_value(llvalue, llalloc, 1, name=insn.name) return llvalue - elif not builtins.is_allocated(insn.type): + elif not builtins.is_allocated(insn.type) or ir.is_keyword(insn.type): llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) for index, elt in enumerate(insn.operands): llvalue = self.llbuilder.insert_value(llvalue, self.map(elt), index) @@ -707,8 +709,8 @@ class LLVMIRGenerator: def get_global_closure(self, typ, attr): closure_type = typ.attributes[attr] assert types.is_constructor(typ) - assert types.is_function(closure_type) - if types.is_c_function(closure_type) or types.is_rpc_function(closure_type): + assert types.is_function(closure_type) or types.is_rpc(closure_type) + if types.is_c_function(closure_type) or types.is_rpc(closure_type): return None llty = self.llty_of_type(typ.attributes[attr]) @@ -1156,8 +1158,8 @@ class LLVMIRGenerator: elif builtins.is_range(typ): return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ), error_handler) - elif ir.is_option(typ): - return b"o" + self._rpc_tag(typ.params["inner"], + elif ir.is_keyword(typ): + return b"k" + self._rpc_tag(typ.params["value"], error_handler) elif '__objectid__' in typ.attributes: return b"O" @@ -1271,7 +1273,7 @@ class LLVMIRGenerator: def process_Call(self, insn): functiontyp = insn.target_function().type - if types.is_rpc_function(functiontyp): + if types.is_rpc(functiontyp): return self._build_rpc(insn.target_function().loc, functiontyp, insn.arguments(), @@ -1303,7 +1305,7 @@ class LLVMIRGenerator: functiontyp = insn.target_function().type llnormalblock = self.map(insn.normal_target()) llunwindblock = self.map(insn.exception_target()) - if types.is_rpc_function(functiontyp): + if types.is_rpc(functiontyp): return self._build_rpc(insn.target_function().loc, functiontyp, insn.arguments(), @@ -1392,7 +1394,7 @@ class LLVMIRGenerator: lleltsptr = llglobal.bitcast(lleltsary.type.element.as_pointer()) llconst = ll.Constant(llty, [ll.Constant(lli32, len(llelts)), lleltsptr]) return llconst - elif types.is_function(typ): + elif types.is_rpc(typ) or types.is_function(typ): # RPC and C functions have no runtime representation. # We only get down this codepath for ARTIQ Python functions when they're # referenced from a constructor, and the value inside the constructor diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index 26fc24d95..cca1d15a9 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -94,9 +94,6 @@ class TVar(Type): else: return self.find().fold(accum, fn) - def map(self, fn): - return fn(self) - def __repr__(self): if self.parent is self: return "" % id(self) @@ -141,21 +138,6 @@ class TMono(Type): accum = self.params[param].fold(accum, fn) return fn(accum, self) - def map(self, fn): - params = OrderedDict() - for param in self.params: - params[param] = self.params[param].map(fn) - - attributes = OrderedDict() - for attr in self.attributes: - attributes[attr] = self.attributes[attr].map(fn) - - self_copy = self.__class__.__new__(self.__class__) - self_copy.name = self.name - self_copy.params = params - self_copy.attributes = attributes - return fn(self_copy) - def __repr__(self): return "artiq.compiler.types.TMono(%s, %s)" % (repr(self.name), repr(self.params)) @@ -202,9 +184,6 @@ class TTuple(Type): accum = elt.fold(accum, fn) return fn(accum, self) - def map(self, fn): - return fn(TTuple(list(map(lambda elt: elt.map(fn), self.elts)))) - def __repr__(self): return "artiq.compiler.types.TTuple(%s)" % repr(self.elts) @@ -276,23 +255,6 @@ class TFunction(Type): accum = self.ret.fold(accum, fn) return fn(accum, self) - def _map_args(self, fn): - args = OrderedDict() - for arg in self.args: - args[arg] = self.args[arg].map(fn) - - optargs = OrderedDict() - for optarg in self.optargs: - optargs[optarg] = self.optargs[optarg].map(fn) - - return args, optargs, self.ret.map(fn) - - def map(self, fn): - args, optargs, ret = self._map_args(fn) - self_copy = TFunction(args, optargs, ret) - self_copy.delay = self.delay.map(fn) - return fn(self_copy) - def __repr__(self): return "artiq.compiler.types.TFunction({}, {}, {})".format( repr(self.args), repr(self.optargs), repr(self.ret)) @@ -308,35 +270,6 @@ class TFunction(Type): def __hash__(self): return hash((_freeze(self.args), _freeze(self.optargs), self.ret)) -class TRPCFunction(TFunction): - """ - A function type of a remote function. - - :ivar service: (int) RPC service number - """ - - attributes = OrderedDict() - - def __init__(self, args, optargs, ret, service): - super().__init__(args, optargs, ret) - self.service = service - self.delay = TFixedDelay(iodelay.Const(0)) - - def unify(self, other): - if isinstance(other, TRPCFunction) and \ - self.service == other.service: - super().unify(other) - elif isinstance(other, TVar): - other.unify(self) - else: - raise UnificationError(self, other) - - def map(self, fn): - args, optargs, ret = self._map_args(fn) - self_copy = TRPCFunction(args, optargs, ret, self.service) - self_copy.delay = self.delay.map(fn) - return fn(self_copy) - class TCFunction(TFunction): """ A function type of a runtime-provided C function. @@ -368,11 +301,49 @@ class TCFunction(TFunction): else: raise UnificationError(self, other) - def map(self, fn): - args, _optargs, ret = self._map_args(fn) - self_copy = TCFunction(args, ret, self.name) - self_copy.delay = self.delay.map(fn) - return fn(self_copy) +class TRPC(Type): + """ + A type of a remote call. + + :ivar ret: (:class:`Type`) + return type + :ivar service: (int) RPC service number + """ + + attributes = OrderedDict() + + def __init__(self, ret, service): + assert isinstance(ret, Type) + self.ret, self.service = ret, service + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TRPC) and \ + self.service == other.service: + self.ret.unify(other.ret) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + + def fold(self, accum, fn): + accum = self.ret.fold(accum, fn) + return fn(accum, self) + + def __repr__(self): + return "artiq.compiler.types.TRPC({})".format(repr(self.ret)) + + def __eq__(self, other): + return isinstance(other, TRPC) and \ + self.service == other.service + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.service) class TBuiltin(Type): """ @@ -395,9 +366,6 @@ class TBuiltin(Type): def fold(self, accum, fn): return fn(accum, self) - def map(self, fn): - return fn(self) - def __repr__(self): return "artiq.compiler.types.{}({})".format(type(self).__name__, repr(self.name)) @@ -490,9 +458,6 @@ class TValue(Type): def fold(self, accum, fn): return fn(accum, self) - def map(self, fn): - return fn(self) - def __repr__(self): return "artiq.compiler.types.TValue(%s)" % repr(self.value) @@ -543,10 +508,6 @@ class TDelay(Type): # delay types do not participate in folding pass - def map(self, fn): - # or mapping - return self - def __eq__(self, other): return isinstance(other, TDelay) and \ (self.duration == other.duration and \ @@ -570,18 +531,6 @@ def TFixedDelay(duration): return TDelay(duration, None) -def instantiate(typ): - tvar_map = dict() - def mapper(typ): - typ = typ.find() - if is_var(typ): - if typ not in tvar_map: - tvar_map[typ] = TVar() - return tvar_map[typ] - return typ - - return typ.map(mapper) - def is_var(typ): return isinstance(typ.find(), TVar) @@ -616,8 +565,8 @@ def _is_pointer(typ): def is_function(typ): return isinstance(typ.find(), TFunction) -def is_rpc_function(typ): - return isinstance(typ.find(), TRPCFunction) +def is_rpc(typ): + return isinstance(typ.find(), TRPC) def is_c_function(typ, name=None): typ = typ.find() @@ -732,7 +681,7 @@ class TypePrinter(object): return "(%s,)" % self.name(typ.elts[0], depth + 1) else: return "(%s)" % ", ".join([self.name(typ, depth + 1) for typ in typ.elts]) - elif isinstance(typ, (TFunction, TRPCFunction, TCFunction)): + elif isinstance(typ, (TFunction, TCFunction)): args = [] args += [ "%s:%s" % (arg, self.name(typ.args[arg], depth + 1)) for arg in typ.args] @@ -746,12 +695,12 @@ class TypePrinter(object): elif not (delay.is_fixed() and iodelay.is_zero(delay.duration)): signature += " " + self.name(delay, depth + 1) - if isinstance(typ, TRPCFunction): - return "[rpc #{}]{}".format(typ.service, signature) if isinstance(typ, TCFunction): return "[ffi {}]{}".format(repr(typ.name), signature) elif isinstance(typ, TFunction): return signature + elif isinstance(typ, TRPC): + return "[rpc #{}](...)->{}".format(typ.service, self.name(typ.ret, depth + 1)) elif isinstance(typ, TBuiltinFunction): return "".format(typ.name) elif isinstance(typ, (TConstructor, TExceptionConstructor)): diff --git a/artiq/compiler/validators/monomorphism.py b/artiq/compiler/validators/monomorphism.py index f30ac5288..0911deb40 100644 --- a/artiq/compiler/validators/monomorphism.py +++ b/artiq/compiler/validators/monomorphism.py @@ -30,7 +30,7 @@ class MonomorphismValidator(algorithm.Visitor): super().generic_visit(node) if isinstance(node, asttyped.commontyped): - if types.is_polymorphic(node.type) and not types.is_rpc_function(node.type): + if types.is_polymorphic(node.type): note = diagnostic.Diagnostic("note", "the expression has type {type}", {"type": types.TypePrinter().name(node.type)}, diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index b099acec0..35ef9b3ce 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -3,6 +3,7 @@ import logging import traceback from enum import Enum from fractions import Fraction +from collections import namedtuple from artiq.coredevice import exceptions from artiq.language.core import int as wrapping_int @@ -62,6 +63,9 @@ class RPCReturnValueError(ValueError): pass +RPCKeyword = namedtuple('RPCKeyword', ['name', 'value']) + + class CommGeneric: def __init__(self): self._read_type = self._write_type = None @@ -229,7 +233,8 @@ class CommGeneric: raise UnsupportedDevice("Unsupported runtime ID: {}" .format(runtime_id)) gateware_version = self._read_chunk(self._read_length).decode("utf-8") - if gateware_version != software_version: + if gateware_version != software_version and \ + gateware_version + ".dirty" != software_version: logger.warning("Mismatch between gateware (%s) " "and software (%s) versions", gateware_version, software_version) @@ -298,7 +303,6 @@ class CommGeneric: logger.debug("running kernel") _rpc_sentinel = object() - _rpc_undefined = object() # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. def _receive_rpc_value(self, object_map): @@ -332,27 +336,23 @@ class CommGeneric: stop = self._receive_rpc_value(object_map) step = self._receive_rpc_value(object_map) return range(start, stop, step) - elif tag == "o": - present = self._read_int8() - if present: - return self._receive_rpc_value(object_map) - else: - return self._rpc_undefined + elif tag == "k": + name = self._read_string() + value = self._receive_rpc_value(object_map) + return RPCKeyword(name, value) elif tag == "O": return object_map.retrieve(self._read_int32()) else: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) def _receive_rpc_args(self, object_map, defaults): - args = [] - default_arg_num = 0 + args, kwargs = [], {} while True: value = self._receive_rpc_value(object_map) if value is self._rpc_sentinel: - return args - elif value is self._rpc_undefined: - args.append(defaults[default_arg_num]) - default_arg_num += 1 + return args, kwargs + elif isinstance(value, RPCKeyword): + kwargs[value.name] = value.value else: args.append(value) @@ -443,13 +443,13 @@ class CommGeneric: else: service = object_map.retrieve(service_id) - arguments = self._receive_rpc_args(object_map, service.__defaults__) - return_tags = self._read_bytes() - logger.debug("rpc service: [%d]%r %r -> %s", service_id, service, arguments, return_tags) + args, kwargs = self._receive_rpc_args(object_map, service.__defaults__) + return_tags = self._read_bytes() + logger.debug("rpc service: [%d]%r %r %r -> %s", service_id, service, args, kwargs, return_tags) try: - result = service(*arguments) - logger.debug("rpc service: %d %r == %r", service_id, arguments, result) + result = service(*args, **kwargs) + logger.debug("rpc service: %d %r %r == %r", service_id, args, kwargs, result) if service_id != 0: self._write_header(_H2DMsgType.RPC_REPLY) @@ -457,7 +457,7 @@ class CommGeneric: self._send_rpc_value(bytearray(return_tags), result, result, service) self._write_flush() except Exception as exn: - logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn) + logger.debug("rpc service: %d %r %r ! %r", service_id, args, kwargs, exn) self._write_header(_H2DMsgType.RPC_EXCEPTION) @@ -486,7 +486,13 @@ class CommGeneric: for index in range(3): self._write_int64(0) - (_, (filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__, 2) + tb = traceback.extract_tb(exn.__traceback__, 2) + if len(tb) == 2: + (_, (filename, line, function, _), ) = tb + elif len(tb) == 1: + ((filename, line, function, _), ) = tb + else: + assert False self._write_string(filename) self._write_int32(line) self._write_int32(-1) # column not known diff --git a/artiq/runtime/session.c b/artiq/runtime/session.c index f2f8a8881..517614e90 100644 --- a/artiq/runtime/session.c +++ b/artiq/runtime/session.c @@ -862,21 +862,16 @@ static int send_rpc_value(const char **tag, void **value) break; } - case 'o': { // option(inner='a) - struct { int8_t present; struct {} contents; } *option = *value; + case 'k': { // keyword(value='a) + struct { const char *name; struct {} contents; } *option = *value; void *contents = &option->contents; - if(!out_packet_int8(option->present)) + if(!out_packet_string(option->name)) return 0; - // option never appears in composite types, so we don't have + // keyword never appears in composite types, so we don't have // to accurately advance *value. - if(option->present) { - return send_rpc_value(tag, &contents); - } else { - skip_rpc_value(tag); - break; - } + return send_rpc_value(tag, &contents); } case 'O': { // host object diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index 5debdb2cf..867dd5fa3 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -59,6 +59,56 @@ class DefaultArgTest(ExperimentCase): self.assertEqual(exp.run(), 42) +class _RPC(EnvExperiment): + def build(self): + self.setattr_device("core") + + def args(self, *args) -> TInt32: + return len(args) + + def kwargs(self, x="", **kwargs) -> TInt32: + return len(kwargs) + + @kernel + def args0(self): + return self.args() + + @kernel + def args1(self): + return self.args("A") + + @kernel + def args2(self): + return self.args("A", 1) + + @kernel + def kwargs0(self): + return self.kwargs() + + @kernel + def kwargs1(self): + return self.kwargs(a="A") + + @kernel + def kwargs2(self): + return self.kwargs(a="A", b=1) + + @kernel + def args1kwargs2(self): + return self.kwargs("X", a="A", b=1) + +class RPCTest(ExperimentCase): + def test_args(self): + exp = self.create(_RPC) + self.assertEqual(exp.args0(), 0) + self.assertEqual(exp.args1(), 1) + self.assertEqual(exp.args2(), 2) + self.assertEqual(exp.kwargs0(), 0) + self.assertEqual(exp.kwargs1(), 1) + self.assertEqual(exp.kwargs2(), 2) + self.assertEqual(exp.args1kwargs2(), 2) + + class _Payload1MB(EnvExperiment): def build(self): self.setattr_device("core") diff --git a/artiq/test/lit/embedding/error_rpc_default_unify.py b/artiq/test/lit/embedding/error_rpc_default_unify.py deleted file mode 100644 index 92da137b7..000000000 --- a/artiq/test/lit/embedding/error_rpc_default_unify.py +++ /dev/null @@ -1,15 +0,0 @@ -# RUN: %python -m artiq.compiler.testbench.embedding +diag %s >%t -# RUN: OutputCheck %s --file-to-check=%t - -from artiq.language.core import * -from artiq.language.types import * - -# CHECK-L: :1: error: cannot unify int(width='a) with str -# CHECK-L: ${LINE:+1}: note: expanded from here while trying to infer a type for an unannotated optional argument 'x' from its default value -def foo(x=[1,"x"]): - pass - -@kernel -def entrypoint(): - # CHECK-L: ${LINE:+1}: note: in function called remotely here - foo()