Implement sending RPCs.

This commit is contained in:
whitequark 2015-08-09 02:17:19 +03:00
parent 22457bc19c
commit b26af5df60
11 changed files with 433 additions and 131 deletions

View File

@ -163,7 +163,7 @@ def is_bool(typ):
def is_int(typ, width=None): def is_int(typ, width=None):
if width is not None: if width is not None:
return types.is_mono(typ, "int", {"width": width}) return types.is_mono(typ, "int", width=width)
else: else:
return types.is_mono(typ, "int") return types.is_mono(typ, "int")
@ -184,7 +184,7 @@ def is_numeric(typ):
def is_list(typ, elt=None): def is_list(typ, elt=None):
if elt is not None: if elt is not None:
return types.is_mono(typ, "list", {"elt": elt}) return types.is_mono(typ, "list", elt=elt)
else: else:
return types.is_mono(typ, "list") return types.is_mono(typ, "list")

View File

@ -5,10 +5,13 @@ the references to the host objects and translates the functions
annotated as ``@kernel`` when they are referenced. annotated as ``@kernel`` when they are referenced.
""" """
import inspect, os import os, re, linecache, inspect
from collections import OrderedDict
from pythonparser import ast, source, diagnostic, parse_buffer from pythonparser import ast, source, diagnostic, parse_buffer
from . import types, builtins, asttyped, prelude from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
class ASTSynthesizer: class ASTSynthesizer:
@ -45,6 +48,9 @@ class ASTSynthesizer:
typ = builtins.TFloat() typ = builtins.TFloat()
return asttyped.NumT(n=value, ctx=None, type=typ, return asttyped.NumT(n=value, ctx=None, type=typ,
loc=self._add(repr(value))) loc=self._add(repr(value)))
elif isinstance(value, str):
return asttyped.StrT(s=value, ctx=None, type=builtins.TStr(),
loc=self._add(repr(value)))
elif isinstance(value, list): elif isinstance(value, list):
begin_loc = self._add("[") begin_loc = self._add("[")
elts = [] elts = []
@ -123,7 +129,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
if inspect.isfunction(value): if inspect.isfunction(value):
# It's a function. We need to translate the function and insert # It's a function. We need to translate the function and insert
# a reference to it. # a reference to it.
function_name = self.quote_function(value) function_name = self.quote_function(value, node.loc)
return asttyped.NameT(id=function_name, ctx=None, return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name], type=self.globals[function_name],
loc=node.loc) loc=node.loc)
@ -154,7 +160,19 @@ class Stitcher:
self.functions = {} self.functions = {}
self.next_rpc = 0
self.rpc_map = {} self.rpc_map = {}
self.inverse_rpc_map = {}
def _map(self, obj):
obj_id = id(obj)
if obj_id in self.inverse_rpc_map:
return self.inverse_rpc_map[obj_id]
self.next_rpc += 1
self.rpc_map[self.next_rpc] = obj
self.inverse_rpc_map[obj_id] = self.next_rpc
return self.next_rpc
def _iterate(self): def _iterate(self):
inferencer = Inferencer(engine=self.engine) inferencer = Inferencer(engine=self.engine)
@ -213,10 +231,91 @@ class Stitcher:
quote_function=self._quote_function) quote_function=self._quote_function)
return asttyped_rewriter.visit(function_node) return asttyped_rewriter.visit(function_node)
def _quote_function(self, function): def _function_def_note(self, function):
filename = function.__code__.co_filename
line = function.__code__.co_firstlineno
name = function.__code__.co_name
source_line = linecache.getline(filename, line)
column = re.search("def", source_line).start(0)
source_buffer = source.Buffer(source_line, filename, line)
loc = source.Range(source_buffer, column, column)
return diagnostic.Diagnostic("note",
"definition of function '{function}'",
{"function": name},
loc)
def _type_of_param(self, function, loc, param):
if param.default is not inspect.Parameter.empty:
# Try and infer the type from the default value.
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
synthesizer = ASTSynthesizer()
ast = synthesizer.quote(param.default)
synthesizer.finalize()
def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
"expanded from here while trying to infer a type for an"
" unannotated optional argument '{param_name}' from its default value",
{"param_name": param.name},
loc)
diag.notes.append(note)
diag.notes.append(self._function_def_note(function))
self.engine.process(diag)
proxy_engine = diagnostic.Engine()
proxy_engine.process = proxy_diagnostic
Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast)
return ast.type
else:
# Let the rest of the program decide.
return types.TVar()
def _quote_rpc_function(self, function, loc):
signature = inspect.signature(function)
arg_types = OrderedDict()
optarg_types = OrderedDict()
for param in signature.parameters.values():
if param.kind not in (inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD):
# We pretend we don't see *args, kwpostargs=..., **kwargs.
# Since every method can be still invoked without any arguments
# going into *args and the slots after it, this is always safe,
# if sometimes constraining.
#
# Accepting POSITIONAL_ONLY is OK, because the compiler
# desugars the keyword arguments into positional ones internally.
continue
if param.default is inspect.Parameter.empty:
arg_types[param.name] = self._type_of_param(function, loc, param)
else:
optarg_types[param.name] = self._type_of_param(function, loc, param)
# Fixed for now.
ret_type = builtins.TInt(types.TValue(32))
rpc_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
service=self._map(function))
rpc_name = "__rpc_{}__".format(rpc_type.service)
self.globals[rpc_name] = rpc_type
self.functions[function] = rpc_name
return rpc_name
def _quote_function(self, function, loc):
if function in self.functions: if function in self.functions:
return self.functions[function] return self.functions[function]
if hasattr(function, "artiq_embedded"):
# Insert the typed AST for the new function and restart inference. # Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before # It doesn't really matter where we insert as long as it is before
# the final call. # the final call.
@ -224,6 +323,10 @@ class Stitcher:
self.typedtree.insert(0, function_node) self.typedtree.insert(0, function_node)
self.inference_finished = False self.inference_finished = False
return function_node.name return function_node.name
else:
# Insert a storage-less global whose type instructs the compiler
# to perform an RPC instead of a regular call.
return self._quote_rpc_function(function, loc)
def stitch_call(self, function, args, kwargs): def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function) function_node = self._quote_embedded_function(function)

