From d4270cf66e8d82e4783a5f63fb8d4ecd6d278fa9 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sun, 9 Aug 2015 20:17:00 +0300 Subject: [PATCH] Implement receiving data from RPCs. --- .../compiler/transforms/llvm_ir_generator.py | 21 +- artiq/coredevice/comm_generic.py | 114 +++++++- soc/runtime/ksupport.c | 2 +- soc/runtime/ksupport.h | 2 +- soc/runtime/messages.h | 2 +- soc/runtime/session.c | 271 ++++++++++++++---- 6 files changed, 337 insertions(+), 75 deletions(-) diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 9cda1b803..3667e6474 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -162,7 +162,7 @@ class LLVMIRGenerator: llty = ll.FunctionType(ll.VoidType(), [ll.IntType(32), ll.IntType(8).as_pointer()], var_arg=True) elif name == "recv_rpc": - llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer().as_pointer()]) + llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer()]) else: assert False @@ -571,7 +571,7 @@ class LLVMIRGenerator: llfun = self.llbuilder.extract_value(llclosure, 1) return llfun, [llenv] + list(llargs) - # See session.c:send_rpc_value and session.c:recv_rpc_value. + # See session.c:{send,receive}_rpc_value and comm_generic.py:_{send,receive}_rpc_value. def _rpc_tag(self, typ, error_handler): if types.is_tuple(typ): assert len(typ.elts) < 256 @@ -666,29 +666,30 @@ class LLVMIRGenerator: llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc") lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail") - llslot = self.llbuilder.alloca(ll.IntType(8).as_pointer()) - self.llbuilder.store(ll.Constant(ll.IntType(8).as_pointer(), None), llslot) + llretty = self.llty_of_type(fun_type.ret) + llslot = self.llbuilder.alloca(llretty) + llslotgen = self.llbuilder.bitcast(llslot, ll.IntType(8).as_pointer()) self.llbuilder.branch(llhead) self.llbuilder.position_at_end(llhead) + llphi = self.llbuilder.phi(llslotgen.type) + llphi.add_incoming(llslotgen, llprehead) if llunwindblock: - llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llslot], + llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llphi], llheadu, llunwindblock) self.llbuilder.position_at_end(llheadu) else: - llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llslot]) + llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llphi]) lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0)) self.llbuilder.cbranch(lldone, lltail, llalloc) self.llbuilder.position_at_end(llalloc) llalloca = self.llbuilder.alloca(ll.IntType(8), llsize) - self.llbuilder.store(llalloca, llslot) + llphi.add_incoming(llalloca, llalloc) self.llbuilder.branch(llhead) self.llbuilder.position_at_end(lltail) - llretty = self.llty_of_type(fun_type.ret, for_return=True) - llretptr = self.llbuilder.bitcast(llslot, llretty.as_pointer()) - llret = self.llbuilder.load(llretptr) + llret = self.llbuilder.load(llslot) if not builtins.is_allocated(fun_type.ret): # We didn't allocate anything except the slot for the value itself. # Don't waste stack space. diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index 267708afe..5c604779f 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -8,7 +8,6 @@ from artiq.language import core as core_language logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) class _H2DMsgType(Enum): @@ -51,6 +50,9 @@ class _D2HMsgType(Enum): class UnsupportedDevice(Exception): pass +class RPCReturnValueError(ValueError): + pass + class CommGeneric: def __init__(self): @@ -279,6 +281,7 @@ class CommGeneric: _rpc_sentinel = object() + # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. def _receive_rpc_value(self, rpc_map): tag = chr(self._read_int8()) if tag == "\x00": @@ -306,11 +309,15 @@ class CommGeneric: length = self._read_int32() return [self._receive_rpc_value(rpc_map) for _ in range(length)] elif tag == "r": - lower = self._receive_rpc_value(rpc_map) - upper = self._receive_rpc_value(rpc_map) + start = self._receive_rpc_value(rpc_map) + stop = self._receive_rpc_value(rpc_map) step = self._receive_rpc_value(rpc_map) - return range(lower, upper, step) + return range(start, stop, step) elif tag == "o": + present = self._read_int8() + if present: + return self._receive_rpc_value(rpc_map) + elif tag == "O": return rpc_map[self._read_int32()] else: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) @@ -323,16 +330,101 @@ class CommGeneric: return args args.append(value) + def _skip_rpc_value(self, tags): + tag = tags.pop(0) + if tag == "t": + length = tags.pop(0) + for _ in range(length): + self._skip_rpc_value(tags) + elif tag == "l": + self._skip_rpc_value(tags) + elif tag == "r": + self._skip_rpc_value(tags) + else: + pass + + def _send_rpc_value(self, tags, value, root, function): + def check(cond, expected): + if not cond: + raise RPCReturnValueError( + "type mismatch: cannot serialize {value} as {type}" + " ({function} has returned {root})".format( + value=repr(value), type=expected(), + function=function, root=root)) + + tag = chr(tags.pop(0)) + if tag == "t": + length = tags.pop(0) + check(isinstance(value, tuple) and length == len(value), + lambda: "tuple of {}".format(length)) + for elt in value: + self._send_rpc_value(tags, elt, root, function) + elif tag == "n": + check(value is None, + lambda: "None") + elif tag == "b": + check(isinstance(value, bool), + lambda: "bool") + self._write_int8(value) + elif tag == "i": + check(isinstance(value, int) and (-2**31 < value < 2**31-1), + lambda: "32-bit int") + self._write_int32(value) + elif tag == "I": + check(isinstance(value, int) and (-2**63 < value < 2**63-1), + lambda: "64-bit int") + self._write_int64(value) + elif tag == "f": + check(isinstance(value, float), + lambda: "float") + self._write_float64(value) + elif tag == "F": + check(isinstance(value, Fraction) and + (-2**63 < value.numerator < 2**63-1) and + (-2**63 < value.denominator < 2**63-1), + lambda: "64-bit Fraction") + self._write_int64(value.numerator) + self._write_int64(value.denominator) + elif tag == "s": + check(isinstance(value, str) and "\x00" not in value, + lambda: "str") + self._write_string(value) + elif tag == "l": + check(isinstance(value, list), + lambda: "list") + self._write_int32(len(value)) + for elt in value: + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, elt, root, function) + self._skip_rpc_value(tags) + elif tag == "r": + check(isinstance(value, range), + lambda: "range") + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, value.start, root, function) + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, value.stop, root, function) + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, value.step, root, function) + tags = tags_copy + else: + raise IOError("Unknown RPC value tag: {}".format(repr(tag))) + def _serve_rpc(self, rpc_map): service = self._read_int32() args = self._receive_rpc_args(rpc_map) - return_tag = self._read_string() - logger.debug("rpc service: %d %r -> %s", service, args, return_tag) + return_tags = self._read_bytes() + logger.debug("rpc service: %d %r -> %s", service, args, return_tags) try: result = rpc_map[service](*args) - if not isinstance(result, int) or not (-2**31 < result < 2**31-1): - raise ValueError("An RPC must return an int(width=32)") + logger.debug("rpc service: %d %r == %r", service, args, result) + + self._write_header(_H2DMsgType.RPC_REPLY) + self._write_bytes(return_tags) + self._send_rpc_value(bytearray(return_tags), result, result, + rpc_map[service]) + self._write_flush() except core_language.ARTIQException as exn: logger.debug("rpc service: %d %r ! %r", service, args, exn) @@ -364,12 +456,6 @@ class CommGeneric: self._write_string(function) self._write_flush() - else: - logger.debug("rpc service: %d %r == %r", service, args, result) - - self._write_header(_H2DMsgType.RPC_REPLY) - self._write_int32(result) - self._write_flush() def _serve_exception(self): name = self._read_string() diff --git a/soc/runtime/ksupport.c b/soc/runtime/ksupport.c index 7a681bc59..8abee3969 100644 --- a/soc/runtime/ksupport.c +++ b/soc/runtime/ksupport.c @@ -314,7 +314,7 @@ void send_rpc(int service, const char *tag, ...) va_end(request.args); } -int recv_rpc(void **slot) { +int recv_rpc(void *slot) { struct msg_rpc_recv_request request; struct msg_rpc_recv_reply *reply; diff --git a/soc/runtime/ksupport.h b/soc/runtime/ksupport.h index 2171324a4..88dc7e2a0 100644 --- a/soc/runtime/ksupport.h +++ b/soc/runtime/ksupport.h @@ -6,7 +6,7 @@ void now_save(long long int now); int watchdog_set(int ms); void watchdog_clear(int id); void send_rpc(int service, const char *tag, ...); -int recv_rpc(void **slot); +int recv_rpc(void *slot); void lognonl(const char *fmt, ...); void log(const char *fmt, ...); diff --git a/soc/runtime/messages.h b/soc/runtime/messages.h index 533f296ca..3914cdff6 100644 --- a/soc/runtime/messages.h +++ b/soc/runtime/messages.h @@ -89,7 +89,7 @@ struct msg_rpc_send { struct msg_rpc_recv_request { int type; - void **slot; + void *slot; }; struct msg_rpc_recv_reply { diff --git a/soc/runtime/session.c b/soc/runtime/session.c index 84e04265c..88bcfaefa 100644 --- a/soc/runtime/session.c +++ b/soc/runtime/session.c @@ -147,7 +147,7 @@ static const char *in_packet_string() { int length; const char *string = in_packet_bytes(&length); - if(string[length] != 0) { + if(string[length - 1] != 0) { log("session.c: string is not zero-terminated"); return ""; } @@ -346,6 +346,8 @@ enum { REMOTEMSG_TYPE_FLASH_ERROR_REPLY }; +static int receive_rpc_value(const char **tag, void **slot); + static int process_input(void) { switch(buffer_in.header.type) { @@ -457,23 +459,37 @@ static int process_input(void) user_kernel_state = USER_KERNEL_RUNNING; break; - // case REMOTEMSG_TYPE_RPC_REPLY: { - // struct msg_rpc_reply reply; + case REMOTEMSG_TYPE_RPC_REPLY: { + struct msg_rpc_recv_request *request; + struct msg_rpc_recv_reply reply; - // int result = in_packet_int32(); + if(user_kernel_state != USER_KERNEL_WAIT_RPC) { + log("Unsolicited RPC reply"); + return 0; // restart session + } - // if(user_kernel_state != USER_KERNEL_WAIT_RPC) { - // log("Unsolicited RPC reply"); - // return 0; // restart session - // } + request = mailbox_wait_and_receive(); + if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) { + log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d", + request->type); + return 0; // restart session + } - // reply.type = MESSAGE_TYPE_RPC_REPLY; - // reply.result = result; - // mailbox_send_and_wait(&reply); + const char *tag = in_packet_string(); + void *slot = request->slot; + if(!receive_rpc_value(&tag, &slot)) { + log("Failed to receive RPC reply"); + return 0; // restart session + } - // user_kernel_state = USER_KERNEL_RUNNING; - // break; - // } + reply.type = MESSAGE_TYPE_RPC_RECV_REPLY; + reply.alloc_size = 0; + reply.exception = NULL; + mailbox_send_and_wait(&reply); + + user_kernel_state = USER_KERNEL_RUNNING; + break; + } case REMOTEMSG_TYPE_RPC_EXCEPTION: { struct msg_rpc_recv_request *request; @@ -512,13 +528,191 @@ static int process_input(void) } default: + log("Received invalid packet type %d from host", + buffer_in.header.type); + return 0; + } + + return 1; +} + +// See comm_generic.py:_{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. +static void skip_rpc_value(const char **tag) { + switch(*(*tag)++) { + case 't': { + int size = *(*tag)++; + for(int i = 0; i < size; i++) + skip_rpc_value(tag); + break; + } + + case 'l': + skip_rpc_value(tag); + break; + + case 'r': + skip_rpc_value(tag); + break; + } +} + +static int sizeof_rpc_value(const char **tag) +{ + switch(*(*tag)++) { + case 't': { // tuple + int size = *(*tag)++; + + int32_t length = 0; + for(int i = 0; i < size; i++) + length += sizeof_rpc_value(tag); + return length; + } + + case 'n': // None + return 0; + + case 'b': // bool + return sizeof(int8_t); + + case 'i': // int(width=32) + return sizeof(int32_t); + + case 'I': // int(width=64) + return sizeof(int64_t); + + case 'f': // float + return sizeof(double); + + case 'F': // Fraction + return sizeof(struct { int64_t numerator, denominator; }); + + case 's': // string + return sizeof(char *); + + case 'l': // list(elt='a) + skip_rpc_value(tag); + return sizeof(struct { int32_t length; struct {} *elements; }); + + case 'r': // range(elt='a) + return sizeof_rpc_value(tag) * 3; + + default: + log("sizeof_rpc_value: unknown tag %02x", *((*tag) - 1)); + return 0; + } +} + +static void *alloc_rpc_value(int size) +{ + struct msg_rpc_recv_request *request; + struct msg_rpc_recv_reply reply; + + reply.type = MESSAGE_TYPE_RPC_RECV_REPLY; + reply.alloc_size = size; + reply.exception = NULL; + mailbox_send_and_wait(&reply); + + request = mailbox_wait_and_receive(); + if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) { + log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d", + request->type); + return NULL; + } + return request->slot; +} + +static int receive_rpc_value(const char **tag, void **slot) +{ + switch(*(*tag)++) { + case 't': { // tuple + int size = *(*tag)++; + + for(int i = 0; i < size; i++) { + if(!receive_rpc_value(tag, slot)) + return 0; + } + break; + } + + case 'n': // None + break; + + case 'b': { // bool + *((*(int8_t**)slot)++) = in_packet_int8(); + break; + } + + case 'i': { // int(width=32) + *((*(int32_t**)slot)++) = in_packet_int32(); + break; + } + + case 'I': { // int(width=64) + *((*(int64_t**)slot)++) = in_packet_int64(); + break; + } + + case 'f': { // float + *((*(int64_t**)slot)++) = in_packet_int64(); + break; + } + + case 'F': { // Fraction + struct { int64_t numerator, denominator; } *fraction = *slot; + fraction->numerator = in_packet_int64(); + fraction->denominator = in_packet_int64(); + *slot = (void*)((intptr_t)(*slot) + sizeof(*fraction)); + break; + } + + case 's': { // string + const char *in_string = in_packet_string(); + char *out_string = alloc_rpc_value(strlen(in_string) + 1); + memcpy(out_string, in_string, strlen(in_string) + 1); + *((*(char***)slot)++) = out_string; + break; + } + + case 'l': { // list(elt='a) + struct { int32_t length; struct {} *elements; } *list = *slot; + list->length = in_packet_int32(); + + const char *tag_copy = *tag; + list->elements = alloc_rpc_value(sizeof_rpc_value(&tag_copy) * list->length); + + void *element = list->elements; + for(int i = 0; i < list->length; i++) { + const char *tag_copy = *tag; + if(!receive_rpc_value(&tag_copy, &element)) + return 0; + } + skip_rpc_value(tag); + break; + } + + case 'r': { // range(elt='a) + const char *tag_copy; + tag_copy = *tag; + if(!receive_rpc_value(&tag_copy, slot)) // min + return 0; + tag_copy = *tag; + if(!receive_rpc_value(&tag_copy, slot)) // max + return 0; + tag_copy = *tag; + if(!receive_rpc_value(&tag_copy, slot)) // step + return 0; + *tag = tag_copy; + break; + } + + default: + log("receive_rpc_value: unknown tag %02x", *((*tag) - 1)); return 0; } return 1; } -// See llvm_ir_generator.py:_rpc_tag. static int send_rpc_value(const char **tag, void **value) { if(!out_packet_int8(**tag)) @@ -541,51 +735,33 @@ static int send_rpc_value(const char **tag, void **value) break; case 'b': { // bool - int size = sizeof(int8_t); - if(!out_packet_chunk(*value, size)) - return 0; - *value = (void*)((intptr_t)(*value) + size); - break; + return out_packet_int8(*((*(int8_t**)value)++)); } case 'i': { // int(width=32) - int size = sizeof(int32_t); - if(!out_packet_chunk(*value, size)) - return 0; - *value = (void*)((intptr_t)(*value) + size); - break; + return out_packet_int32(*((*(int32_t**)value)++)); } case 'I': { // int(width=64) - int size = sizeof(int64_t); - if(!out_packet_chunk(*value, size)) - return 0; - *value = (void*)((intptr_t)(*value) + size); - break; + return out_packet_int64(*((*(int64_t**)value)++)); } case 'f': { // float - int size = sizeof(double); - if(!out_packet_chunk(*value, size)) - return 0; - *value = (void*)((intptr_t)(*value) + size); - break; + return out_packet_float64(*((*(double**)value)++)); } case 'F': { // Fraction - int size = sizeof(int64_t) * 2; - if(!out_packet_chunk(*value, size)) + struct { int64_t numerator, denominator; } *fraction = *value; + if(!out_packet_int64(fraction->numerator)) return 0; - *value = (void*)((intptr_t)(*value) + size); + if(!out_packet_int64(fraction->denominator)) + return 0; + *value = (void*)((intptr_t)(*value) + sizeof(*fraction)); break; } case 's': { // string - const char **string = *value; - if(!out_packet_string(*string)) - return 0; - *value = (void*)((intptr_t)(*value) + strlen(*string) + 1); - break; + return out_packet_string(*((*(const char***)value)++)); } case 'l': { // list(elt='a) @@ -595,11 +771,11 @@ static int send_rpc_value(const char **tag, void **value) if(!out_packet_int32(list->length)) return 0; - const char *tag_copy; + const char *tag_copy = *tag; for(int i = 0; i < list->length; i++) { - tag_copy = *tag; if(!send_rpc_value(&tag_copy, &element)) return 0; + tag_copy = *tag; } *tag = tag_copy; @@ -634,7 +810,7 @@ static int send_rpc_value(const char **tag, void **value) if(option->present) { return send_rpc_value(tag, &contents); } else { - (*tag)++; + skip_rpc_value(tag); break; } } @@ -668,8 +844,7 @@ static int send_rpc_request(int service, const char *tag, va_list args) } out_packet_int8(0); - out_packet_string(tag + 1); - + out_packet_string(tag + 1); // return tags out_packet_finish(); return 1; }