Implement sending RPCs.

pull/235/head
whitequark 2015-08-09 02:17:19 +03:00
parent 22457bc19c
commit b26af5df60
11 changed files with 433 additions and 131 deletions

View File

@ -163,7 +163,7 @@ def is_bool(typ):
def is_int(typ, width=None):
if width is not None:
return types.is_mono(typ, "int", {"width": width})
return types.is_mono(typ, "int", width=width)
else:
return types.is_mono(typ, "int")
@ -184,7 +184,7 @@ def is_numeric(typ):
def is_list(typ, elt=None):
if elt is not None:
return types.is_mono(typ, "list", {"elt": elt})
return types.is_mono(typ, "list", elt=elt)
else:
return types.is_mono(typ, "list")

View File

@ -5,10 +5,13 @@ the references to the host objects and translates the functions
annotated as ``@kernel`` when they are referenced.
"""
import inspect, os
import os, re, linecache, inspect
from collections import OrderedDict
from pythonparser import ast, source, diagnostic, parse_buffer
from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
class ASTSynthesizer:
@ -45,6 +48,9 @@ class ASTSynthesizer:
typ = builtins.TFloat()
return asttyped.NumT(n=value, ctx=None, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, str):
return asttyped.StrT(s=value, ctx=None, type=builtins.TStr(),
loc=self._add(repr(value)))
elif isinstance(value, list):
begin_loc = self._add("[")
elts = []
@ -123,7 +129,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
if inspect.isfunction(value):
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self.quote_function(value)
function_name = self.quote_function(value, node.loc)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=node.loc)
@ -154,7 +160,19 @@ class Stitcher:
self.functions = {}
self.next_rpc = 0
self.rpc_map = {}
self.inverse_rpc_map = {}
def _map(self, obj):
obj_id = id(obj)
if obj_id in self.inverse_rpc_map:
return self.inverse_rpc_map[obj_id]
self.next_rpc += 1
self.rpc_map[self.next_rpc] = obj
self.inverse_rpc_map[obj_id] = self.next_rpc
return self.next_rpc
def _iterate(self):
inferencer = Inferencer(engine=self.engine)
@ -213,17 +231,102 @@ class Stitcher:
quote_function=self._quote_function)
return asttyped_rewriter.visit(function_node)
def _quote_function(self, function):
def _function_def_note(self, function):
filename = function.__code__.co_filename
line = function.__code__.co_firstlineno
name = function.__code__.co_name
source_line = linecache.getline(filename, line)
column = re.search("def", source_line).start(0)
source_buffer = source.Buffer(source_line, filename, line)
loc = source.Range(source_buffer, column, column)
return diagnostic.Diagnostic("note",
"definition of function '{function}'",
{"function": name},
loc)
def _type_of_param(self, function, loc, param):
if param.default is not inspect.Parameter.empty:
# Try and infer the type from the default value.
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
synthesizer = ASTSynthesizer()
ast = synthesizer.quote(param.default)
synthesizer.finalize()
def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
"expanded from here while trying to infer a type for an"
" unannotated optional argument '{param_name}' from its default value",
{"param_name": param.name},
loc)
diag.notes.append(note)
diag.notes.append(self._function_def_note(function))
self.engine.process(diag)
proxy_engine = diagnostic.Engine()
proxy_engine.process = proxy_diagnostic
Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast)
return ast.type
else:
# Let the rest of the program decide.
return types.TVar()
def _quote_rpc_function(self, function, loc):
signature = inspect.signature(function)
arg_types = OrderedDict()
optarg_types = OrderedDict()
for param in signature.parameters.values():
if param.kind not in (inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD):
# We pretend we don't see *args, kwpostargs=..., **kwargs.
# Since every method can be still invoked without any arguments
# going into *args and the slots after it, this is always safe,
# if sometimes constraining.
#
# Accepting POSITIONAL_ONLY is OK, because the compiler
# desugars the keyword arguments into positional ones internally.
continue
if param.default is inspect.Parameter.empty:
arg_types[param.name] = self._type_of_param(function, loc, param)
else:
optarg_types[param.name] = self._type_of_param(function, loc, param)
# Fixed for now.
ret_type = builtins.TInt(types.TValue(32))
rpc_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
service=self._map(function))
rpc_name = "__rpc_{}__".format(rpc_type.service)
self.globals[rpc_name] = rpc_type
self.functions[function] = rpc_name
return rpc_name
def _quote_function(self, function, loc):
if function in self.functions:
return self.functions[function]
# Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node)
self.inference_finished = False
return function_node.name
if hasattr(function, "artiq_embedded"):
# Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node)
self.inference_finished = False
return function_node.name
else:
# Insert a storage-less global whose type instructs the compiler
# to perform an RPC instead of a regular call.
return self._quote_rpc_function(function, loc)
def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function)

View File

@ -41,14 +41,16 @@ class Source:
class Module:
def __init__(self, src):
int_monomorphizer = transforms.IntMonomorphizer(engine=src.engine)
inferencer = transforms.Inferencer(engine=src.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=src.engine)
escape_validator = validators.EscapeValidator(engine=src.engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=src.engine,
self.engine = src.engine
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
inferencer = transforms.Inferencer(engine=self.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
escape_validator = validators.EscapeValidator(engine=self.engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine,
module_name=src.name)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=src.engine)
local_access_validator = validators.LocalAccessValidator(engine=src.engine)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
local_access_validator = validators.LocalAccessValidator(engine=self.engine)
self.name = src.name
self.globals = src.globals
@ -62,7 +64,8 @@ class Module:
def build_llvm_ir(self, target):
"""Compile the module to LLVM IR for the specified target."""
llvm_ir_generator = transforms.LLVMIRGenerator(module_name=self.name, target=target)
llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine,
module_name=self.name, target=target)
return llvm_ir_generator.process(self.artiq_ir)
def entry_point(self):

View File

@ -3,12 +3,13 @@
into LLVM intermediate representation.
"""
from pythonparser import ast
from pythonparser import ast, diagnostic
from llvmlite_artiq import ir as ll
from .. import types, builtins, ir
class LLVMIRGenerator:
def __init__(self, module_name, target):
def __init__(self, engine, module_name, target):
self.engine = engine
self.target = target
self.llcontext = target.llcontext
self.llmodule = ll.Module(context=self.llcontext, name=module_name)
@ -21,6 +22,11 @@ class LLVMIRGenerator:
typ = typ.find()
if types.is_tuple(typ):
return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts])
elif types.is_rpc_function(typ):
if for_return:
return ll.VoidType()
else:
return ll.LiteralStructType([])
elif types.is_function(typ):
envarg = ll.IntType(8).as_pointer()
llty = ll.FunctionType(args=[envarg] +
@ -89,10 +95,13 @@ class LLVMIRGenerator:
return ll.Constant(llty, False)
elif isinstance(const.value, (int, float)):
return ll.Constant(llty, const.value)
elif isinstance(const.value, str):
assert "\0" not in const.value
elif isinstance(const.value, (str, bytes)):
if isinstance(const.value, str):
assert "\0" not in const.value
as_bytes = (const.value + "\0").encode("utf-8")
else:
as_bytes = const.value
as_bytes = (const.value + "\0").encode("utf-8")
if ir.is_exn_typeinfo(const.type):
# Exception typeinfo; should be merged with identical others
name = "__artiq_exn_" + const.value
@ -144,6 +153,9 @@ class LLVMIRGenerator:
llty = ll.FunctionType(ll.VoidType(), [self.llty_of_type(builtins.TException())])
elif name == "__artiq_reraise":
llty = ll.FunctionType(ll.VoidType(), [])
elif name == "rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()],
var_arg=True)
else:
assert False
@ -546,11 +558,79 @@ class LLVMIRGenerator:
name=insn.name)
return llvalue
# See session.c:send_rpc_value.
def _rpc_tag(self, typ, root_type, root_loc):
if types.is_tuple(typ):
assert len(typ.elts) < 256
return b"t" + bytes([len(typ.elts)]) + \
b"".join([self._rpc_tag(elt_type, root_type, root_loc)
for elt_type in typ.elts])
elif builtins.is_none(typ):
return b"n"
elif builtins.is_bool(typ):
return b"b"
elif builtins.is_int(typ, types.TValue(32)):
return b"i"
elif builtins.is_int(typ, types.TValue(64)):
return b"I"
elif builtins.is_float(typ):
return b"f"
elif builtins.is_str(typ):
return b"s"
elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
elif ir.is_option(typ):
return b"o" + self._rpc_tag(typ.params["inner"],
root_type, root_loc)
else:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(root_type)},
root_loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls",
{"type": printer.name(typ)},
root_loc)
self.engine.process(diag)
def _build_rpc(self, service, args, return_type):
llservice = ll.Constant(ll.IntType(32), service)
tag = b""
for arg in args:
if isinstance(arg, ir.Constant):
# Constants don't have locations, but conveniently
# they also never fail to serialize.
tag += self._rpc_tag(arg.type, arg.type, None)
else:
tag += self._rpc_tag(arg.type, arg.type, arg.loc)
tag += b":\x00"
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
llargs = []
for arg in args:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type)
self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot)
return self.llbuiltin("rpc"), [llservice, lltag] + llargs
def prepare_call(self, insn):
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)
if types.is_rpc_function(insn.target_function().type):
return self._build_rpc(insn.target_function().type.service,
insn.arguments(),
insn.target_function().type.ret)
else:
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)
def process_Call(self, insn):
llfun, llargs = self.prepare_call(insn)

