From b26af5df60c8e131546a9e14886aa7af78d9779b Mon Sep 17 00:00:00 2001 From: whitequark Date: Sun, 9 Aug 2015 02:17:19 +0300 Subject: [PATCH] Implement sending RPCs. --- artiq/compiler/builtins.py | 4 +- artiq/compiler/embedding.py | 125 ++++++++++-- artiq/compiler/module.py | 19 +- .../compiler/transforms/llvm_ir_generator.py | 98 +++++++++- artiq/compiler/types.py | 34 +++- artiq/coredevice/comm_generic.py | 41 ++-- artiq/coredevice/core.py | 4 +- soc/runtime/ksupport.c | 35 ++-- soc/runtime/ksupport.h | 2 +- soc/runtime/messages.h | 19 +- soc/runtime/session.c | 183 ++++++++++++------ 11 files changed, 433 insertions(+), 131 deletions(-) diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index 86a356ef3..4a8280631 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -163,7 +163,7 @@ def is_bool(typ): def is_int(typ, width=None): if width is not None: - return types.is_mono(typ, "int", {"width": width}) + return types.is_mono(typ, "int", width=width) else: return types.is_mono(typ, "int") @@ -184,7 +184,7 @@ def is_numeric(typ): def is_list(typ, elt=None): if elt is not None: - return types.is_mono(typ, "list", {"elt": elt}) + return types.is_mono(typ, "list", elt=elt) else: return types.is_mono(typ, "list") diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 4a744131d..bf1ed49d0 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -5,10 +5,13 @@ the references to the host objects and translates the functions annotated as ``@kernel`` when they are referenced. """ -import inspect, os +import os, re, linecache, inspect +from collections import OrderedDict + from pythonparser import ast, source, diagnostic, parse_buffer + from . import types, builtins, asttyped, prelude -from .transforms import ASTTypedRewriter, Inferencer +from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer class ASTSynthesizer: @@ -45,6 +48,9 @@ class ASTSynthesizer: typ = builtins.TFloat() return asttyped.NumT(n=value, ctx=None, type=typ, loc=self._add(repr(value))) + elif isinstance(value, str): + return asttyped.StrT(s=value, ctx=None, type=builtins.TStr(), + loc=self._add(repr(value))) elif isinstance(value, list): begin_loc = self._add("[") elts = [] @@ -123,7 +129,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter): if inspect.isfunction(value): # It's a function. We need to translate the function and insert # a reference to it. - function_name = self.quote_function(value) + function_name = self.quote_function(value, node.loc) return asttyped.NameT(id=function_name, ctx=None, type=self.globals[function_name], loc=node.loc) @@ -154,7 +160,19 @@ class Stitcher: self.functions = {} + self.next_rpc = 0 self.rpc_map = {} + self.inverse_rpc_map = {} + + def _map(self, obj): + obj_id = id(obj) + if obj_id in self.inverse_rpc_map: + return self.inverse_rpc_map[obj_id] + + self.next_rpc += 1 + self.rpc_map[self.next_rpc] = obj + self.inverse_rpc_map[obj_id] = self.next_rpc + return self.next_rpc def _iterate(self): inferencer = Inferencer(engine=self.engine) @@ -213,17 +231,102 @@ class Stitcher: quote_function=self._quote_function) return asttyped_rewriter.visit(function_node) - def _quote_function(self, function): + def _function_def_note(self, function): + filename = function.__code__.co_filename + line = function.__code__.co_firstlineno + name = function.__code__.co_name + + source_line = linecache.getline(filename, line) + column = re.search("def", source_line).start(0) + source_buffer = source.Buffer(source_line, filename, line) + loc = source.Range(source_buffer, column, column) + return diagnostic.Diagnostic("note", + "definition of function '{function}'", + {"function": name}, + loc) + + def _type_of_param(self, function, loc, param): + if param.default is not inspect.Parameter.empty: + # Try and infer the type from the default value. + # This is tricky, because the default value might not have + # a well-defined type in APython. + # In this case, we bail out, but mention why we do it. + synthesizer = ASTSynthesizer() + ast = synthesizer.quote(param.default) + synthesizer.finalize() + + def proxy_diagnostic(diag): + note = diagnostic.Diagnostic("note", + "expanded from here while trying to infer a type for an" + " unannotated optional argument '{param_name}' from its default value", + {"param_name": param.name}, + loc) + diag.notes.append(note) + + diag.notes.append(self._function_def_note(function)) + + self.engine.process(diag) + + proxy_engine = diagnostic.Engine() + proxy_engine.process = proxy_diagnostic + Inferencer(engine=proxy_engine).visit(ast) + IntMonomorphizer(engine=proxy_engine).visit(ast) + + return ast.type + else: + # Let the rest of the program decide. + return types.TVar() + + def _quote_rpc_function(self, function, loc): + signature = inspect.signature(function) + + arg_types = OrderedDict() + optarg_types = OrderedDict() + for param in signature.parameters.values(): + if param.kind not in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD): + # We pretend we don't see *args, kwpostargs=..., **kwargs. + # Since every method can be still invoked without any arguments + # going into *args and the slots after it, this is always safe, + # if sometimes constraining. + # + # Accepting POSITIONAL_ONLY is OK, because the compiler + # desugars the keyword arguments into positional ones internally. + continue + + if param.default is inspect.Parameter.empty: + arg_types[param.name] = self._type_of_param(function, loc, param) + else: + optarg_types[param.name] = self._type_of_param(function, loc, param) + + # Fixed for now. + ret_type = builtins.TInt(types.TValue(32)) + + rpc_type = types.TRPCFunction(arg_types, optarg_types, ret_type, + service=self._map(function)) + + rpc_name = "__rpc_{}__".format(rpc_type.service) + self.globals[rpc_name] = rpc_type + self.functions[function] = rpc_name + + return rpc_name + + def _quote_function(self, function, loc): if function in self.functions: return self.functions[function] - # Insert the typed AST for the new function and restart inference. - # It doesn't really matter where we insert as long as it is before - # the final call. - function_node = self._quote_embedded_function(function) - self.typedtree.insert(0, function_node) - self.inference_finished = False - return function_node.name + if hasattr(function, "artiq_embedded"): + # Insert the typed AST for the new function and restart inference. + # It doesn't really matter where we insert as long as it is before + # the final call. + function_node = self._quote_embedded_function(function) + self.typedtree.insert(0, function_node) + self.inference_finished = False + return function_node.name + else: + # Insert a storage-less global whose type instructs the compiler + # to perform an RPC instead of a regular call. + return self._quote_rpc_function(function, loc) def stitch_call(self, function, args, kwargs): function_node = self._quote_embedded_function(function) diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index 7f77c04dd..f6285540a 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -41,14 +41,16 @@ class Source: class Module: def __init__(self, src): - int_monomorphizer = transforms.IntMonomorphizer(engine=src.engine) - inferencer = transforms.Inferencer(engine=src.engine) - monomorphism_validator = validators.MonomorphismValidator(engine=src.engine) - escape_validator = validators.EscapeValidator(engine=src.engine) - artiq_ir_generator = transforms.ARTIQIRGenerator(engine=src.engine, + self.engine = src.engine + + int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine) + inferencer = transforms.Inferencer(engine=self.engine) + monomorphism_validator = validators.MonomorphismValidator(engine=self.engine) + escape_validator = validators.EscapeValidator(engine=self.engine) + artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine, module_name=src.name) - dead_code_eliminator = transforms.DeadCodeEliminator(engine=src.engine) - local_access_validator = validators.LocalAccessValidator(engine=src.engine) + dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine) + local_access_validator = validators.LocalAccessValidator(engine=self.engine) self.name = src.name self.globals = src.globals @@ -62,7 +64,8 @@ class Module: def build_llvm_ir(self, target): """Compile the module to LLVM IR for the specified target.""" - llvm_ir_generator = transforms.LLVMIRGenerator(module_name=self.name, target=target) + llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine, + module_name=self.name, target=target) return llvm_ir_generator.process(self.artiq_ir) def entry_point(self): diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index bccc2397c..ca6e8da35 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -3,12 +3,13 @@ into LLVM intermediate representation. """ -from pythonparser import ast +from pythonparser import ast, diagnostic from llvmlite_artiq import ir as ll from .. import types, builtins, ir class LLVMIRGenerator: - def __init__(self, module_name, target): + def __init__(self, engine, module_name, target): + self.engine = engine self.target = target self.llcontext = target.llcontext self.llmodule = ll.Module(context=self.llcontext, name=module_name) @@ -21,6 +22,11 @@ class LLVMIRGenerator: typ = typ.find() if types.is_tuple(typ): return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts]) + elif types.is_rpc_function(typ): + if for_return: + return ll.VoidType() + else: + return ll.LiteralStructType([]) elif types.is_function(typ): envarg = ll.IntType(8).as_pointer() llty = ll.FunctionType(args=[envarg] + @@ -89,10 +95,13 @@ class LLVMIRGenerator: return ll.Constant(llty, False) elif isinstance(const.value, (int, float)): return ll.Constant(llty, const.value) - elif isinstance(const.value, str): - assert "\0" not in const.value + elif isinstance(const.value, (str, bytes)): + if isinstance(const.value, str): + assert "\0" not in const.value + as_bytes = (const.value + "\0").encode("utf-8") + else: + as_bytes = const.value - as_bytes = (const.value + "\0").encode("utf-8") if ir.is_exn_typeinfo(const.type): # Exception typeinfo; should be merged with identical others name = "__artiq_exn_" + const.value @@ -144,6 +153,9 @@ class LLVMIRGenerator: llty = ll.FunctionType(ll.VoidType(), [self.llty_of_type(builtins.TException())]) elif name == "__artiq_reraise": llty = ll.FunctionType(ll.VoidType(), []) + elif name == "rpc": + llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()], + var_arg=True) else: assert False @@ -546,11 +558,79 @@ class LLVMIRGenerator: name=insn.name) return llvalue + # See session.c:send_rpc_value. + def _rpc_tag(self, typ, root_type, root_loc): + if types.is_tuple(typ): + assert len(typ.elts) < 256 + return b"t" + bytes([len(typ.elts)]) + \ + b"".join([self._rpc_tag(elt_type, root_type, root_loc) + for elt_type in typ.elts]) + elif builtins.is_none(typ): + return b"n" + elif builtins.is_bool(typ): + return b"b" + elif builtins.is_int(typ, types.TValue(32)): + return b"i" + elif builtins.is_int(typ, types.TValue(64)): + return b"I" + elif builtins.is_float(typ): + return b"f" + elif builtins.is_str(typ): + return b"s" + elif builtins.is_list(typ): + return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ), + root_type, root_loc) + elif builtins.is_range(typ): + return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ), + root_type, root_loc) + elif ir.is_option(typ): + return b"o" + self._rpc_tag(typ.params["inner"], + root_type, root_loc) + else: + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "value of type {type}", + {"type": printer.name(root_type)}, + root_loc) + diag = diagnostic.Diagnostic("error", + "type {type} is not supported in remote procedure calls", + {"type": printer.name(typ)}, + root_loc) + self.engine.process(diag) + + def _build_rpc(self, service, args, return_type): + llservice = ll.Constant(ll.IntType(32), service) + + tag = b"" + for arg in args: + if isinstance(arg, ir.Constant): + # Constants don't have locations, but conveniently + # they also never fail to serialize. + tag += self._rpc_tag(arg.type, arg.type, None) + else: + tag += self._rpc_tag(arg.type, arg.type, arg.loc) + tag += b":\x00" + lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr())) + + llargs = [] + for arg in args: + llarg = self.map(arg) + llargslot = self.llbuilder.alloca(llarg.type) + self.llbuilder.store(llarg, llargslot) + llargs.append(llargslot) + + return self.llbuiltin("rpc"), [llservice, lltag] + llargs + def prepare_call(self, insn): - llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments()) - llenv = self.llbuilder.extract_value(llclosure, 0) - llfun = self.llbuilder.extract_value(llclosure, 1) - return llfun, [llenv] + list(llargs) + if types.is_rpc_function(insn.target_function().type): + return self._build_rpc(insn.target_function().type.service, + insn.arguments(), + insn.target_function().type.ret) + else: + llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments()) + llenv = self.llbuilder.extract_value(llclosure, 0) + llfun = self.llbuilder.extract_value(llclosure, 1) + return llfun, [llenv] + list(llargs) def process_Call(self, insn): llfun, llargs = self.prepare_call(insn) diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index a614ce637..4c145720c 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -222,6 +222,26 @@ class TFunction(Type): def __ne__(self, other): return not (self == other) +class TRPCFunction(TFunction): + """ + A function type of a remote function. + + :ivar service: (int) RPC service number + """ + + def __init__(self, args, optargs, ret, service): + super().__init__(args, optargs, ret) + self.service = service + + def unify(self, other): + if isinstance(other, TRPCFunction) and \ + self.service == other.service: + super().unify(other) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + class TBuiltin(Type): """ An instance of builtin type. Every instance of a builtin @@ -310,6 +330,8 @@ def is_mono(typ, name=None, **params): typ = typ.find() params_match = True for param in params: + if param not in typ.params: + return False params_match = params_match and \ typ.params[param].find() == params[param].find() return isinstance(typ, TMono) and \ @@ -329,6 +351,9 @@ def is_tuple(typ, elts=None): def is_function(typ): return isinstance(typ.find(), TFunction) +def is_rpc_function(typ): + return isinstance(typ.find(), TRPCFunction) + def is_builtin(typ, name=None): typ = typ.find() if name is None: @@ -381,11 +406,16 @@ class TypePrinter(object): return "(%s,)" % self.name(typ.elts[0]) else: return "(%s)" % ", ".join(list(map(self.name, typ.elts))) - elif isinstance(typ, TFunction): + elif isinstance(typ, (TFunction, TRPCFunction)): args = [] args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args] args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs] - return "(%s)->%s" % (", ".join(args), self.name(typ.ret)) + signature = "(%s)->%s" % (", ".join(args), self.name(typ.ret)) + + if isinstance(typ, TRPCFunction): + return "rpc({}) {}".format(typ.service, signature) + elif isinstance(typ, TFunction): + return signature elif isinstance(typ, TBuiltinFunction): return "" % typ.name elif isinstance(typ, (TConstructor, TExceptionConstructor)): diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index df0cbb4cd..54246d525 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -276,8 +276,16 @@ class CommGeneric: self._write_empty(_H2DMsgType.RUN_KERNEL) logger.debug("running kernel") - def _receive_rpc_value(self, tag, rpc_map): - if tag == "n": + _rpc_sentinel = object() + + def _receive_rpc_value(self, rpc_map): + tag = chr(self._read_int8()) + if tag == "\x00": + return self._rpc_sentinel + elif tag == "t": + length = self._read_int8() + return tuple(self._receive_rpc_value(rpc_map) for _ in range(length)) + elif tag == "n": return None elif tag == "b": return bool(self._read_int8()) @@ -291,31 +299,36 @@ class CommGeneric: numerator = self._read_int64() denominator = self._read_int64() return Fraction(numerator, denominator) + elif tag == "s": + return self._read_string() elif tag == "l": - elt_tag = chr(self._read_int8()) length = self._read_int32() - return [self._receive_rpc_value(elt_tag) for _ in range(length)] + return [self._receive_rpc_value(rpc_map) for _ in range(length)] + elif tag == "r": + lower = self._receive_rpc_value(rpc_map) + upper = self._receive_rpc_value(rpc_map) + step = self._receive_rpc_value(rpc_map) + return range(lower, upper, step) elif tag == "o": return rpc_map[self._read_int32()] else: - raise IOError("Unknown RPC value tag: {}", tag) + raise IOError("Unknown RPC value tag: {}".format(repr(tag))) - def _receive_rpc_values(self, rpc_map): - result = [] + def _receive_rpc_args(self, rpc_map): + args = [] while True: - tag = chr(self._read_int8()) - if tag == "\x00": - return result - else: - result.append(self._receive_rpc_value(tag, rpc_map)) + value = self._receive_rpc_value(rpc_map) + if value is self._rpc_sentinel: + return args + args.append(value) def _serve_rpc(self, rpc_map): service = self._read_int32() - args = self._receive_rpc_values(rpc_map) + args = self._receive_rpc_args(rpc_map) logger.debug("rpc service: %d %r", service, args) try: - result = rpc_map[rpc_num](args) + result = rpc_map[service](*args) if not isinstance(result, int) or not (-2**31 < result < 2**31-1): raise ValueError("An RPC must return an int(width=32)") except ARTIQException as exn: diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index ce9ae391b..5188de1ae 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -50,13 +50,13 @@ class Core: raise CompileError() from error def run(self, function, args, kwargs): + kernel_library, rpc_map = self.compile(function, args, kwargs) + if self.first_run: self.comm.check_ident() self.comm.switch_clock(self.external_clock) self.first_run = False - kernel_library, rpc_map = self.compile(function, args, kwargs) - try: self.comm.load(kernel_library) except Exception as error: diff --git a/soc/runtime/ksupport.c b/soc/runtime/ksupport.c index b2082beef..9a866908c 100644 --- a/soc/runtime/ksupport.c +++ b/soc/runtime/ksupport.c @@ -301,33 +301,34 @@ void watchdog_clear(int id) mailbox_send_and_wait(&request); } -int rpc(int rpc_num, ...) +int rpc(int service, const char *tag, ...) { - struct msg_rpc_request request; + struct msg_rpc_send_request request; struct msg_base *reply; - request.type = MESSAGE_TYPE_RPC_REQUEST; - request.rpc_num = rpc_num; - va_start(request.args, rpc_num); + request.type = MESSAGE_TYPE_RPC_SEND_REQUEST; + request.service = service; + request.tag = tag; + va_start(request.args, tag); mailbox_send_and_wait(&request); va_end(request.args); reply = mailbox_wait_and_receive(); - if(reply->type == MESSAGE_TYPE_RPC_REPLY) { - int result = ((struct msg_rpc_reply *)reply)->result; - mailbox_acknowledge(); - return result; - } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) { - struct artiq_exception exception; - memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception, - sizeof(struct artiq_exception)); - mailbox_acknowledge(); - __artiq_raise(&exception); - } else { + // if(reply->type == MESSAGE_TYPE_RPC_REPLY) { + // int result = ((struct msg_rpc_reply *)reply)->result; + // mailbox_acknowledge(); + // return result; + // } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) { + // struct artiq_exception exception; + // memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception, + // sizeof(struct artiq_exception)); + // mailbox_acknowledge(); + // __artiq_raise(&exception); + // } else { log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d", reply->type); while(1); - } + // } } void lognonl(const char *fmt, ...) diff --git a/soc/runtime/ksupport.h b/soc/runtime/ksupport.h index 47b9c3a1d..41561330f 100644 --- a/soc/runtime/ksupport.h +++ b/soc/runtime/ksupport.h @@ -5,7 +5,7 @@ long long int now_init(void); void now_save(long long int now); int watchdog_set(int ms); void watchdog_clear(int id); -int rpc(int service, ...); +int rpc(int service, const char *tag, ...); void lognonl(const char *fmt, ...); void log(const char *fmt, ...); diff --git a/soc/runtime/messages.h b/soc/runtime/messages.h index 651c7a5ef..55d53aca9 100644 --- a/soc/runtime/messages.h +++ b/soc/runtime/messages.h @@ -14,8 +14,9 @@ enum { MESSAGE_TYPE_WATCHDOG_SET_REQUEST, MESSAGE_TYPE_WATCHDOG_SET_REPLY, MESSAGE_TYPE_WATCHDOG_CLEAR, - MESSAGE_TYPE_RPC_REQUEST, - MESSAGE_TYPE_RPC_REPLY, + MESSAGE_TYPE_RPC_SEND_REQUEST, + MESSAGE_TYPE_RPC_RECV_REQUEST, + MESSAGE_TYPE_RPC_RECV_REPLY, MESSAGE_TYPE_RPC_EXCEPTION, MESSAGE_TYPE_LOG, @@ -80,15 +81,21 @@ struct msg_watchdog_clear { int id; }; -struct msg_rpc_request { +struct msg_rpc_send_request { int type; - int rpc_num; + int service; + const char *tag; va_list args; }; -struct msg_rpc_reply { +struct msg_rpc_recv_request { int type; - int result; + // TODO ??? +}; + +struct msg_rpc_recv_reply { + int type; + // TODO ??? }; struct msg_rpc_exception { diff --git a/soc/runtime/session.c b/soc/runtime/session.c index b571c5873..eadaa1990 100644 --- a/soc/runtime/session.c +++ b/soc/runtime/session.c @@ -457,23 +457,23 @@ static int process_input(void) user_kernel_state = USER_KERNEL_RUNNING; break; - case REMOTEMSG_TYPE_RPC_REPLY: { - struct msg_rpc_reply reply; + // case REMOTEMSG_TYPE_RPC_REPLY: { + // struct msg_rpc_reply reply; - int result = in_packet_int32(); + // int result = in_packet_int32(); - if(user_kernel_state != USER_KERNEL_WAIT_RPC) { - log("Unsolicited RPC reply"); - return 0; // restart session - } + // if(user_kernel_state != USER_KERNEL_WAIT_RPC) { + // log("Unsolicited RPC reply"); + // return 0; // restart session + // } - reply.type = MESSAGE_TYPE_RPC_REPLY; - reply.result = result; - mailbox_send_and_wait(&reply); + // reply.type = MESSAGE_TYPE_RPC_REPLY; + // reply.result = result; + // mailbox_send_and_wait(&reply); - user_kernel_state = USER_KERNEL_RUNNING; - break; - } + // user_kernel_state = USER_KERNEL_RUNNING; + // break; + // } case REMOTEMSG_TYPE_RPC_EXCEPTION: { struct msg_rpc_exception reply; @@ -509,91 +509,156 @@ static int process_input(void) return 1; } -static int send_rpc_value(const char **tag, void *value) +// See llvm_ir_generator.py:_rpc_tag. +static int send_rpc_value(const char **tag, void **value) { if(!out_packet_int8(**tag)) - return -1; + return 0; + + switch(*(*tag)++) { + case 't': { // tuple + int size = *(*tag)++; + if(!out_packet_int8(size)) + return 0; + + for(int i = 0; i < size; i++) { + if(!send_rpc_value(tag, value)) + return 0; + } + break; + } - int size = 0; - switch(**tag) { - case 0: // last tag case 'n': // None break; - case 'b': // bool - size = 1; - if(!out_packet_chunk(value, size)) - return -1; + case 'b': { // bool + int size = sizeof(int8_t); + if(!out_packet_chunk(*value, size)) + return 0; + *value = (void*)((intptr_t)(*value) + size); break; + } - case 'i': // int(width=32) - size = 4; - if(!out_packet_chunk(value, size)) - return -1; + case 'i': { // int(width=32) + int size = sizeof(int32_t); + if(!out_packet_chunk(*value, size)) + return 0; + *value = (void*)((intptr_t)(*value) + size); break; + } - case 'I': // int(width=64) - case 'f': // float - size = 8; - if(!out_packet_chunk(value, size)) - return -1; + case 'I': { // int(width=64) + int size = sizeof(int64_t); + if(!out_packet_chunk(*value, size)) + return 0; + *value = (void*)((intptr_t)(*value) + size); break; + } - case 'F': // Fraction - size = 16; - if(!out_packet_chunk(value, size)) - return -1; + case 'f': { // float + int size = sizeof(double); + if(!out_packet_chunk(*value, size)) + return 0; + *value = (void*)((intptr_t)(*value) + size); break; + } + + case 'F': { // Fraction + int size = sizeof(int64_t) * 2; + if(!out_packet_chunk(*value, size)) + return 0; + *value = (void*)((intptr_t)(*value) + size); + break; + } + + case 's': { // string + const char **string = *value; + if(!out_packet_string(*string)) + return 0; + *value = (void*)((intptr_t)(*value) + strlen(*string) + 1); + break; + } case 'l': { // list(elt='a) - struct { uint32_t length; void *elements; } *list = value; + struct { uint32_t length; struct {} *elements; } *list = *value; void *element = list->elements; - const char *tag_copy = *tag + 1; + if(!out_packet_int32(list->length)) + return 0; + + const char *tag_copy; for(int i = 0; i < list->length; i++) { - int element_size = send_rpc_value(&tag_copy, element); - if(element_size < 0) - return -1; - element = (void*)((intptr_t)element + element_size); + tag_copy = *tag; + if(!send_rpc_value(&tag_copy, &element)) + return 0; } *tag = tag_copy; - size = sizeof(list); + *value = (void*)((intptr_t)(*value) + sizeof(*list)); break; } - case 'o': { // host object - struct { uint32_t id; } *object = value; - - if(!out_packet_int32(object->id)) - return -1; - - size = sizeof(object); + case 'r': { // range(elt='a) + const char *tag_copy; + tag_copy = *tag; + if(!send_rpc_value(&tag_copy, value)) // min + return 0; + tag_copy = *tag; + if(!send_rpc_value(&tag_copy, value)) // max + return 0; + tag_copy = *tag; + if(!send_rpc_value(&tag_copy, value)) // step + return 0; + *tag = tag_copy; break; } + case 'o': { // option(inner='a) + struct { int8_t present; struct {} contents; } *option = *value; + void *contents = &option->contents; + + if(!out_packet_int8(option->present)) + return 0; + + // option never appears in composite types, so we don't have + // to accurately advance *value. + if(option->present) { + return send_rpc_value(tag, &contents); + } else { + (*tag)++; + break; + } + } + + case 'O': { // host object + struct { uint32_t id; } **object = *value; + + if(!out_packet_int32((*object)->id)) + return 0; + } + default: - return -1; + log("send_rpc_value: unknown tag %02x", *((*tag) - 1)); + return 0; } - (*tag)++; - return size; + return 1; } -static int send_rpc_request(int service, va_list args) +static int send_rpc_request(int service, const char *tag, va_list args) { out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST); out_packet_int32(service); - const char *tag = va_arg(args, const char*); - while(*tag) { + while(*tag != ':') { void *value = va_arg(args, void*); if(!kloader_validate_kpointer(value)) return 0; - if(send_rpc_value(&tag, &value) < 0) + if(!send_rpc_value(&tag, &value)) return 0; } + out_packet_int8(0); out_packet_finish(); return 1; } @@ -670,10 +735,10 @@ static int process_kmsg(struct msg_base *umsg) break; } - case MESSAGE_TYPE_RPC_REQUEST: { - struct msg_rpc_request *msg = (struct msg_rpc_request *)umsg; + case MESSAGE_TYPE_RPC_SEND_REQUEST: { + struct msg_rpc_send_request *msg = (struct msg_rpc_send_request *)umsg; - if(!send_rpc_request(msg->rpc_num, msg->args)) { + if(!send_rpc_request(msg->service, msg->tag, msg->args)) { log("Failed to send RPC request"); return 0; // restart session }