Implement receiving exceptions from RPCs.

This commit is contained in:
whitequark 2015-08-09 16:16:41 +03:00
parent 8b7d38d203
commit 02b1543c63
6 changed files with 177 additions and 74 deletions

View File

@ -146,6 +146,10 @@ class LLVMIRGenerator:
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)]) llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)])
elif name == "llvm.copysign.f64": elif name == "llvm.copysign.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()]) llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
elif name == "llvm.stacksave":
llty = ll.FunctionType(ll.IntType(8).as_pointer(), [])
elif name == "llvm.stackrestore":
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()])
elif name == self.target.print_function: elif name == self.target.print_function:
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True) llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True)
elif name == "__artiq_personality": elif name == "__artiq_personality":
@ -155,8 +159,10 @@ class LLVMIRGenerator:
elif name == "__artiq_reraise": elif name == "__artiq_reraise":
llty = ll.FunctionType(ll.VoidType(), []) llty = ll.FunctionType(ll.VoidType(), [])
elif name == "send_rpc": elif name == "send_rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()], llty = ll.FunctionType(ll.VoidType(), [ll.IntType(32), ll.IntType(8).as_pointer()],
var_arg=True) var_arg=True)
elif name == "recv_rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer().as_pointer()])
else: else:
assert False assert False
@ -559,12 +565,18 @@ class LLVMIRGenerator:
name=insn.name) name=insn.name)
return llvalue return llvalue
# See session.c:send_rpc_value. def _prepare_closure_call(self, insn):
def _rpc_tag(self, typ, root_type, root_loc): llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)
# See session.c:send_rpc_value and session.c:recv_rpc_value.
def _rpc_tag(self, typ, error_handler):
if types.is_tuple(typ): if types.is_tuple(typ):
assert len(typ.elts) < 256 assert len(typ.elts) < 256
return b"t" + bytes([len(typ.elts)]) + \ return b"t" + bytes([len(typ.elts)]) + \
b"".join([self._rpc_tag(elt_type, root_type, root_loc) b"".join([self._rpc_tag(elt_type, error_handler)
for elt_type in typ.elts]) for elt_type in typ.elts])
elif builtins.is_none(typ): elif builtins.is_none(typ):
return b"n" return b"n"
@ -580,38 +592,53 @@ class LLVMIRGenerator:
return b"s" return b"s"
elif builtins.is_list(typ): elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ), return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc) error_handler)
elif builtins.is_range(typ): elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ), return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc) error_handler)
elif ir.is_option(typ): elif ir.is_option(typ):
return b"o" + self._rpc_tag(typ.params["inner"], return b"o" + self._rpc_tag(typ.params["inner"],
root_type, root_loc) error_handler)
else: else:
error_handler(typ)
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
llservice = ll.Constant(ll.IntType(32), fun_type.service)
tag = b""
for arg in args:
def arg_error_handler(typ):
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(typ)},
arg.loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls",
{"type": printer.name(arg.typ)},
arg.loc)
self.engine.process(diag)
tag += self._rpc_tag(arg.type, arg_error_handler)
tag += b":"
def ret_error_handler(typ):
printer = types.TypePrinter() printer = types.TypePrinter()
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
"value of type {type}", "value of type {type}",
{"type": printer.name(root_type)},
root_loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls",
{"type": printer.name(typ)}, {"type": printer.name(typ)},
root_loc) fun_loc)
diag = diagnostic.Diagnostic("error",
"return type {type} is not supported in remote procedure calls",
{"type": printer.name(fun_type.ret)},
fun_loc)
self.engine.process(diag) self.engine.process(diag)
tag += self._rpc_tag(fun_type.ret, ret_error_handler)
def _build_rpc(self, service, args, return_type):
llservice = ll.Constant(ll.IntType(32), service)
tag = b""
for arg in args:
if isinstance(arg, ir.Constant):
# Constants don't have locations, but conveniently
# they also never fail to serialize.
tag += self._rpc_tag(arg.type, arg.type, None)
else:
tag += self._rpc_tag(arg.type, arg.type, arg.loc)
tag += b"\x00" tag += b"\x00"
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
lltag = self.llconst_of_const(ir.Constant(tag + b"\x00", builtins.TStr()))
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [])
llargs = [] llargs = []
for arg in args: for arg in args:
@ -620,30 +647,79 @@ class LLVMIRGenerator:
self.llbuilder.store(llarg, llargslot) self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot) llargs.append(llargslot)
return self.llbuiltin("send_rpc"), [llservice, lltag] + llargs self.llbuilder.call(self.llbuiltin("send_rpc"),
[llservice, lltag] + llargs)
def prepare_call(self, insn): # Don't waste stack space on saved arguments.
if types.is_rpc_function(insn.target_function().type): self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
return self._build_rpc(insn.target_function().type.service,
insn.arguments(), # T result = {
insn.target_function().type.ret) # void *ptr = NULL;
# loop: int size = rpc_recv("tag", ptr);
# if(size) { ptr = alloca(size); goto loop; }
# else *(T*)ptr
# }
llprehead = self.llbuilder.basic_block
llhead = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head")
if llunwindblock:
llheadu = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head.unwind")
llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc")
lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail")
llslot = self.llbuilder.alloca(ll.IntType(8).as_pointer())
self.llbuilder.store(ll.Constant(ll.IntType(8).as_pointer(), None), llslot)
self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(llhead)
if llunwindblock:
llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llslot],
llheadu, llunwindblock)
self.llbuilder.position_at_end(llheadu)
else: else:
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments()) llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llslot])
llenv = self.llbuilder.extract_value(llclosure, 0) lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0))
llfun = self.llbuilder.extract_value(llclosure, 1) self.llbuilder.cbranch(lldone, lltail, llalloc)
return llfun, [llenv] + list(llargs)
self.llbuilder.position_at_end(llalloc)
llalloca = self.llbuilder.alloca(ll.IntType(8), llsize)
self.llbuilder.store(llalloca, llslot)
self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(lltail)
llretty = self.llty_of_type(fun_type.ret, for_return=True)
llretptr = self.llbuilder.bitcast(llslot, llretty.as_pointer())
llret = self.llbuilder.load(llretptr)
if not builtins.is_allocated(fun_type.ret):
# We didn't allocate anything except the slot for the value itself.
# Don't waste stack space.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
if llnormalblock:
self.llbuilder.branch(llnormalblock)
return llret
def process_Call(self, insn): def process_Call(self, insn):
llfun, llargs = self.prepare_call(insn) if types.is_rpc_function(insn.target_function().type):
return self.llbuilder.call(llfun, llargs, return self._build_rpc(insn.target_function().loc,
name=insn.name) insn.target_function().type,
insn.arguments(),
llnormalblock=None, llunwindblock=None)
else:
llfun, llargs = self._prepare_closure_call(insn)
return self.llbuilder.call(llfun, llargs,
name=insn.name)
def process_Invoke(self, insn): def process_Invoke(self, insn):
llfun, llargs = self.prepare_call(insn)
llnormalblock = self.map(insn.normal_target()) llnormalblock = self.map(insn.normal_target())
llunwindblock = self.map(insn.exception_target()) llunwindblock = self.map(insn.exception_target())
return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock, if types.is_rpc_function(insn.target_function().type):
name=insn.name) return self._build_rpc(insn.target_function().loc,
insn.target_function().type,
insn.arguments(),
llnormalblock, llunwindblock)
else:
llfun, llargs = self._prepare_closure_call(insn)
return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock,
name=insn.name)
def process_Select(self, insn): def process_Select(self, insn):
return self.llbuilder.select(self.map(insn.condition()), return self.llbuilder.select(self.map(insn.condition()),

View File

@ -8,6 +8,7 @@ from artiq.language import core as core_language
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class _H2DMsgType(Enum): class _H2DMsgType(Enum):
@ -325,13 +326,14 @@ class CommGeneric:
def _serve_rpc(self, rpc_map): def _serve_rpc(self, rpc_map):
service = self._read_int32() service = self._read_int32()
args = self._receive_rpc_args(rpc_map) args = self._receive_rpc_args(rpc_map)
logger.debug("rpc service: %d %r", service, args) return_tag = self._read_string()
logger.debug("rpc service: %d %r -> %s", service, args, return_tag)
try: try:
result = rpc_map[service](*args) result = rpc_map[service](*args)
if not isinstance(result, int) or not (-2**31 < result < 2**31-1): if not isinstance(result, int) or not (-2**31 < result < 2**31-1):
raise ValueError("An RPC must return an int(width=32)") raise ValueError("An RPC must return an int(width=32)")
except ARTIQException as exn: except core_language.ARTIQException as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn) logger.debug("rpc service: %d %r ! %r", service, args, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION) self._write_header(_H2DMsgType.RPC_EXCEPTION)
@ -355,7 +357,7 @@ class CommGeneric:
for index in range(3): for index in range(3):
self._write_int64(0) self._write_int64(0)
((filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__) (_, (filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__, 2)
self._write_string(filename) self._write_string(filename)
self._write_int32(line) self._write_int32(line)
self._write_int32(-1) # column not known self._write_int32(-1) # column not known

View File

@ -93,6 +93,7 @@ static const struct symbol runtime_exports[] = {
{"log", &log}, {"log", &log},
{"lognonl", &lognonl}, {"lognonl", &lognonl},
{"send_rpc", &send_rpc}, {"send_rpc", &send_rpc},
{"recv_rpc", &recv_rpc},
/* direct syscalls */ /* direct syscalls */
{"rtio_get_counter", &rtio_get_counter}, {"rtio_get_counter", &rtio_get_counter},
@ -301,7 +302,7 @@ void watchdog_clear(int id)
mailbox_send_and_wait(&request); mailbox_send_and_wait(&request);
} }
int send_rpc(int service, const char *tag, ...) void send_rpc(int service, const char *tag, ...)
{ {
struct msg_rpc_send request; struct msg_rpc_send request;
@ -311,24 +312,34 @@ int send_rpc(int service, const char *tag, ...)
va_start(request.args, tag); va_start(request.args, tag);
mailbox_send_and_wait(&request); mailbox_send_and_wait(&request);
va_end(request.args); va_end(request.args);
}
// struct msg_base *reply; int recv_rpc(void **slot) {
// reply = mailbox_wait_and_receive(); struct msg_rpc_recv_request request;
// if(reply->type == MESSAGE_TYPE_RPC_REPLY) { struct msg_rpc_recv_reply *reply;
// int result = ((struct msg_rpc_reply *)reply)->result;
// mailbox_acknowledge(); request.type = MESSAGE_TYPE_RPC_RECV_REQUEST;
// return result; request.slot = slot;
// } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) { mailbox_send_and_wait(&request);
// struct artiq_exception exception;
// memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception, reply = mailbox_wait_and_receive();
// sizeof(struct artiq_exception)); if(reply->type != MESSAGE_TYPE_RPC_RECV_REPLY) {
// mailbox_acknowledge(); log("Malformed MESSAGE_TYPE_RPC_RECV_REQUEST reply type %d",
// __artiq_raise(&exception); reply->type);
// } else {
// log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d",
// reply->type);
while(1); while(1);
// } }
if(reply->exception) {
struct artiq_exception exception;
memcpy(&exception, reply->exception,
sizeof(struct artiq_exception));
mailbox_acknowledge();
__artiq_raise(&exception);
} else {
int alloc_size = reply->alloc_size;
mailbox_acknowledge();
return alloc_size;
}
} }
void lognonl(const char *fmt, ...) void lognonl(const char *fmt, ...)

View File

@ -5,7 +5,8 @@ long long int now_init(void);
void now_save(long long int now); void now_save(long long int now);
int watchdog_set(int ms); int watchdog_set(int ms);
void watchdog_clear(int id); void watchdog_clear(int id);
int send_rpc(int service, const char *tag, ...); void send_rpc(int service, const char *tag, ...);
int recv_rpc(void **slot);
void lognonl(const char *fmt, ...); void lognonl(const char *fmt, ...);
void log(const char *fmt, ...); void log(const char *fmt, ...);

View File

@ -17,7 +17,6 @@ enum {
MESSAGE_TYPE_RPC_SEND, MESSAGE_TYPE_RPC_SEND,
MESSAGE_TYPE_RPC_RECV_REQUEST, MESSAGE_TYPE_RPC_RECV_REQUEST,
MESSAGE_TYPE_RPC_RECV_REPLY, MESSAGE_TYPE_RPC_RECV_REPLY,
MESSAGE_TYPE_RPC_EXCEPTION,
MESSAGE_TYPE_LOG, MESSAGE_TYPE_LOG,
MESSAGE_TYPE_BRG_READY, MESSAGE_TYPE_BRG_READY,
@ -90,16 +89,12 @@ struct msg_rpc_send {
struct msg_rpc_recv_request { struct msg_rpc_recv_request {
int type; int type;
// TODO ??? void **slot;
}; };
struct msg_rpc_recv_reply { struct msg_rpc_recv_reply {
int type; int type;
// TODO ??? int alloc_size;
};
struct msg_rpc_exception {
int type;
struct artiq_exception *exception; struct artiq_exception *exception;
}; };

View File

@ -476,7 +476,8 @@ static int process_input(void)
// } // }
case REMOTEMSG_TYPE_RPC_EXCEPTION: { case REMOTEMSG_TYPE_RPC_EXCEPTION: {
struct msg_rpc_exception reply; struct msg_rpc_recv_request *request;
struct msg_rpc_recv_reply reply;
struct artiq_exception exception; struct artiq_exception exception;
exception.name = in_packet_string(); exception.name = in_packet_string();
@ -494,7 +495,15 @@ static int process_input(void)
return 0; // restart session return 0; // restart session
} }
reply.type = MESSAGE_TYPE_RPC_EXCEPTION; request = mailbox_wait_and_receive();
if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) {
log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d",
request->type);
return 0; // restart session
}
reply.type = MESSAGE_TYPE_RPC_RECV_REPLY;
reply.alloc_size = 0;
reply.exception = &exception; reply.exception = &exception;
mailbox_send_and_wait(&reply); mailbox_send_and_wait(&reply);
@ -650,15 +659,17 @@ static int send_rpc_request(int service, const char *tag, va_list args)
out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST); out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST);
out_packet_int32(service); out_packet_int32(service);
while(*tag) { while(*tag != ':') {
void *value = va_arg(args, void*); void *value = va_arg(args, void*);
if(!kloader_validate_kpointer(value)) if(!kloader_validate_kpointer(value))
return 0; return 0;
if(!send_rpc_value(&tag, &value)) if(!send_rpc_value(&tag, &value))
return 0; return 0;
} }
out_packet_int8(0); out_packet_int8(0);
out_packet_string(tag + 1);
out_packet_finish(); out_packet_finish();
return 1; return 1;
} }
@ -670,6 +681,12 @@ static int process_kmsg(struct msg_base *umsg)
return 0; return 0;
if(kloader_is_essential_kmsg(umsg->type)) if(kloader_is_essential_kmsg(umsg->type))
return 1; /* handled elsewhere */ return 1; /* handled elsewhere */
if(user_kernel_state == USER_KERNEL_WAIT_RPC &&
umsg->type == MESSAGE_TYPE_RPC_RECV_REQUEST) {
// Handled and acknowledged when we receive
// REMOTEMSG_TYPE_RPC_{EXCEPTION,REPLY}.
return 1;
}
if(user_kernel_state != USER_KERNEL_RUNNING) { if(user_kernel_state != USER_KERNEL_RUNNING) {
log("Received unexpected message from kernel CPU while not in running state"); log("Received unexpected message from kernel CPU while not in running state");
return 0; return 0;
@ -739,7 +756,8 @@ static int process_kmsg(struct msg_base *umsg)
struct msg_rpc_send *msg = (struct msg_rpc_send *)umsg; struct msg_rpc_send *msg = (struct msg_rpc_send *)umsg;
if(!send_rpc_request(msg->service, msg->tag, msg->args)) { if(!send_rpc_request(msg->service, msg->tag, msg->args)) {
log("Failed to send RPC request"); log("Failed to send RPC request (service %d, tag %s)",
msg->service, msg->tag);
return 0; // restart session return 0; // restart session
} }