View File

@ -222,6 +222,26 @@ class TFunction(Type):
def __ne__(self, other):
return not (self == other)
class TRPCFunction(TFunction):
"""
A function type of a remote function.
:ivar service: (int) RPC service number
"""
def __init__(self, args, optargs, ret, service):
super().__init__(args, optargs, ret)
self.service = service
def unify(self, other):
if isinstance(other, TRPCFunction) and \
self.service == other.service:
super().unify(other)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
class TBuiltin(Type):
"""
An instance of builtin type. Every instance of a builtin
@ -310,6 +330,8 @@ def is_mono(typ, name=None, **params):
typ = typ.find()
params_match = True
for param in params:
if param not in typ.params:
return False
params_match = params_match and \
typ.params[param].find() == params[param].find()
return isinstance(typ, TMono) and \
@ -329,6 +351,9 @@ def is_tuple(typ, elts=None):
def is_function(typ):
return isinstance(typ.find(), TFunction)
def is_rpc_function(typ):
return isinstance(typ.find(), TRPCFunction)
def is_builtin(typ, name=None):
typ = typ.find()
if name is None:
@ -381,11 +406,16 @@ class TypePrinter(object):
return "(%s,)" % self.name(typ.elts[0])
else:
return "(%s)" % ", ".join(list(map(self.name, typ.elts)))
elif isinstance(typ, TFunction):
elif isinstance(typ, (TFunction, TRPCFunction)):
args = []
args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args]
args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs]
return "(%s)->%s" % (", ".join(args), self.name(typ.ret))
signature = "(%s)->%s" % (", ".join(args), self.name(typ.ret))
if isinstance(typ, TRPCFunction):
return "rpc({}) {}".format(typ.service, signature)
elif isinstance(typ, TFunction):
return signature
elif isinstance(typ, TBuiltinFunction):
return "<function %s>" % typ.name
elif isinstance(typ, (TConstructor, TExceptionConstructor)):

View File

@ -276,8 +276,16 @@ class CommGeneric:
self._write_empty(_H2DMsgType.RUN_KERNEL)
logger.debug("running kernel")
def _receive_rpc_value(self, tag, rpc_map):
if tag == "n":
_rpc_sentinel = object()
def _receive_rpc_value(self, rpc_map):
tag = chr(self._read_int8())
if tag == "\x00":
return self._rpc_sentinel
elif tag == "t":
length = self._read_int8()
return tuple(self._receive_rpc_value(rpc_map) for _ in range(length))
elif tag == "n":
return None
elif tag == "b":
return bool(self._read_int8())
@ -291,31 +299,36 @@ class CommGeneric:
numerator = self._read_int64()
denominator = self._read_int64()
return Fraction(numerator, denominator)
elif tag == "s":
return self._read_string()
elif tag == "l":
elt_tag = chr(self._read_int8())
length = self._read_int32()
return [self._receive_rpc_value(elt_tag) for _ in range(length)]
return [self._receive_rpc_value(rpc_map) for _ in range(length)]
elif tag == "r":
lower = self._receive_rpc_value(rpc_map)
upper = self._receive_rpc_value(rpc_map)
step = self._receive_rpc_value(rpc_map)
return range(lower, upper, step)
elif tag == "o":
return rpc_map[self._read_int32()]
else:
raise IOError("Unknown RPC value tag: {}", tag)
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_values(self, rpc_map):
result = []
def _receive_rpc_args(self, rpc_map):
args = []
while True:
tag = chr(self._read_int8())
if tag == "\x00":
return result
else:
result.append(self._receive_rpc_value(tag, rpc_map))
value = self._receive_rpc_value(rpc_map)
if value is self._rpc_sentinel:
return args
args.append(value)
def _serve_rpc(self, rpc_map):
service = self._read_int32()
args = self._receive_rpc_values(rpc_map)
args = self._receive_rpc_args(rpc_map)
logger.debug("rpc service: %d %r", service, args)
try:
result = rpc_map[rpc_num](args)
result = rpc_map[service](*args)
if not isinstance(result, int) or not (-2**31 < result < 2**31-1):
raise ValueError("An RPC must return an int(width=32)")
except ARTIQException as exn:

View File

@ -50,13 +50,13 @@ class Core:
raise CompileError() from error
def run(self, function, args, kwargs):
kernel_library, rpc_map = self.compile(function, args, kwargs)
if self.first_run:
self.comm.check_ident()
self.comm.switch_clock(self.external_clock)
self.first_run = False
kernel_library, rpc_map = self.compile(function, args, kwargs)
try:
self.comm.load(kernel_library)
except Exception as error:

View File

@ -301,33 +301,34 @@ void watchdog_clear(int id)
mailbox_send_and_wait(&request);
}
int rpc(int rpc_num, ...)
int rpc(int service, const char *tag, ...)
{
struct msg_rpc_request request;
struct msg_rpc_send_request request;
struct msg_base *reply;
request.type = MESSAGE_TYPE_RPC_REQUEST;
request.rpc_num = rpc_num;
va_start(request.args, rpc_num);
request.type = MESSAGE_TYPE_RPC_SEND_REQUEST;
request.service = service;
request.tag = tag;
va_start(request.args, tag);
mailbox_send_and_wait(&request);
va_end(request.args);
reply = mailbox_wait_and_receive();
if(reply->type == MESSAGE_TYPE_RPC_REPLY) {
int result = ((struct msg_rpc_reply *)reply)->result;
mailbox_acknowledge();
return result;
} else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) {
struct artiq_exception exception;
memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception,
sizeof(struct artiq_exception));
mailbox_acknowledge();
__artiq_raise(&exception);
} else {
// if(reply->type == MESSAGE_TYPE_RPC_REPLY) {
// int result = ((struct msg_rpc_reply *)reply)->result;
// mailbox_acknowledge();
// return result;
// } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) {
// struct artiq_exception exception;
// memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception,
// sizeof(struct artiq_exception));
// mailbox_acknowledge();
// __artiq_raise(&exception);
// } else {
log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d",
reply->type);
while(1);
}
// }
}
void lognonl(const char *fmt, ...)

View File

@ -5,7 +5,7 @@ long long int now_init(void);
void now_save(long long int now);
int watchdog_set(int ms);
void watchdog_clear(int id);
int rpc(int service, ...);
int rpc(int service, const char *tag, ...);
void lognonl(const char *fmt, ...);
void log(const char *fmt, ...);

View File

@ -14,8 +14,9 @@ enum {
MESSAGE_TYPE_WATCHDOG_SET_REQUEST,
MESSAGE_TYPE_WATCHDOG_SET_REPLY,
MESSAGE_TYPE_WATCHDOG_CLEAR,
MESSAGE_TYPE_RPC_REQUEST,
MESSAGE_TYPE_RPC_REPLY,
MESSAGE_TYPE_RPC_SEND_REQUEST,
MESSAGE_TYPE_RPC_RECV_REQUEST,
MESSAGE_TYPE_RPC_RECV_REPLY,
MESSAGE_TYPE_RPC_EXCEPTION,
MESSAGE_TYPE_LOG,
@ -80,15 +81,21 @@ struct msg_watchdog_clear {
int id;
};
struct msg_rpc_request {
struct msg_rpc_send_request {
int type;
int rpc_num;
int service;
const char *tag;
va_list args;
};
struct msg_rpc_reply {
struct msg_rpc_recv_request {
int type;
int result;
// TODO ???
};
struct msg_rpc_recv_reply {
int type;
// TODO ???
};
struct msg_rpc_exception {

View File

@ -457,23 +457,23 @@ static int process_input(void)
user_kernel_state = USER_KERNEL_RUNNING;
break;
case REMOTEMSG_TYPE_RPC_REPLY: {
struct msg_rpc_reply reply;
// case REMOTEMSG_TYPE_RPC_REPLY: {
// struct msg_rpc_reply reply;
int result = in_packet_int32();
// int result = in_packet_int32();
if(user_kernel_state != USER_KERNEL_WAIT_RPC) {
log("Unsolicited RPC reply");
return 0; // restart session
}
// if(user_kernel_state != USER_KERNEL_WAIT_RPC) {
// log("Unsolicited RPC reply");
// return 0; // restart session
// }
reply.type = MESSAGE_TYPE_RPC_REPLY;
reply.result = result;
mailbox_send_and_wait(&reply);
// reply.type = MESSAGE_TYPE_RPC_REPLY;
// reply.result = result;
// mailbox_send_and_wait(&reply);
user_kernel_state = USER_KERNEL_RUNNING;
break;
}
// user_kernel_state = USER_KERNEL_RUNNING;
// break;
// }
case REMOTEMSG_TYPE_RPC_EXCEPTION: {
struct msg_rpc_exception reply;
@ -509,91 +509,156 @@ static int process_input(void)
return 1;
}
static int send_rpc_value(const char **tag, void *value)
// See llvm_ir_generator.py:_rpc_tag.
static int send_rpc_value(const char **tag, void **value)
{
if(!out_packet_int8(**tag))
return -1;
return 0;
switch(*(*tag)++) {
case 't': { // tuple
int size = *(*tag)++;
if(!out_packet_int8(size))
return 0;
for(int i = 0; i < size; i++) {
if(!send_rpc_value(tag, value))
return 0;
}
break;
}
int size = 0;
switch(**tag) {
case 0: // last tag
case 'n': // None
break;
case 'b': // bool
size = 1;
if(!out_packet_chunk(value, size))
return -1;
case 'b': { // bool
int size = sizeof(int8_t);
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
}
case 'i': // int(width=32)
size = 4;
if(!out_packet_chunk(value, size))
return -1;
case 'i': { // int(width=32)
int size = sizeof(int32_t);
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
}
case 'I': // int(width=64)
case 'f': // float
size = 8;
if(!out_packet_chunk(value, size))
return -1;
case 'I': { // int(width=64)
int size = sizeof(int64_t);
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
}
case 'F': // Fraction
size = 16;
if(!out_packet_chunk(value, size))
return -1;
case 'f': { // float
int size = sizeof(double);
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
}
case 'F': { // Fraction
int size = sizeof(int64_t) * 2;
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
}
case 's': { // string
const char **string = *value;
if(!out_packet_string(*string))
return 0;
*value = (void*)((intptr_t)(*value) + strlen(*string) + 1);
break;
}
case 'l': { // list(elt='a)
struct { uint32_t length; void *elements; } *list = value;
struct { uint32_t length; struct {} *elements; } *list = *value;
void *element = list->elements;
const char *tag_copy = *tag + 1;
if(!out_packet_int32(list->length))
return 0;
const char *tag_copy;
for(int i = 0; i < list->length; i++) {
int element_size = send_rpc_value(&tag_copy, element);
if(element_size < 0)
return -1;
element = (void*)((intptr_t)element + element_size);
tag_copy = *tag;
if(!send_rpc_value(&tag_copy, &element))
return 0;
}
*tag = tag_copy;
size = sizeof(list);
*value = (void*)((intptr_t)(*value) + sizeof(*list));
break;
}
case 'o': { // host object
struct { uint32_t id; } *object = value;
if(!out_packet_int32(object->id))
return -1;
size = sizeof(object);
case 'r': { // range(elt='a)
const char *tag_copy;
tag_copy = *tag;
if(!send_rpc_value(&tag_copy, value)) // min
return 0;
tag_copy = *tag;
if(!send_rpc_value(&tag_copy, value)) // max
return 0;
tag_copy = *tag;
if(!send_rpc_value(&tag_copy, value)) // step
return 0;
*tag = tag_copy;
break;
}
case 'o': { // option(inner='a)
struct { int8_t present; struct {} contents; } *option = *value;
void *contents = &option->contents;
if(!out_packet_int8(option->present))
return 0;
// option never appears in composite types, so we don't have
// to accurately advance *value.
if(option->present) {
return send_rpc_value(tag, &contents);
} else {
(*tag)++;
break;
}
}
case 'O': { // host object
struct { uint32_t id; } **object = *value;
if(!out_packet_int32((*object)->id))
return 0;
}
default:
return -1;
log("send_rpc_value: unknown tag %02x", *((*tag) - 1));
return 0;
}
(*tag)++;
return size;
return 1;
}
static int send_rpc_request(int service, va_list args)
static int send_rpc_request(int service, const char *tag, va_list args)
{
out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST);
out_packet_int32(service);
const char *tag = va_arg(args, const char*);
while(*tag) {
while(*tag != ':') {
void *value = va_arg(args, void*);
if(!kloader_validate_kpointer(value))
return 0;
if(send_rpc_value(&tag, &value) < 0)
if(!send_rpc_value(&tag, &value))
return 0;
}
out_packet_int8(0);
out_packet_finish();
return 1;
}
@ -670,10 +735,10 @@ static int process_kmsg(struct msg_base *umsg)
break;
}
case MESSAGE_TYPE_RPC_REQUEST: {
struct msg_rpc_request *msg = (struct msg_rpc_request *)umsg;
case MESSAGE_TYPE_RPC_SEND_REQUEST: {
struct msg_rpc_send_request *msg = (struct msg_rpc_send_request *)umsg;
if(!send_rpc_request(msg->rpc_num, msg->args)) {
if(!send_rpc_request(msg->service, msg->tag, msg->args)) {
log("Failed to send RPC request");
return 0; // restart session
}