View File

@ -41,14 +41,16 @@ class Source:
class Module: class Module:
def __init__(self, src): def __init__(self, src):
int_monomorphizer = transforms.IntMonomorphizer(engine=src.engine) self.engine = src.engine
inferencer = transforms.Inferencer(engine=src.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=src.engine) int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
escape_validator = validators.EscapeValidator(engine=src.engine) inferencer = transforms.Inferencer(engine=self.engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=src.engine, monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
escape_validator = validators.EscapeValidator(engine=self.engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine,
module_name=src.name) module_name=src.name)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=src.engine) dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
local_access_validator = validators.LocalAccessValidator(engine=src.engine) local_access_validator = validators.LocalAccessValidator(engine=self.engine)
self.name = src.name self.name = src.name
self.globals = src.globals self.globals = src.globals
@ -62,7 +64,8 @@ class Module:
def build_llvm_ir(self, target): def build_llvm_ir(self, target):
"""Compile the module to LLVM IR for the specified target.""" """Compile the module to LLVM IR for the specified target."""
llvm_ir_generator = transforms.LLVMIRGenerator(module_name=self.name, target=target) llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine,
module_name=self.name, target=target)
return llvm_ir_generator.process(self.artiq_ir) return llvm_ir_generator.process(self.artiq_ir)
def entry_point(self): def entry_point(self):

View File

