Implement receiving exceptions from RPCs.

This commit is contained in:
whitequark 2015-08-09 16:16:41 +03:00
parent 8b7d38d203
commit 02b1543c63
6 changed files with 177 additions and 74 deletions

View File

@ -146,6 +146,10 @@ class LLVMIRGenerator:
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)])
elif name == "llvm.copysign.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
elif name == "llvm.stacksave":
llty = ll.FunctionType(ll.IntType(8).as_pointer(), [])
elif name == "llvm.stackrestore":
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()])
elif name == self.target.print_function:
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True)
elif name == "__artiq_personality":
@ -155,8 +159,10 @@ class LLVMIRGenerator:
elif name == "__artiq_reraise":
llty = ll.FunctionType(ll.VoidType(), [])
elif name == "send_rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()],
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(32), ll.IntType(8).as_pointer()],
var_arg=True)
elif name == "recv_rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer().as_pointer()])
else:
assert False
@ -559,12 +565,18 @@ class LLVMIRGenerator:
name=insn.name)
return llvalue
# See session.c:send_rpc_value.
def _rpc_tag(self, typ, root_type, root_loc):
def _prepare_closure_call(self, insn):
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)
# See session.c:send_rpc_value and session.c:recv_rpc_value.
def _rpc_tag(self, typ, error_handler):
if types.is_tuple(typ):
assert len(typ.elts) < 256
return b"t" + bytes([len(typ.elts)]) + \
b"".join([self._rpc_tag(elt_type, root_type, root_loc)
b"".join([self._rpc_tag(elt_type, error_handler)
for elt_type in typ.elts])
elif builtins.is_none(typ):
return b"n"
@ -580,38 +592,53 @@ class LLVMIRGenerator:
return b"s"
elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
error_handler)
elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
error_handler)
elif ir.is_option(typ):
return b"o" + self._rpc_tag(typ.params["inner"],
root_type, root_loc)
error_handler)
else:
error_handler(typ)
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
llservice = ll.Constant(ll.IntType(32), fun_type.service)
tag = b""
for arg in args:
def arg_error_handler(typ):
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(root_type)},
root_loc)
{"type": printer.name(typ)},
arg.loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls",
{"type": printer.name(typ)},
root_loc)
{"type": printer.name(arg.typ)},
arg.loc)
self.engine.process(diag)
tag += self._rpc_tag(arg.type, arg_error_handler)
tag += b":"
def _build_rpc(self, service, args, return_type):
llservice = ll.Constant(ll.IntType(32), service)
tag = b""
for arg in args:
if isinstance(arg, ir.Constant):
# Constants don't have locations, but conveniently
# they also never fail to serialize.
tag += self._rpc_tag(arg.type, arg.type, None)
else:
tag += self._rpc_tag(arg.type, arg.type, arg.loc)
def ret_error_handler(typ):
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(typ)},
fun_loc)
diag = diagnostic.Diagnostic("error",
"return type {type} is not supported in remote procedure calls",
{"type": printer.name(fun_type.ret)},
fun_loc)
self.engine.process(diag)
tag += self._rpc_tag(fun_type.ret, ret_error_handler)
tag += b"\x00"
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
lltag = self.llconst_of_const(ir.Constant(tag + b"\x00", builtins.TStr()))
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [])
llargs = []
for arg in args:
@ -620,28 +647,77 @@ class LLVMIRGenerator:
self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot)
return self.llbuiltin("send_rpc"), [llservice, lltag] + llargs
self.llbuilder.call(self.llbuiltin("send_rpc"),
[llservice, lltag] + llargs)
def prepare_call(self, insn):
if types.is_rpc_function(insn.target_function().type):
return self._build_rpc(insn.target_function().type.service,
insn.arguments(),
insn.target_function().type.ret)
# Don't waste stack space on saved arguments.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
# T result = {
# void *ptr = NULL;
# loop: int size = rpc_recv("tag", ptr);
# if(size) { ptr = alloca(size); goto loop; }
# else *(T*)ptr
# }
llprehead = self.llbuilder.basic_block
llhead = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head")
if llunwindblock:
llheadu = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head.unwind")
llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc")
lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail")
llslot = self.llbuilder.alloca(ll.IntType(8).as_pointer())
self.llbuilder.store(ll.Constant(ll.IntType(8).as_pointer(), None), llslot)
self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(llhead)
if llunwindblock:
llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llslot],
llheadu, llunwindblock)
self.llbuilder.position_at_end(llheadu)
else:
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)
llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llslot])
lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0))
self.llbuilder.cbranch(lldone, lltail, llalloc)
self.llbuilder.position_at_end(llalloc)
llalloca = self.llbuilder.alloca(ll.IntType(8), llsize)
self.llbuilder.store(llalloca, llslot)
self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(lltail)
llretty = self.llty_of_type(fun_type.ret, for_return=True)
llretptr = self.llbuilder.bitcast(llslot, llretty.as_pointer())
llret = self.llbuilder.load(llretptr)
if not builtins.is_allocated(fun_type.ret):
# We didn't allocate anything except the slot for the value itself.
# Don't waste stack space.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
if llnormalblock:
self.llbuilder.branch(llnormalblock)
return llret
def process_Call(self, insn):
llfun, llargs = self.prepare_call(insn)
if types.is_rpc_function(insn.target_function().type):
return self._build_rpc(insn.target_function().loc,
insn.target_function().type,
insn.arguments(),
llnormalblock=None, llunwindblock=None)
else:
llfun, llargs = self._prepare_closure_call(insn)
return self.llbuilder.call(llfun, llargs,
name=insn.name)
def process_Invoke(self, insn):
llfun, llargs = self.prepare_call(insn)
llnormalblock = self.map(insn.normal_target())
llunwindblock = self.map(insn.exception_target())
if types.is_rpc_function(insn.target_function().type):
return self._build_rpc(insn.target_function().loc,
insn.target_function().type,
insn.arguments(),
llnormalblock, llunwindblock)
else:
llfun, llargs = self._prepare_closure_call(insn)
return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock,
name=insn.name)

