Implement receiving data from RPCs.

This commit is contained in:
whitequark 2015-08-09 20:17:00 +03:00
parent 02b1543c63
commit d4270cf66e
6 changed files with 337 additions and 75 deletions

View File

@ -162,7 +162,7 @@ class LLVMIRGenerator:
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(32), ll.IntType(8).as_pointer()], llty = ll.FunctionType(ll.VoidType(), [ll.IntType(32), ll.IntType(8).as_pointer()],
var_arg=True) var_arg=True)
elif name == "recv_rpc": elif name == "recv_rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer().as_pointer()]) llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer()])
else: else:
assert False assert False
@ -571,7 +571,7 @@ class LLVMIRGenerator:
llfun = self.llbuilder.extract_value(llclosure, 1) llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs) return llfun, [llenv] + list(llargs)
# See session.c:send_rpc_value and session.c:recv_rpc_value. # See session.c:{send,receive}_rpc_value and comm_generic.py:_{send,receive}_rpc_value.
def _rpc_tag(self, typ, error_handler): def _rpc_tag(self, typ, error_handler):
if types.is_tuple(typ): if types.is_tuple(typ):
assert len(typ.elts) < 256 assert len(typ.elts) < 256
@ -666,29 +666,30 @@ class LLVMIRGenerator:
llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc") llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc")
lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail") lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail")
llslot = self.llbuilder.alloca(ll.IntType(8).as_pointer()) llretty = self.llty_of_type(fun_type.ret)
self.llbuilder.store(ll.Constant(ll.IntType(8).as_pointer(), None), llslot) llslot = self.llbuilder.alloca(llretty)
llslotgen = self.llbuilder.bitcast(llslot, ll.IntType(8).as_pointer())
self.llbuilder.branch(llhead) self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(llhead) self.llbuilder.position_at_end(llhead)
llphi = self.llbuilder.phi(llslotgen.type)
llphi.add_incoming(llslotgen, llprehead)
if llunwindblock: if llunwindblock:
llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llslot], llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llphi],
llheadu, llunwindblock) llheadu, llunwindblock)
self.llbuilder.position_at_end(llheadu) self.llbuilder.position_at_end(llheadu)
else: else:
llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llslot]) llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llphi])
lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0)) lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0))
self.llbuilder.cbranch(lldone, lltail, llalloc) self.llbuilder.cbranch(lldone, lltail, llalloc)
self.llbuilder.position_at_end(llalloc) self.llbuilder.position_at_end(llalloc)
llalloca = self.llbuilder.alloca(ll.IntType(8), llsize) llalloca = self.llbuilder.alloca(ll.IntType(8), llsize)
self.llbuilder.store(llalloca, llslot) llphi.add_incoming(llalloca, llalloc)
self.llbuilder.branch(llhead) self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(lltail) self.llbuilder.position_at_end(lltail)
llretty = self.llty_of_type(fun_type.ret, for_return=True) llret = self.llbuilder.load(llslot)
llretptr = self.llbuilder.bitcast(llslot, llretty.as_pointer())
llret = self.llbuilder.load(llretptr)
if not builtins.is_allocated(fun_type.ret): if not builtins.is_allocated(fun_type.ret):
# We didn't allocate anything except the slot for the value itself. # We didn't allocate anything except the slot for the value itself.
# Don't waste stack space. # Don't waste stack space.

View File