@ -3,12 +3,13 @@
into LLVM intermediate representation. into LLVM intermediate representation.
""" """
from pythonparser import ast from pythonparser import ast, diagnostic
from llvmlite_artiq import ir as ll from llvmlite_artiq import ir as ll
from .. import types, builtins, ir from .. import types, builtins, ir
class LLVMIRGenerator: class LLVMIRGenerator:
def __init__(self, module_name, target): def __init__(self, engine, module_name, target):
self.engine = engine
self.target = target self.target = target
self.llcontext = target.llcontext self.llcontext = target.llcontext
self.llmodule = ll.Module(context=self.llcontext, name=module_name) self.llmodule = ll.Module(context=self.llcontext, name=module_name)
@ -21,6 +22,11 @@ class LLVMIRGenerator:
typ = typ.find() typ = typ.find()
if types.is_tuple(typ): if types.is_tuple(typ):
return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts]) return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts])
elif types.is_rpc_function(typ):
if for_return:
return ll.VoidType()
else:
return ll.LiteralStructType([])
elif types.is_function(typ): elif types.is_function(typ):
envarg = ll.IntType(8).as_pointer() envarg = ll.IntType(8).as_pointer()
llty = ll.FunctionType(args=[envarg] + llty = ll.FunctionType(args=[envarg] +
@ -89,10 +95,13 @@ class LLVMIRGenerator:
return ll.Constant(llty, False) return ll.Constant(llty, False)
elif isinstance(const.value, (int, float)): elif isinstance(const.value, (int, float)):
return ll.Constant(llty, const.value) return ll.Constant(llty, const.value)
elif isinstance(const.value, str): elif isinstance(const.value, (str, bytes)):
if isinstance(const.value, str):
assert "\0" not in const.value assert "\0" not in const.value
as_bytes = (const.value + "\0").encode("utf-8") as_bytes = (const.value + "\0").encode("utf-8")
else:
as_bytes = const.value
if ir.is_exn_typeinfo(const.type): if ir.is_exn_typeinfo(const.type):
# Exception typeinfo; should be merged with identical others # Exception typeinfo; should be merged with identical others
name = "__artiq_exn_" + const.value name = "__artiq_exn_" + const.value
@ -144,6 +153,9 @@ class LLVMIRGenerator:
llty = ll.FunctionType(ll.VoidType(), [self.llty_of_type(builtins.TException())]) llty = ll.FunctionType(ll.VoidType(), [self.llty_of_type(builtins.TException())])
elif name == "__artiq_reraise": elif name == "__artiq_reraise":
llty = ll.FunctionType(ll.VoidType(), []) llty = ll.FunctionType(ll.VoidType(), [])
elif name == "rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()],
var_arg=True)
else: else:
assert False assert False
@ -546,7 +558,75 @@ class LLVMIRGenerator:
name=insn.name) name=insn.name)
return llvalue return llvalue
# See session.c:send_rpc_value.
def _rpc_tag(self, typ, root_type, root_loc):
if types.is_tuple(typ):
assert len(typ.elts) < 256
return b"t" + bytes([len(typ.elts)]) + \
b"".join([self._rpc_tag(elt_type, root_type, root_loc)
for elt_type in typ.elts])
elif builtins.is_none(typ):
return b"n"
elif builtins.is_bool(typ):
return b"b"
elif builtins.is_int(typ, types.TValue(32)):
return b"i"
elif builtins.is_int(typ, types.TValue(64)):
return b"I"
elif builtins.is_float(typ):
return b"f"
elif builtins.is_str(typ):
return b"s"
elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
elif ir.is_option(typ):
return b"o" + self._rpc_tag(typ.params["inner"],
root_type, root_loc)
else:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(root_type)},
root_loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls",
{"type": printer.name(typ)},
root_loc)
self.engine.process(diag)
def _build_rpc(self, service, args, return_type):
llservice = ll.Constant(ll.IntType(32), service)
tag = b""
for arg in args:
if isinstance(arg, ir.Constant):
# Constants don't have locations, but conveniently
# they also never fail to serialize.
tag += self._rpc_tag(arg.type, arg.type, None)
else:
tag += self._rpc_tag(arg.type, arg.type, arg.loc)
tag += b":\x00"
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
llargs = []
for arg in args:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type)
self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot)
return self.llbuiltin("rpc"), [llservice, lltag] + llargs
def prepare_call(self, insn): def prepare_call(self, insn):
if types.is_rpc_function(insn.target_function().type):
return self._build_rpc(insn.target_function().type.service,
insn.arguments(),
insn.target_function().type.ret)
else:
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments()) llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0) llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1) llfun = self.llbuilder.extract_value(llclosure, 1)

View File