View File

@ -8,6 +8,7 @@ from artiq.language import core as core_language
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class _H2DMsgType(Enum):
@ -325,13 +326,14 @@ class CommGeneric:
def _serve_rpc(self, rpc_map):
service = self._read_int32()
args = self._receive_rpc_args(rpc_map)
logger.debug("rpc service: %d %r", service, args)
return_tag = self._read_string()
logger.debug("rpc service: %d %r -> %s", service, args, return_tag)
try:
result = rpc_map[service](*args)
if not isinstance(result, int) or not (-2**31 < result < 2**31-1):
raise ValueError("An RPC must return an int(width=32)")
except ARTIQException as exn:
except core_language.ARTIQException as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION)
@ -355,7 +357,7 @@ class CommGeneric:
for index in range(3):
self._write_int64(0)
((filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__)
(_, (filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__, 2)
self._write_string(filename)
self._write_int32(line)
self._write_int32(-1) # column not known

View File

@ -93,6 +93,7 @@ static const struct symbol runtime_exports[] = {
{"log", &log},
{"lognonl", &lognonl},
{"send_rpc", &send_rpc},
{"recv_rpc", &recv_rpc},
/* direct syscalls */
{"rtio_get_counter", &rtio_get_counter},
@ -301,7 +302,7 @@ void watchdog_clear(int id)
mailbox_send_and_wait(&request);
}
int send_rpc(int service, const char *tag, ...)
void send_rpc(int service, const char *tag, ...)
{
struct msg_rpc_send request;
@ -311,24 +312,34 @@ int send_rpc(int service, const char *tag, ...)
va_start(request.args, tag);
mailbox_send_and_wait(&request);
va_end(request.args);
}
// struct msg_base *reply;
// reply = mailbox_wait_and_receive();
// if(reply->type == MESSAGE_TYPE_RPC_REPLY) {
// int result = ((struct msg_rpc_reply *)reply)->result;
// mailbox_acknowledge();
// return result;
// } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) {
// struct artiq_exception exception;
// memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception,
// sizeof(struct artiq_exception));
// mailbox_acknowledge();
// __artiq_raise(&exception);
// } else {
// log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d",
// reply->type);
int recv_rpc(void **slot) {
struct msg_rpc_recv_request request;
struct msg_rpc_recv_reply *reply;
request.type = MESSAGE_TYPE_RPC_RECV_REQUEST;
request.slot = slot;
mailbox_send_and_wait(&request);
reply = mailbox_wait_and_receive();
if(reply->type != MESSAGE_TYPE_RPC_RECV_REPLY) {
log("Malformed MESSAGE_TYPE_RPC_RECV_REQUEST reply type %d",
reply->type);
while(1);
// }
}
if(reply->exception) {
struct artiq_exception exception;
memcpy(&exception, reply->exception,
sizeof(struct artiq_exception));
mailbox_acknowledge();
__artiq_raise(&exception);
} else {
int alloc_size = reply->alloc_size;
mailbox_acknowledge();
return alloc_size;
}
}
void lognonl(const char *fmt, ...)

View File

@ -5,7 +5,8 @@ long long int now_init(void);
void now_save(long long int now);
int watchdog_set(int ms);
void watchdog_clear(int id);
int send_rpc(int service, const char *tag, ...);
void send_rpc(int service, const char *tag, ...);
int recv_rpc(void **slot);
void lognonl(const char *fmt, ...);
void log(const char *fmt, ...);

View File

@ -17,7 +17,6 @@ enum {
MESSAGE_TYPE_RPC_SEND,
MESSAGE_TYPE_RPC_RECV_REQUEST,
MESSAGE_TYPE_RPC_RECV_REPLY,
MESSAGE_TYPE_RPC_EXCEPTION,
MESSAGE_TYPE_LOG,
MESSAGE_TYPE_BRG_READY,
@ -90,16 +89,12 @@ struct msg_rpc_send {
struct msg_rpc_recv_request {
int type;
// TODO ???
void **slot;
};
struct msg_rpc_recv_reply {
int type;
// TODO ???
};
struct msg_rpc_exception {
int type;
int alloc_size;
struct artiq_exception *exception;
};

View File

@ -476,7 +476,8 @@ static int process_input(void)
// }
case REMOTEMSG_TYPE_RPC_EXCEPTION: {
struct msg_rpc_exception reply;
struct msg_rpc_recv_request *request;
struct msg_rpc_recv_reply reply;
struct artiq_exception exception;
exception.name = in_packet_string();
@ -494,7 +495,15 @@ static int process_input(void)
return 0; // restart session
}
reply.type = MESSAGE_TYPE_RPC_EXCEPTION;
request = mailbox_wait_and_receive();
if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) {
log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d",
request->type);
return 0; // restart session
}
reply.type = MESSAGE_TYPE_RPC_RECV_REPLY;
reply.alloc_size = 0;
reply.exception = &exception;
mailbox_send_and_wait(&reply);
@ -650,15 +659,17 @@ static int send_rpc_request(int service, const char *tag, va_list args)
out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST);
out_packet_int32(service);
while(*tag) {
while(*tag != ':') {
void *value = va_arg(args, void*);
if(!kloader_validate_kpointer(value))
return 0;
if(!send_rpc_value(&tag, &value))
return 0;
}
out_packet_int8(0);
out_packet_string(tag + 1);
out_packet_finish();
return 1;
}
@ -670,6 +681,12 @@ static int process_kmsg(struct msg_base *umsg)
return 0;
if(kloader_is_essential_kmsg(umsg->type))
return 1; /* handled elsewhere */
if(user_kernel_state == USER_KERNEL_WAIT_RPC &&
umsg->type == MESSAGE_TYPE_RPC_RECV_REQUEST) {
// Handled and acknowledged when we receive
// REMOTEMSG_TYPE_RPC_{EXCEPTION,REPLY}.
return 1;
}
if(user_kernel_state != USER_KERNEL_RUNNING) {
log("Received unexpected message from kernel CPU while not in running state");
return 0;
@ -739,7 +756,8 @@ static int process_kmsg(struct msg_base *umsg)
struct msg_rpc_send *msg = (struct msg_rpc_send *)umsg;
if(!send_rpc_request(msg->service, msg->tag, msg->args)) {
log("Failed to send RPC request");
log("Failed to send RPC request (service %d, tag %s)",
msg->service, msg->tag);
return 0; // restart session
}