From 02b1543c630132e402c7e5a27029568ae985c5f0 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sun, 9 Aug 2015 16:16:41 +0300 Subject: [PATCH] Implement receiving exceptions from RPCs. --- .../compiler/transforms/llvm_ir_generator.py | 158 +++++++++++++----- artiq/coredevice/comm_generic.py | 8 +- soc/runtime/ksupport.c | 45 +++-- soc/runtime/ksupport.h | 3 +- soc/runtime/messages.h | 9 +- soc/runtime/session.c | 28 +++- 6 files changed, 177 insertions(+), 74 deletions(-) diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 6233c103f..9cda1b803 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -146,6 +146,10 @@ class LLVMIRGenerator: llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)]) elif name == "llvm.copysign.f64": llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()]) + elif name == "llvm.stacksave": + llty = ll.FunctionType(ll.IntType(8).as_pointer(), []) + elif name == "llvm.stackrestore": + llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()]) elif name == self.target.print_function: llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True) elif name == "__artiq_personality": @@ -155,8 +159,10 @@ class LLVMIRGenerator: elif name == "__artiq_reraise": llty = ll.FunctionType(ll.VoidType(), []) elif name == "send_rpc": - llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()], + llty = ll.FunctionType(ll.VoidType(), [ll.IntType(32), ll.IntType(8).as_pointer()], var_arg=True) + elif name == "recv_rpc": + llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer().as_pointer()]) else: assert False @@ -559,12 +565,18 @@ class LLVMIRGenerator: name=insn.name) return llvalue - # See session.c:send_rpc_value. - def _rpc_tag(self, typ, root_type, root_loc): + def _prepare_closure_call(self, insn): + llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments()) + llenv = self.llbuilder.extract_value(llclosure, 0) + llfun = self.llbuilder.extract_value(llclosure, 1) + return llfun, [llenv] + list(llargs) + + # See session.c:send_rpc_value and session.c:recv_rpc_value. + def _rpc_tag(self, typ, error_handler): if types.is_tuple(typ): assert len(typ.elts) < 256 return b"t" + bytes([len(typ.elts)]) + \ - b"".join([self._rpc_tag(elt_type, root_type, root_loc) + b"".join([self._rpc_tag(elt_type, error_handler) for elt_type in typ.elts]) elif builtins.is_none(typ): return b"n" @@ -580,38 +592,53 @@ class LLVMIRGenerator: return b"s" elif builtins.is_list(typ): return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ), - root_type, root_loc) + error_handler) elif builtins.is_range(typ): return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ), - root_type, root_loc) + error_handler) elif ir.is_option(typ): return b"o" + self._rpc_tag(typ.params["inner"], - root_type, root_loc) + error_handler) else: + error_handler(typ) + + def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock): + llservice = ll.Constant(ll.IntType(32), fun_type.service) + + tag = b"" + + for arg in args: + def arg_error_handler(typ): + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "value of type {type}", + {"type": printer.name(typ)}, + arg.loc) + diag = diagnostic.Diagnostic("error", + "type {type} is not supported in remote procedure calls", + {"type": printer.name(arg.typ)}, + arg.loc) + self.engine.process(diag) + tag += self._rpc_tag(arg.type, arg_error_handler) + tag += b":" + + def ret_error_handler(typ): printer = types.TypePrinter() note = diagnostic.Diagnostic("note", "value of type {type}", - {"type": printer.name(root_type)}, - root_loc) - diag = diagnostic.Diagnostic("error", - "type {type} is not supported in remote procedure calls", {"type": printer.name(typ)}, - root_loc) + fun_loc) + diag = diagnostic.Diagnostic("error", + "return type {type} is not supported in remote procedure calls", + {"type": printer.name(fun_type.ret)}, + fun_loc) self.engine.process(diag) - - def _build_rpc(self, service, args, return_type): - llservice = ll.Constant(ll.IntType(32), service) - - tag = b"" - for arg in args: - if isinstance(arg, ir.Constant): - # Constants don't have locations, but conveniently - # they also never fail to serialize. - tag += self._rpc_tag(arg.type, arg.type, None) - else: - tag += self._rpc_tag(arg.type, arg.type, arg.loc) + tag += self._rpc_tag(fun_type.ret, ret_error_handler) tag += b"\x00" - lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr())) + + lltag = self.llconst_of_const(ir.Constant(tag + b"\x00", builtins.TStr())) + + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), []) llargs = [] for arg in args: @@ -620,30 +647,79 @@ class LLVMIRGenerator: self.llbuilder.store(llarg, llargslot) llargs.append(llargslot) - return self.llbuiltin("send_rpc"), [llservice, lltag] + llargs + self.llbuilder.call(self.llbuiltin("send_rpc"), + [llservice, lltag] + llargs) - def prepare_call(self, insn): - if types.is_rpc_function(insn.target_function().type): - return self._build_rpc(insn.target_function().type.service, - insn.arguments(), - insn.target_function().type.ret) + # Don't waste stack space on saved arguments. + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + + # T result = { + # void *ptr = NULL; + # loop: int size = rpc_recv("tag", ptr); + # if(size) { ptr = alloca(size); goto loop; } + # else *(T*)ptr + # } + llprehead = self.llbuilder.basic_block + llhead = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head") + if llunwindblock: + llheadu = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head.unwind") + llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc") + lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail") + + llslot = self.llbuilder.alloca(ll.IntType(8).as_pointer()) + self.llbuilder.store(ll.Constant(ll.IntType(8).as_pointer(), None), llslot) + self.llbuilder.branch(llhead) + + self.llbuilder.position_at_end(llhead) + if llunwindblock: + llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llslot], + llheadu, llunwindblock) + self.llbuilder.position_at_end(llheadu) else: - llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments()) - llenv = self.llbuilder.extract_value(llclosure, 0) - llfun = self.llbuilder.extract_value(llclosure, 1) - return llfun, [llenv] + list(llargs) + llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llslot]) + lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0)) + self.llbuilder.cbranch(lldone, lltail, llalloc) + + self.llbuilder.position_at_end(llalloc) + llalloca = self.llbuilder.alloca(ll.IntType(8), llsize) + self.llbuilder.store(llalloca, llslot) + self.llbuilder.branch(llhead) + + self.llbuilder.position_at_end(lltail) + llretty = self.llty_of_type(fun_type.ret, for_return=True) + llretptr = self.llbuilder.bitcast(llslot, llretty.as_pointer()) + llret = self.llbuilder.load(llretptr) + if not builtins.is_allocated(fun_type.ret): + # We didn't allocate anything except the slot for the value itself. + # Don't waste stack space. + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + if llnormalblock: + self.llbuilder.branch(llnormalblock) + return llret def process_Call(self, insn): - llfun, llargs = self.prepare_call(insn) - return self.llbuilder.call(llfun, llargs, - name=insn.name) + if types.is_rpc_function(insn.target_function().type): + return self._build_rpc(insn.target_function().loc, + insn.target_function().type, + insn.arguments(), + llnormalblock=None, llunwindblock=None) + else: + llfun, llargs = self._prepare_closure_call(insn) + return self.llbuilder.call(llfun, llargs, + name=insn.name) def process_Invoke(self, insn): - llfun, llargs = self.prepare_call(insn) llnormalblock = self.map(insn.normal_target()) llunwindblock = self.map(insn.exception_target()) - return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock, - name=insn.name) + if types.is_rpc_function(insn.target_function().type): + return self._build_rpc(insn.target_function().loc, + insn.target_function().type, + insn.arguments(), + llnormalblock, llunwindblock) + else: + llfun, llargs = self._prepare_closure_call(insn) + return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock, + name=insn.name) def process_Select(self, insn): return self.llbuilder.select(self.map(insn.condition()), diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index 54246d525..267708afe 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -8,6 +8,7 @@ from artiq.language import core as core_language logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) class _H2DMsgType(Enum): @@ -325,13 +326,14 @@ class CommGeneric: def _serve_rpc(self, rpc_map): service = self._read_int32() args = self._receive_rpc_args(rpc_map) - logger.debug("rpc service: %d %r", service, args) + return_tag = self._read_string() + logger.debug("rpc service: %d %r -> %s", service, args, return_tag) try: result = rpc_map[service](*args) if not isinstance(result, int) or not (-2**31 < result < 2**31-1): raise ValueError("An RPC must return an int(width=32)") - except ARTIQException as exn: + except core_language.ARTIQException as exn: logger.debug("rpc service: %d %r ! %r", service, args, exn) self._write_header(_H2DMsgType.RPC_EXCEPTION) @@ -355,7 +357,7 @@ class CommGeneric: for index in range(3): self._write_int64(0) - ((filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__) + (_, (filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__, 2) self._write_string(filename) self._write_int32(line) self._write_int32(-1) # column not known diff --git a/soc/runtime/ksupport.c b/soc/runtime/ksupport.c index 187e33220..7a681bc59 100644 --- a/soc/runtime/ksupport.c +++ b/soc/runtime/ksupport.c @@ -93,6 +93,7 @@ static const struct symbol runtime_exports[] = { {"log", &log}, {"lognonl", &lognonl}, {"send_rpc", &send_rpc}, + {"recv_rpc", &recv_rpc}, /* direct syscalls */ {"rtio_get_counter", &rtio_get_counter}, @@ -301,7 +302,7 @@ void watchdog_clear(int id) mailbox_send_and_wait(&request); } -int send_rpc(int service, const char *tag, ...) +void send_rpc(int service, const char *tag, ...) { struct msg_rpc_send request; @@ -311,24 +312,34 @@ int send_rpc(int service, const char *tag, ...) va_start(request.args, tag); mailbox_send_and_wait(&request); va_end(request.args); +} - // struct msg_base *reply; - // reply = mailbox_wait_and_receive(); - // if(reply->type == MESSAGE_TYPE_RPC_REPLY) { - // int result = ((struct msg_rpc_reply *)reply)->result; - // mailbox_acknowledge(); - // return result; - // } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) { - // struct artiq_exception exception; - // memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception, - // sizeof(struct artiq_exception)); - // mailbox_acknowledge(); - // __artiq_raise(&exception); - // } else { - // log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d", - // reply->type); +int recv_rpc(void **slot) { + struct msg_rpc_recv_request request; + struct msg_rpc_recv_reply *reply; + + request.type = MESSAGE_TYPE_RPC_RECV_REQUEST; + request.slot = slot; + mailbox_send_and_wait(&request); + + reply = mailbox_wait_and_receive(); + if(reply->type != MESSAGE_TYPE_RPC_RECV_REPLY) { + log("Malformed MESSAGE_TYPE_RPC_RECV_REQUEST reply type %d", + reply->type); while(1); - // } + } + + if(reply->exception) { + struct artiq_exception exception; + memcpy(&exception, reply->exception, + sizeof(struct artiq_exception)); + mailbox_acknowledge(); + __artiq_raise(&exception); + } else { + int alloc_size = reply->alloc_size; + mailbox_acknowledge(); + return alloc_size; + } } void lognonl(const char *fmt, ...) diff --git a/soc/runtime/ksupport.h b/soc/runtime/ksupport.h index 2aa83ce63..2171324a4 100644 --- a/soc/runtime/ksupport.h +++ b/soc/runtime/ksupport.h @@ -5,7 +5,8 @@ long long int now_init(void); void now_save(long long int now); int watchdog_set(int ms); void watchdog_clear(int id); -int send_rpc(int service, const char *tag, ...); +void send_rpc(int service, const char *tag, ...); +int recv_rpc(void **slot); void lognonl(const char *fmt, ...); void log(const char *fmt, ...); diff --git a/soc/runtime/messages.h b/soc/runtime/messages.h index 46294f8e6..533f296ca 100644 --- a/soc/runtime/messages.h +++ b/soc/runtime/messages.h @@ -17,7 +17,6 @@ enum { MESSAGE_TYPE_RPC_SEND, MESSAGE_TYPE_RPC_RECV_REQUEST, MESSAGE_TYPE_RPC_RECV_REPLY, - MESSAGE_TYPE_RPC_EXCEPTION, MESSAGE_TYPE_LOG, MESSAGE_TYPE_BRG_READY, @@ -90,16 +89,12 @@ struct msg_rpc_send { struct msg_rpc_recv_request { int type; - // TODO ??? + void **slot; }; struct msg_rpc_recv_reply { int type; - // TODO ??? -}; - -struct msg_rpc_exception { - int type; + int alloc_size; struct artiq_exception *exception; }; diff --git a/soc/runtime/session.c b/soc/runtime/session.c index f5e846477..84e04265c 100644 --- a/soc/runtime/session.c +++ b/soc/runtime/session.c @@ -476,7 +476,8 @@ static int process_input(void) // } case REMOTEMSG_TYPE_RPC_EXCEPTION: { - struct msg_rpc_exception reply; + struct msg_rpc_recv_request *request; + struct msg_rpc_recv_reply reply; struct artiq_exception exception; exception.name = in_packet_string(); @@ -494,7 +495,15 @@ static int process_input(void) return 0; // restart session } - reply.type = MESSAGE_TYPE_RPC_EXCEPTION; + request = mailbox_wait_and_receive(); + if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) { + log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d", + request->type); + return 0; // restart session + } + + reply.type = MESSAGE_TYPE_RPC_RECV_REPLY; + reply.alloc_size = 0; reply.exception = &exception; mailbox_send_and_wait(&reply); @@ -650,15 +659,17 @@ static int send_rpc_request(int service, const char *tag, va_list args) out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST); out_packet_int32(service); - while(*tag) { + while(*tag != ':') { void *value = va_arg(args, void*); if(!kloader_validate_kpointer(value)) return 0; if(!send_rpc_value(&tag, &value)) return 0; } - out_packet_int8(0); + + out_packet_string(tag + 1); + out_packet_finish(); return 1; } @@ -670,6 +681,12 @@ static int process_kmsg(struct msg_base *umsg) return 0; if(kloader_is_essential_kmsg(umsg->type)) return 1; /* handled elsewhere */ + if(user_kernel_state == USER_KERNEL_WAIT_RPC && + umsg->type == MESSAGE_TYPE_RPC_RECV_REQUEST) { + // Handled and acknowledged when we receive + // REMOTEMSG_TYPE_RPC_{EXCEPTION,REPLY}. + return 1; + } if(user_kernel_state != USER_KERNEL_RUNNING) { log("Received unexpected message from kernel CPU while not in running state"); return 0; @@ -739,7 +756,8 @@ static int process_kmsg(struct msg_base *umsg) struct msg_rpc_send *msg = (struct msg_rpc_send *)umsg; if(!send_rpc_request(msg->service, msg->tag, msg->args)) { - log("Failed to send RPC request"); + log("Failed to send RPC request (service %d, tag %s)", + msg->service, msg->tag); return 0; // restart session }