@ -222,6 +222,26 @@ class TFunction(Type):
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
class TRPCFunction(TFunction):
"""
A function type of a remote function.
:ivar service: (int) RPC service number
"""
def __init__(self, args, optargs, ret, service):
super().__init__(args, optargs, ret)
self.service = service
def unify(self, other):
if isinstance(other, TRPCFunction) and \
self.service == other.service:
super().unify(other)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
class TBuiltin(Type): class TBuiltin(Type):
""" """
An instance of builtin type. Every instance of a builtin An instance of builtin type. Every instance of a builtin
@ -310,6 +330,8 @@ def is_mono(typ, name=None, **params):
typ = typ.find() typ = typ.find()
params_match = True params_match = True
for param in params: for param in params:
if param not in typ.params:
return False
params_match = params_match and \ params_match = params_match and \
typ.params[param].find() == params[param].find() typ.params[param].find() == params[param].find()
return isinstance(typ, TMono) and \ return isinstance(typ, TMono) and \
@ -329,6 +351,9 @@ def is_tuple(typ, elts=None):
def is_function(typ): def is_function(typ):
return isinstance(typ.find(), TFunction) return isinstance(typ.find(), TFunction)
def is_rpc_function(typ):
return isinstance(typ.find(), TRPCFunction)
def is_builtin(typ, name=None): def is_builtin(typ, name=None):
typ = typ.find() typ = typ.find()
if name is None: if name is None:
@ -381,11 +406,16 @@ class TypePrinter(object):
return "(%s,)" % self.name(typ.elts[0]) return "(%s,)" % self.name(typ.elts[0])
else: else:
return "(%s)" % ", ".join(list(map(self.name, typ.elts))) return "(%s)" % ", ".join(list(map(self.name, typ.elts)))
elif isinstance(typ, TFunction): elif isinstance(typ, (TFunction, TRPCFunction)):
args = [] args = []
args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args] args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args]
args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs] args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs]
return "(%s)->%s" % (", ".join(args), self.name(typ.ret)) signature = "(%s)->%s" % (", ".join(args), self.name(typ.ret))
if isinstance(typ, TRPCFunction):
return "rpc({}) {}".format(typ.service, signature)
elif isinstance(typ, TFunction):
return signature
elif isinstance(typ, TBuiltinFunction): elif isinstance(typ, TBuiltinFunction):
return "<function %s>" % typ.name return "<function %s>" % typ.name
elif isinstance(typ, (TConstructor, TExceptionConstructor)): elif isinstance(typ, (TConstructor, TExceptionConstructor)):

View File

@ -276,8 +276,16 @@ class CommGeneric:
self._write_empty(_H2DMsgType.RUN_KERNEL) self._write_empty(_H2DMsgType.RUN_KERNEL)
logger.debug("running kernel") logger.debug("running kernel")
def _receive_rpc_value(self, tag, rpc_map): _rpc_sentinel = object()
if tag == "n":
def _receive_rpc_value(self, rpc_map):
tag = chr(self._read_int8())
if tag == "\x00":
return self._rpc_sentinel
elif tag == "t":
length = self._read_int8()
return tuple(self._receive_rpc_value(rpc_map) for _ in range(length))
elif tag == "n":
return None return None
elif tag == "b": elif tag == "b":
return bool(self._read_int8()) return bool(self._read_int8())
@ -291,31 +299,36 @@ class CommGeneric:
numerator = self._read_int64() numerator = self._read_int64()
denominator = self._read_int64() denominator = self._read_int64()
return Fraction(numerator, denominator) return Fraction(numerator, denominator)
elif tag == "s":
return self._read_string()
elif tag == "l": elif tag == "l":
elt_tag = chr(self._read_int8())
length = self._read_int32() length = self._read_int32()
return [self._receive_rpc_value(elt_tag) for _ in range(length)] return [self._receive_rpc_value(rpc_map) for _ in range(length)]
elif tag == "r":
lower = self._receive_rpc_value(rpc_map)
upper = self._receive_rpc_value(rpc_map)
step = self._receive_rpc_value(rpc_map)
return range(lower, upper, step)
elif tag == "o": elif tag == "o":
return rpc_map[self._read_int32()] return rpc_map[self._read_int32()]
else: else:
raise IOError("Unknown RPC value tag: {}", tag) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_values(self, rpc_map): def _receive_rpc_args(self, rpc_map):
result = [] args = []
while True: while True:
tag = chr(self._read_int8()) value = self._receive_rpc_value(rpc_map)
if tag == "\x00": if value is self._rpc_sentinel:
return result return args
else: args.append(value)
result.append(self._receive_rpc_value(tag, rpc_map))
def _serve_rpc(self, rpc_map): def _serve_rpc(self, rpc_map):
service = self._read_int32() service = self._read_int32()
args = self._receive_rpc_values(rpc_map) args = self._receive_rpc_args(rpc_map)
logger.debug("rpc service: %d %r", service, args) logger.debug("rpc service: %d %r", service, args)
try: try:
result = rpc_map[rpc_num](args) result = rpc_map[service](*args)
if not isinstance(result, int) or not (-2**31 < result < 2**31-1): if not isinstance(result, int) or not (-2**31 < result < 2**31-1):
raise ValueError("An RPC must return an int(width=32)") raise ValueError("An RPC must return an int(width=32)")
except ARTIQException as exn: except ARTIQException as exn:

View File

@ -50,13 +50,13 @@ class Core:
raise CompileError() from error raise CompileError() from error
def run(self, function, args, kwargs): def run(self, function, args, kwargs):
kernel_library, rpc_map = self.compile(function, args, kwargs)
if self.first_run: if self.first_run:
self.comm.check_ident() self.comm.check_ident()
self.comm.switch_clock(self.external_clock) self.comm.switch_clock(self.external_clock)
self.first_run = False self.first_run = False
kernel_library, rpc_map = self.compile(function, args, kwargs)
try: try:
self.comm.load(kernel_library) self.comm.load(kernel_library)
except Exception as error: except Exception as error:

View File

@ -301,33 +301,34 @@ void watchdog_clear(int id)
mailbox_send_and_wait(&request); mailbox_send_and_wait(&request);
} }
int rpc(int rpc_num, ...) int rpc(int service, const char *tag, ...)
{ {
struct msg_rpc_request request; struct msg_rpc_send_request request;
struct msg_base *reply; struct msg_base *reply;
request.type = MESSAGE_TYPE_RPC_REQUEST; request.type = MESSAGE_TYPE_RPC_SEND_REQUEST;
request.rpc_num = rpc_num; request.service = service;
va_start(request.args, rpc_num); request.tag = tag;
va_start(request.args, tag);
mailbox_send_and_wait(&request); mailbox_send_and_wait(&request);
va_end(request.args); va_end(request.args);
reply = mailbox_wait_and_receive(); reply = mailbox_wait_and_receive();
if(reply->type == MESSAGE_TYPE_RPC_REPLY) { // if(reply->type == MESSAGE_TYPE_RPC_REPLY) {
int result = ((struct msg_rpc_reply *)reply)->result; // int result = ((struct msg_rpc_reply *)reply)->result;
mailbox_acknowledge(); // mailbox_acknowledge();
return result; // return result;
} else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) { // } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) {
struct artiq_exception exception; // struct artiq_exception exception;
memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception, // memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception,
sizeof(struct artiq_exception)); // sizeof(struct artiq_exception));
mailbox_acknowledge(); // mailbox_acknowledge();
__artiq_raise(&exception); // __artiq_raise(&exception);
} else { // } else {
log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d", log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d",
reply->type); reply->type);
while(1); while(1);
} // }
} }
void lognonl(const char *fmt, ...) void lognonl(const char *fmt, ...)

View File

@ -5,7 +5,7 @@ long long int now_init(void);
void now_save(long long int now); void now_save(long long int now);
int watchdog_set(int ms); int watchdog_set(int ms);
void watchdog_clear(int id); void watchdog_clear(int id);
int rpc(int service, ...); int rpc(int service, const char *tag, ...);
void lognonl(const char *fmt, ...); void lognonl(const char *fmt, ...);
void log(const char *fmt, ...); void log(const char *fmt, ...);

View File