@ -8,7 +8,6 @@ from artiq.language import core as core_language
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class _H2DMsgType(Enum): class _H2DMsgType(Enum):
@ -51,6 +50,9 @@ class _D2HMsgType(Enum):
class UnsupportedDevice(Exception): class UnsupportedDevice(Exception):
pass pass
class RPCReturnValueError(ValueError):
pass
class CommGeneric: class CommGeneric:
def __init__(self): def __init__(self):
@ -279,6 +281,7 @@ class CommGeneric:
_rpc_sentinel = object() _rpc_sentinel = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, rpc_map): def _receive_rpc_value(self, rpc_map):
tag = chr(self._read_int8()) tag = chr(self._read_int8())
if tag == "\x00": if tag == "\x00":
@ -306,11 +309,15 @@ class CommGeneric:
length = self._read_int32() length = self._read_int32()
return [self._receive_rpc_value(rpc_map) for _ in range(length)] return [self._receive_rpc_value(rpc_map) for _ in range(length)]
elif tag == "r": elif tag == "r":
lower = self._receive_rpc_value(rpc_map) start = self._receive_rpc_value(rpc_map)
upper = self._receive_rpc_value(rpc_map) stop = self._receive_rpc_value(rpc_map)
step = self._receive_rpc_value(rpc_map) step = self._receive_rpc_value(rpc_map)
return range(lower, upper, step) return range(start, stop, step)
elif tag == "o": elif tag == "o":
present = self._read_int8()
if present:
return self._receive_rpc_value(rpc_map)
elif tag == "O":
return rpc_map[self._read_int32()] return rpc_map[self._read_int32()]
else: else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
@ -323,16 +330,101 @@ class CommGeneric:
return args return args
args.append(value) args.append(value)
def _skip_rpc_value(self, tags):
tag = tags.pop(0)
if tag == "t":
length = tags.pop(0)
for _ in range(length):
self._skip_rpc_value(tags)
elif tag == "l":
self._skip_rpc_value(tags)
elif tag == "r":
self._skip_rpc_value(tags)
else:
pass
def _send_rpc_value(self, tags, value, root, function):
def check(cond, expected):
if not cond:
raise RPCReturnValueError(
"type mismatch: cannot serialize {value} as {type}"
" ({function} has returned {root})".format(
value=repr(value), type=expected(),
function=function, root=root))
tag = chr(tags.pop(0))
if tag == "t":
length = tags.pop(0)
check(isinstance(value, tuple) and length == len(value),
lambda: "tuple of {}".format(length))
for elt in value:
self._send_rpc_value(tags, elt, root, function)
elif tag == "n":
check(value is None,
lambda: "None")
elif tag == "b":
check(isinstance(value, bool),
lambda: "bool")
self._write_int8(value)
elif tag == "i":
check(isinstance(value, int) and (-2**31 < value < 2**31-1),
lambda: "32-bit int")
self._write_int32(value)
elif tag == "I":
check(isinstance(value, int) and (-2**63 < value < 2**63-1),
lambda: "64-bit int")
self._write_int64(value)
elif tag == "f":
check(isinstance(value, float),
lambda: "float")
self._write_float64(value)
elif tag == "F":
check(isinstance(value, Fraction) and
(-2**63 < value.numerator < 2**63-1) and
(-2**63 < value.denominator < 2**63-1),
lambda: "64-bit Fraction")
self._write_int64(value.numerator)
self._write_int64(value.denominator)
elif tag == "s":
check(isinstance(value, str) and "\x00" not in value,
lambda: "str")
self._write_string(value)
elif tag == "l":
check(isinstance(value, list),
lambda: "list")
self._write_int32(len(value))
for elt in value:
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags)
elif tag == "r":
check(isinstance(value, range),
lambda: "range")
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.start, root, function)
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.stop, root, function)
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.step, root, function)
tags = tags_copy
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _serve_rpc(self, rpc_map): def _serve_rpc(self, rpc_map):
service = self._read_int32() service = self._read_int32()
args = self._receive_rpc_args(rpc_map) args = self._receive_rpc_args(rpc_map)
return_tag = self._read_string() return_tags = self._read_bytes()
logger.debug("rpc service: %d %r -> %s", service, args, return_tag) logger.debug("rpc service: %d %r -> %s", service, args, return_tags)
try: try:
result = rpc_map[service](*args) result = rpc_map[service](*args)
if not isinstance(result, int) or not (-2**31 < result < 2**31-1): logger.debug("rpc service: %d %r == %r", service, args, result)
raise ValueError("An RPC must return an int(width=32)")
self._write_header(_H2DMsgType.RPC_REPLY)
self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags), result, result,
rpc_map[service])
self._write_flush()
except core_language.ARTIQException as exn: except core_language.ARTIQException as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn) logger.debug("rpc service: %d %r ! %r", service, args, exn)
@ -364,12 +456,6 @@ class CommGeneric:
self._write_string(function) self._write_string(function)
self._write_flush() self._write_flush()
else:
logger.debug("rpc service: %d %r == %r", service, args, result)
self._write_header(_H2DMsgType.RPC_REPLY)
self._write_int32(result)
self._write_flush()
def _serve_exception(self): def _serve_exception(self):
name = self._read_string() name = self._read_string()