@ -14,8 +14,9 @@ enum {
MESSAGE_TYPE_WATCHDOG_SET_REQUEST, MESSAGE_TYPE_WATCHDOG_SET_REQUEST,
MESSAGE_TYPE_WATCHDOG_SET_REPLY, MESSAGE_TYPE_WATCHDOG_SET_REPLY,
MESSAGE_TYPE_WATCHDOG_CLEAR, MESSAGE_TYPE_WATCHDOG_CLEAR,
MESSAGE_TYPE_RPC_REQUEST, MESSAGE_TYPE_RPC_SEND_REQUEST,
MESSAGE_TYPE_RPC_REPLY, MESSAGE_TYPE_RPC_RECV_REQUEST,
MESSAGE_TYPE_RPC_RECV_REPLY,
MESSAGE_TYPE_RPC_EXCEPTION, MESSAGE_TYPE_RPC_EXCEPTION,
MESSAGE_TYPE_LOG, MESSAGE_TYPE_LOG,
@ -80,15 +81,21 @@ struct msg_watchdog_clear {
int id; int id;
}; };
struct msg_rpc_request { struct msg_rpc_send_request {
int type; int type;
int rpc_num; int service;
const char *tag;
va_list args; va_list args;
}; };
struct msg_rpc_reply { struct msg_rpc_recv_request {
int type; int type;
int result; // TODO ???
};
struct msg_rpc_recv_reply {
int type;
// TODO ???
}; };
struct msg_rpc_exception { struct msg_rpc_exception {

View File

@ -457,23 +457,23 @@ static int process_input(void)
user_kernel_state = USER_KERNEL_RUNNING; user_kernel_state = USER_KERNEL_RUNNING;
break; break;
case REMOTEMSG_TYPE_RPC_REPLY: { // case REMOTEMSG_TYPE_RPC_REPLY: {
struct msg_rpc_reply reply; // struct msg_rpc_reply reply;
int result = in_packet_int32(); // int result = in_packet_int32();
if(user_kernel_state != USER_KERNEL_WAIT_RPC) { // if(user_kernel_state != USER_KERNEL_WAIT_RPC) {
log("Unsolicited RPC reply"); // log("Unsolicited RPC reply");
return 0; // restart session // return 0; // restart session
} // }
reply.type = MESSAGE_TYPE_RPC_REPLY; // reply.type = MESSAGE_TYPE_RPC_REPLY;
reply.result = result; // reply.result = result;
mailbox_send_and_wait(&reply); // mailbox_send_and_wait(&reply);
user_kernel_state = USER_KERNEL_RUNNING; // user_kernel_state = USER_KERNEL_RUNNING;
break; // break;
} // }
case REMOTEMSG_TYPE_RPC_EXCEPTION: { case REMOTEMSG_TYPE_RPC_EXCEPTION: {
struct msg_rpc_exception reply; struct msg_rpc_exception reply;
@ -509,91 +509,156 @@ static int process_input(void)
return 1; return 1;
} }
static int send_rpc_value(const char **tag, void *value) // See llvm_ir_generator.py:_rpc_tag.
static int send_rpc_value(const char **tag, void **value)
{ {
if(!out_packet_int8(**tag)) if(!out_packet_int8(**tag))
return -1; return 0;
switch(*(*tag)++) {
case 't': { // tuple
int size = *(*tag)++;
if(!out_packet_int8(size))
return 0;
for(int i = 0; i < size; i++) {
if(!send_rpc_value(tag, value))
return 0;
}
break;
}
int size = 0;
switch(**tag) {
case 0: // last tag
case 'n': // None case 'n': // None
break; break;
case 'b': // bool case 'b': { // bool
size = 1; int size = sizeof(int8_t);
if(!out_packet_chunk(value, size)) if(!out_packet_chunk(*value, size))
return -1; return 0;
*value = (void*)((intptr_t)(*value) + size);
break; break;
}
case 'i': // int(width=32) case 'i': { // int(width=32)
size = 4; int size = sizeof(int32_t);
if(!out_packet_chunk(value, size)) if(!out_packet_chunk(*value, size))
return -1; return 0;
*value = (void*)((intptr_t)(*value) + size);
break; break;
}
case 'I': // int(width=64) case 'I': { // int(width=64)
case 'f': // float int size = sizeof(int64_t);
size = 8; if(!out_packet_chunk(*value, size))
if(!out_packet_chunk(value, size)) return 0;
return -1; *value = (void*)((intptr_t)(*value) + size);
break; break;
}
case 'F': // Fraction case 'f': { // float
size = 16; int size = sizeof(double);
if(!out_packet_chunk(value, size)) if(!out_packet_chunk(*value, size))
return -1; return 0;
*value = (void*)((intptr_t)(*value) + size);
break; break;
}
case 'F': { // Fraction
int size = sizeof(int64_t) * 2;
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
}
case 's': { // string
const char **string = *value;
if(!out_packet_string(*string))
return 0;
*value = (void*)((intptr_t)(*value) + strlen(*string) + 1);
break;
}
case 'l': { // list(elt='a) case 'l': { // list(elt='a)
struct { uint32_t length; void *elements; } *list = value; struct { uint32_t length; struct {} *elements; } *list = *value;
void *element = list->elements; void *element = list->elements;
const char *tag_copy = *tag + 1; if(!out_packet_int32(list->length))
return 0;
const char *tag_copy;
for(int i = 0; i < list->length; i++) { for(int i = 0; i < list->length; i++) {
int element_size = send_rpc_value(&tag_copy, element); tag_copy = *tag;
if(element_size < 0) if(!send_rpc_value(&tag_copy, &element))
return -1; return 0;
element = (void*)((intptr_t)element + element_size);
} }
*tag = tag_copy; *tag = tag_copy;
size = sizeof(list); *value = (void*)((intptr_t)(*value) + sizeof(*list));
break; break;
} }
case 'o': { // host object case 'r': { // range(elt='a)
struct { uint32_t id; } *object = value; const char *tag_copy;
tag_copy = *tag;
if(!out_packet_int32(object->id)) if(!send_rpc_value(&tag_copy, value)) // min
return -1; return 0;
tag_copy = *tag;
size = sizeof(object); if(!send_rpc_value(&tag_copy, value)) // max
return 0;
tag_copy = *tag;
if(!send_rpc_value(&tag_copy, value)) // step
return 0;
*tag = tag_copy;
break; break;
} }
case 'o': { // option(inner='a)
struct { int8_t present; struct {} contents; } *option = *value;
void *contents = &option->contents;
if(!out_packet_int8(option->present))
return 0;
// option never appears in composite types, so we don't have
// to accurately advance *value.
if(option->present) {
return send_rpc_value(tag, &contents);
} else {
(*tag)++;
break;
}
}
case 'O': { // host object
struct { uint32_t id; } **object = *value;
if(!out_packet_int32((*object)->id))
return 0;
}
default: default:
return -1; log("send_rpc_value: unknown tag %02x", *((*tag) - 1));
return 0;
} }
(*tag)++; return 1;
return size;
} }
static int send_rpc_request(int service, va_list args) static int send_rpc_request(int service, const char *tag, va_list args)
{ {
out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST); out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST);
out_packet_int32(service); out_packet_int32(service);
const char *tag = va_arg(args, const char*); while(*tag != ':') {
while(*tag) {
void *value = va_arg(args, void*); void *value = va_arg(args, void*);
if(!kloader_validate_kpointer(value)) if(!kloader_validate_kpointer(value))
return 0; return 0;
if(send_rpc_value(&tag, &value) < 0) if(!send_rpc_value(&tag, &value))
return 0; return 0;
} }
out_packet_int8(0);
out_packet_finish(); out_packet_finish();
return 1; return 1;
} }
@ -670,10 +735,10 @@ static int process_kmsg(struct msg_base *umsg)
break; break;
} }
case MESSAGE_TYPE_RPC_REQUEST: { case MESSAGE_TYPE_RPC_SEND_REQUEST: {
struct msg_rpc_request *msg = (struct msg_rpc_request *)umsg; struct msg_rpc_send_request *msg = (struct msg_rpc_send_request *)umsg;
if(!send_rpc_request(msg->rpc_num, msg->args)) { if(!send_rpc_request(msg->service, msg->tag, msg->args)) {
log("Failed to send RPC request"); log("Failed to send RPC request");
return 0; // restart session return 0; // restart session
} }