View File

@ -314,7 +314,7 @@ void send_rpc(int service, const char *tag, ...)
va_end(request.args); va_end(request.args);
} }
int recv_rpc(void **slot) { int recv_rpc(void *slot) {
struct msg_rpc_recv_request request; struct msg_rpc_recv_request request;
struct msg_rpc_recv_reply *reply; struct msg_rpc_recv_reply *reply;

View File

@ -6,7 +6,7 @@ void now_save(long long int now);
int watchdog_set(int ms); int watchdog_set(int ms);
void watchdog_clear(int id); void watchdog_clear(int id);
void send_rpc(int service, const char *tag, ...); void send_rpc(int service, const char *tag, ...);
int recv_rpc(void **slot); int recv_rpc(void *slot);
void lognonl(const char *fmt, ...); void lognonl(const char *fmt, ...);
void log(const char *fmt, ...); void log(const char *fmt, ...);

View File

@ -89,7 +89,7 @@ struct msg_rpc_send {
struct msg_rpc_recv_request { struct msg_rpc_recv_request {
int type; int type;
void **slot; void *slot;
}; };
struct msg_rpc_recv_reply { struct msg_rpc_recv_reply {

View File

@ -147,7 +147,7 @@ static const char *in_packet_string()
{ {
int length; int length;
const char *string = in_packet_bytes(&length); const char *string = in_packet_bytes(&length);
if(string[length] != 0) { if(string[length - 1] != 0) {
log("session.c: string is not zero-terminated"); log("session.c: string is not zero-terminated");
return ""; return "";
} }
@ -346,6 +346,8 @@ enum {
REMOTEMSG_TYPE_FLASH_ERROR_REPLY REMOTEMSG_TYPE_FLASH_ERROR_REPLY
}; };
static int receive_rpc_value(const char **tag, void **slot);
static int process_input(void) static int process_input(void)
{ {
switch(buffer_in.header.type) { switch(buffer_in.header.type) {
@ -457,23 +459,37 @@ static int process_input(void)
user_kernel_state = USER_KERNEL_RUNNING; user_kernel_state = USER_KERNEL_RUNNING;
break; break;
// case REMOTEMSG_TYPE_RPC_REPLY: { case REMOTEMSG_TYPE_RPC_REPLY: {
// struct msg_rpc_reply reply; struct msg_rpc_recv_request *request;
struct msg_rpc_recv_reply reply;
// int result = in_packet_int32(); if(user_kernel_state != USER_KERNEL_WAIT_RPC) {
log("Unsolicited RPC reply");
return 0; // restart session
}
// if(user_kernel_state != USER_KERNEL_WAIT_RPC) { request = mailbox_wait_and_receive();
// log("Unsolicited RPC reply"); if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) {
// return 0; // restart session log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d",
// } request->type);
return 0; // restart session
}
// reply.type = MESSAGE_TYPE_RPC_REPLY; const char *tag = in_packet_string();
// reply.result = result; void *slot = request->slot;
// mailbox_send_and_wait(&reply); if(!receive_rpc_value(&tag, &slot)) {
log("Failed to receive RPC reply");
return 0; // restart session
}
// user_kernel_state = USER_KERNEL_RUNNING; reply.type = MESSAGE_TYPE_RPC_RECV_REPLY;
// break; reply.alloc_size = 0;
// } reply.exception = NULL;
mailbox_send_and_wait(&reply);
user_kernel_state = USER_KERNEL_RUNNING;
break;
}
case REMOTEMSG_TYPE_RPC_EXCEPTION: { case REMOTEMSG_TYPE_RPC_EXCEPTION: {
struct msg_rpc_recv_request *request; struct msg_rpc_recv_request *request;
@ -512,13 +528,191 @@ static int process_input(void)
} }
default: default:
log("Received invalid packet type %d from host",
buffer_in.header.type);
return 0;
}
return 1;
}
// See comm_generic.py:_{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
static void skip_rpc_value(const char **tag) {
switch(*(*tag)++) {
case 't': {
int size = *(*tag)++;
for(int i = 0; i < size; i++)
skip_rpc_value(tag);
break;
}
case 'l':
skip_rpc_value(tag);
break;
case 'r':
skip_rpc_value(tag);
break;
}
}
static int sizeof_rpc_value(const char **tag)
{
switch(*(*tag)++) {
case 't': { // tuple
int size = *(*tag)++;
int32_t length = 0;
for(int i = 0; i < size; i++)
length += sizeof_rpc_value(tag);
return length;
}
case 'n': // None
return 0;
case 'b': // bool
return sizeof(int8_t);
case 'i': // int(width=32)
return sizeof(int32_t);
case 'I': // int(width=64)
return sizeof(int64_t);
case 'f': // float
return sizeof(double);
case 'F': // Fraction
return sizeof(struct { int64_t numerator, denominator; });
case 's': // string
return sizeof(char *);
case 'l': // list(elt='a)
skip_rpc_value(tag);
return sizeof(struct { int32_t length; struct {} *elements; });
case 'r': // range(elt='a)
return sizeof_rpc_value(tag) * 3;
default:
log("sizeof_rpc_value: unknown tag %02x", *((*tag) - 1));
return 0;
}
}
static void *alloc_rpc_value(int size)
{
struct msg_rpc_recv_request *request;
struct msg_rpc_recv_reply reply;
reply.type = MESSAGE_TYPE_RPC_RECV_REPLY;
reply.alloc_size = size;
reply.exception = NULL;
mailbox_send_and_wait(&reply);
request = mailbox_wait_and_receive();
if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) {
log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d",
request->type);
return NULL;
}
return request->slot;
}
static int receive_rpc_value(const char **tag, void **slot)
{
switch(*(*tag)++) {
case 't': { // tuple
int size = *(*tag)++;
for(int i = 0; i < size; i++) {
if(!receive_rpc_value(tag, slot))
return 0;
}
break;
}
case 'n': // None
break;
case 'b': { // bool
*((*(int8_t**)slot)++) = in_packet_int8();
break;
}
case 'i': { // int(width=32)
*((*(int32_t**)slot)++) = in_packet_int32();
break;
}
case 'I': { // int(width=64)
*((*(int64_t**)slot)++) = in_packet_int64();
break;
}
case 'f': { // float
*((*(int64_t**)slot)++) = in_packet_int64();
break;
}
case 'F': { // Fraction
struct { int64_t numerator, denominator; } *fraction = *slot;
fraction->numerator = in_packet_int64();
fraction->denominator = in_packet_int64();
*slot = (void*)((intptr_t)(*slot) + sizeof(*fraction));
break;
}
case 's': { // string
const char *in_string = in_packet_string();
char *out_string = alloc_rpc_value(strlen(in_string) + 1);
memcpy(out_string, in_string, strlen(in_string) + 1);
*((*(char***)slot)++) = out_string;
break;
}
case 'l': { // list(elt='a)
struct { int32_t length; struct {} *elements; } *list = *slot;
list->length = in_packet_int32();
const char *tag_copy = *tag;
list->elements = alloc_rpc_value(sizeof_rpc_value(&tag_copy) * list->length);
void *element = list->elements;
for(int i = 0; i < list->length; i++) {
const char *tag_copy = *tag;
if(!receive_rpc_value(&tag_copy, &element))
return 0;
}
skip_rpc_value(tag);
break;
}
case 'r': { // range(elt='a)
const char *tag_copy;
tag_copy = *tag;
if(!receive_rpc_value(&tag_copy, slot)) // min
return 0;
tag_copy = *tag;
if(!receive_rpc_value(&tag_copy, slot)) // max
return 0;
tag_copy = *tag;
if(!receive_rpc_value(&tag_copy, slot)) // step
return 0;
*tag = tag_copy;
break;
}
default:
log("receive_rpc_value: unknown tag %02x", *((*tag) - 1));
return 0; return 0;
} }
return 1; return 1;
} }
// See llvm_ir_generator.py:_rpc_tag.
static int send_rpc_value(const char **tag, void **value) static int send_rpc_value(const char **tag, void **value)
{ {
if(!out_packet_int8(**tag)) if(!out_packet_int8(**tag))
@ -541,51 +735,33 @@ static int send_rpc_value(const char **tag, void **value)
break; break;
case 'b': { // bool case 'b': { // bool
int size = sizeof(int8_t); return out_packet_int8(*((*(int8_t**)value)++));
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
} }
case 'i': { // int(width=32) case 'i': { // int(width=32)
int size = sizeof(int32_t); return out_packet_int32(*((*(int32_t**)value)++));
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
} }
case 'I': { // int(width=64) case 'I': { // int(width=64)
int size = sizeof(int64_t); return out_packet_int64(*((*(int64_t**)value)++));
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
} }
case 'f': { // float case 'f': { // float
int size = sizeof(double); return out_packet_float64(*((*(double**)value)++));
if(!out_packet_chunk(*value, size))
return 0;
*value = (void*)((intptr_t)(*value) + size);
break;
} }
case 'F': { // Fraction case 'F': { // Fraction
int size = sizeof(int64_t) * 2; struct { int64_t numerator, denominator; } *fraction = *value;
if(!out_packet_chunk(*value, size)) if(!out_packet_int64(fraction->numerator))
return 0; return 0;
*value = (void*)((intptr_t)(*value) + size); if(!out_packet_int64(fraction->denominator))
return 0;
*value = (void*)((intptr_t)(*value) + sizeof(*fraction));
break; break;
} }
case 's': { // string case 's': { // string
const char **string = *value; return out_packet_string(*((*(const char***)value)++));
if(!out_packet_string(*string))
return 0;
*value = (void*)((intptr_t)(*value) + strlen(*string) + 1);
break;
} }
case 'l': { // list(elt='a) case 'l': { // list(elt='a)
@ -595,11 +771,11 @@ static int send_rpc_value(const char **tag, void **value)
if(!out_packet_int32(list->length)) if(!out_packet_int32(list->length))
return 0; return 0;
const char *tag_copy; const char *tag_copy = *tag;
for(int i = 0; i < list->length; i++) { for(int i = 0; i < list->length; i++) {
tag_copy = *tag;
if(!send_rpc_value(&tag_copy, &element)) if(!send_rpc_value(&tag_copy, &element))
return 0; return 0;
tag_copy = *tag;
} }
*tag = tag_copy; *tag = tag_copy;
@ -634,7 +810,7 @@ static int send_rpc_value(const char **tag, void **value)
if(option->present) { if(option->present) {
return send_rpc_value(tag, &contents); return send_rpc_value(tag, &contents);
} else { } else {
(*tag)++; skip_rpc_value(tag);
break; break;
} }
} }
@ -668,8 +844,7 @@ static int send_rpc_request(int service, const char *tag, va_list args)
} }
out_packet_int8(0); out_packet_int8(0);
out_packet_string(tag + 1); out_packet_string(tag + 1); // return tags
out_packet_finish(); out_packet_finish();
return 1; return 1;
} }