From 0d10ae75802d8b56262039df8d2b83686d04df76 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Fri, 19 Dec 2014 12:46:24 +0800 Subject: [PATCH] rpc: support all data types as parameters --- artiq/coredevice/comm_serial.py | 65 +++++++++++------ artiq/coredevice/runtime.py | 124 ++++++++++++++++++++++---------- artiq/transforms/inline.py | 11 --- soc/runtime/comm.h | 2 +- soc/runtime/comm_serial.c | 52 ++++++++++++-- 5 files changed, 180 insertions(+), 74 deletions(-) diff --git a/artiq/coredevice/comm_serial.py b/artiq/coredevice/comm_serial.py index dbe53b31c..599c17f1b 100644 --- a/artiq/coredevice/comm_serial.py +++ b/artiq/coredevice/comm_serial.py @@ -171,33 +171,58 @@ class Comm(AutoContext): _write_exactly(self.port, struct.pack( ">lbl", 0x5a5a5a5a, _H2DMsgType.RUN_KERNEL.value, len(kname))) for c in kname: - _write_exactly(self.port, struct.pack("B", ord(c))) + _write_exactly(self.port, struct.pack(">B", ord(c))) logger.debug("running kernel: {}".format(kname)) + def _receive_rpc_values(self): + r = [] + while True: + type_tag = chr(struct.unpack(">B", _read_exactly(self.port, 1))[0]) + if type_tag == "\x00": + return r + if type_tag == "n": + r.append(None) + if type_tag == "b": + r.append(bool(struct.unpack(">B", + _read_exactly(self.port, 1))[0])) + if type_tag == "i": + r.append(struct.unpack(">l", _read_exactly(self.port, 4))[0]) + if type_tag == "I": + r.append(struct.unpack(">q", _read_exactly(self.port, 8))[0]) + if type_tag == "f": + r.append(struct.unpack(">d", _read_exactly(self.port, 8))[0]) + if type_tag == "F": + n, d = struct.unpack(">qq", _read_exactly(self.port, 16)) + r.append(Fraction(n, d)) + if type_tag == "l": + r.append(self._receive_rpc_values()) + + def _serve_rpc(self, rpc_map): + (rpc_num, ) = struct.unpack(">h", _read_exactly(self.port, 2)) + args = self._receive_rpc_values() + logger.debug("rpc service: {} ({})".format(rpc_num, args)) + r = rpc_map[rpc_num](*args) + if r is None: + r = 0 + _write_exactly(self.port, struct.pack(">l", r)) + logger.debug("rpc service: {} ({}) == {}".format( + rpc_num, args, r)) + + def _serve_exception(self, user_exception_map): + (eid, ) = struct.unpack(">l", _read_exactly(self.port, 4)) + if eid < core_language.first_user_eid: + exception = runtime_exceptions.exception_map[eid] + else: + exception = user_exception_map[eid] + raise exception + def serve(self, rpc_map, user_exception_map): while True: msg = self._get_device_msg() if msg == _D2HMsgType.RPC_REQUEST: - rpc_num, n_args = struct.unpack(">hB", - _read_exactly(self.port, 3)) - args = [] - for i in range(n_args): - args.append(*struct.unpack(">l", - _read_exactly(self.port, 4))) - logger.debug("rpc service: {} ({})".format(rpc_num, args)) - r = rpc_map[rpc_num](*args) - if r is None: - r = 0 - _write_exactly(self.port, struct.pack(">l", r)) - logger.debug("rpc service: {} ({}) == {}".format( - rpc_num, args, r)) + self._serve_rpc(rpc_map) elif msg == _D2HMsgType.KERNEL_EXCEPTION: - (eid, ) = struct.unpack(">l", _read_exactly(self.port, 4)) - if eid < core_language.first_user_eid: - exception = runtime_exceptions.exception_map[eid] - else: - exception = user_exception_map[eid] - raise exception + self._serve_exception(user_exception_map) elif msg == _D2HMsgType.KERNEL_FINISHED: return else: diff --git a/artiq/coredevice/runtime.py b/artiq/coredevice/runtime.py index c73d9a603..14b31c6cb 100644 --- a/artiq/coredevice/runtime.py +++ b/artiq/coredevice/runtime.py @@ -3,7 +3,7 @@ import os import llvmlite.ir as ll import llvmlite.binding as llvm -from artiq.py2llvm import base_types +from artiq.py2llvm import base_types, fractions, lists from artiq.language import units @@ -12,7 +12,6 @@ llvm.initialize_all_targets() llvm.initialize_all_asmprinters() _syscalls = { - "rpc": "i+:i", "gpio_set": "ib:n", "rtio_oe": "ib:n", "rtio_set": "Iii:n", @@ -23,36 +22,56 @@ _syscalls = { "dds_program": "Iiiiibb:n", } -_chr_to_type = { - "n": lambda: ll.VoidType(), - "b": lambda: ll.IntType(1), - "i": lambda: ll.IntType(32), - "I": lambda: ll.IntType(64) -} -_chr_to_value = { - "n": lambda: base_types.VNone(), - "b": lambda: base_types.VBool(), - "i": lambda: base_types.VInt(), - "I": lambda: base_types.VInt(64) -} +def _chr_to_type(c): + if c == "n": + return ll.VoidType() + if c == "b": + return ll.IntType(1) + if c == "i": + return ll.IntType(32) + if c == "I": + return ll.IntType(64) + raise ValueError def _str_to_functype(s): assert(s[-2] == ":") - type_ret = _chr_to_type[s[-1]]() + type_ret = _chr_to_type(s[-1]) + type_args = [_chr_to_type(c) for c in s[:-2] if c != "n"] + return ll.FunctionType(type_ret, type_args) - var_arg_fixcount = None - type_args = [] - for n, c in enumerate(s[:-2]): - if c == "+": - type_args.append(ll.IntType(32)) - var_arg_fixcount = n - elif c != "n": - type_args.append(_chr_to_type[c]()) - return (var_arg_fixcount, - ll.FunctionType(type_ret, type_args, - var_arg=var_arg_fixcount is not None)) + +def _chr_to_value(c): + if c == "n": + return base_types.VNone() + if c == "b": + return base_types.VBool() + if c == "i": + return base_types.VInt() + if c == "I": + return base_types.VInt(64) + raise ValueError + + +def _value_to_str(v): + if isinstance(v, base_types.VNone): + return "n" + if isinstance(v, base_types.VBool): + return "b" + if isinstance(v, base_types.VInt): + if v.nbits == 32: + return "i" + if v.nbits == 64: + return "I" + raise ValueError + if isinstance(v, base_types.VFloat): + return "f" + if isinstance(v, fractions.VFraction): + return "F" + if isinstance(v, lists.VList): + return "l" + _value_to_str(v.el_type) + raise ValueError class LinkInterface: @@ -60,13 +79,15 @@ class LinkInterface: self.module = module llvm_module = self.module.llvm_module + # RPC + func_type = ll.FunctionType(ll.IntType(32), [ll.IntType(32)], + var_arg=1) + self.rpc = ll.Function(llvm_module, func_type, "__syscall_rpc") + # syscalls self.syscalls = dict() - self.var_arg_fixcount = dict() for func_name, func_type_str in _syscalls.items(): - var_arg_fixcount, func_type = _str_to_functype(func_type_str) - if var_arg_fixcount is not None: - self.var_arg_fixcount[func_name] = var_arg_fixcount + func_type = _str_to_functype(func_type_str) self.syscalls[func_name] = ll.Function( llvm_module, func_type, "__syscall_" + func_name) @@ -91,19 +112,48 @@ class LinkInterface: self.eh_raise = ll.Function(llvm_module, func_type, "__eh_raise") self.eh_raise.attributes.add("noreturn") - def build_syscall(self, syscall_name, args, builder): - r = _chr_to_value[_syscalls[syscall_name][-1]]() + def _build_rpc(self, args, builder): + r = base_types.VInt() + if builder is not None: + new_args = [] + new_args.append(args[0].auto_load(builder)) # RPC number + for arg in args[1:]: + # type tag + arg_type_str = _value_to_str(arg) + arg_type_int = 0 + for c in reversed(arg_type_str): + arg_type_int <<= 8 + arg_type_int |= ord(c) + new_args.append(ll.Constant(ll.IntType(32), arg_type_int)) + + # pointer to value + if not isinstance(arg, base_types.VNone): + if isinstance(arg.llvm_value.type, ll.PointerType): + new_args.append(arg.llvm_value) + else: + arg_ptr = arg.new() + arg_ptr.alloca(builder) + arg_ptr.auto_store(builder, arg.llvm_value) + new_args.append(arg_ptr.llvm_value) + # end marker + new_args.append(ll.Constant(ll.IntType(32), 0)) + r.auto_store(builder, builder.call(self.rpc, new_args)) + return r + + def _build_regular_syscall(self, syscall_name, args, builder): + r = _chr_to_value(_syscalls[syscall_name][-1]) if builder is not None: args = [arg.auto_load(builder) for arg in args] - if syscall_name in self.var_arg_fixcount: - fixcount = self.var_arg_fixcount[syscall_name] - args = args[:fixcount] \ - + [ll.Constant(ll.IntType(32), len(args) - fixcount)] \ - + args[fixcount:] r.auto_store(builder, builder.call(self.syscalls[syscall_name], args)) return r + def build_syscall(self, syscall_name, args, builder): + if syscall_name == "rpc": + return self._build_rpc(args, builder) + else: + return self._build_regular_syscall(syscall_name, args, builder) + def build_catch(self, builder): jmpbuf = builder.call(self.eh_push, []) exception_occured = builder.call(self.eh_setjmp, [jmpbuf]) diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index afb5898cd..40d82163d 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -447,17 +447,6 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node): attr_writeback = [] for (_, attr), attr_info in attribute_namespace.items(): if attr_info.read_write: - # HACK/FIXME: since RPC of non-int is not supported yet, skip - # writeback of other types for now. - # This code breaks if an int is promoted to int64 - if hasattr(attr_info.obj, attr): - val = getattr(attr_info.obj, attr) - if (not isinstance(val, int) - or isinstance(val, core_language.int64) - or isinstance(val, bool)): - continue - # - setter = partial(setattr, attr_info.obj, attr) func = ast.copy_location( ast.Name("syscall", ast.Load()), loc_node) diff --git a/soc/runtime/comm.h b/soc/runtime/comm.h index eb4235b49..4544f420b 100644 --- a/soc/runtime/comm.h +++ b/soc/runtime/comm.h @@ -11,7 +11,7 @@ typedef int (*object_loader)(void *, int); typedef int (*kernel_runner)(const char *, int *); void comm_serve(object_loader load_object, kernel_runner run_kernel); -int comm_rpc(int rpc_num, int n_args, ...); +int comm_rpc(int rpc_num, ...); void comm_log(const char *fmt, ...); #endif /* __COMM_H */ diff --git a/soc/runtime/comm_serial.c b/soc/runtime/comm_serial.c index 7cdceef54..b234555c4 100644 --- a/soc/runtime/comm_serial.c +++ b/soc/runtime/comm_serial.c @@ -181,17 +181,59 @@ void comm_serve(object_loader load_object, kernel_runner run_kernel) } } -int comm_rpc(int rpc_num, int n_args, ...) +static int send_value(int type_tag, void *value) { + char base_type; + int i, p; + int len; + + base_type = type_tag; + send_char(base_type); + switch(base_type) { + case 'n': + return 0; + case 'b': + if(*(char *)value) + send_char(1); + else + send_char(0); + return 1; + case 'i': + send_int(*(int *)value); + return 4; + case 'I': + case 'f': + send_int(*(int *)value); + send_int(*((int *)value + 1)); + return 8; + case 'F': + for(i=0;i<4;i++) + send_int(*((int *)value + i)); + return 16; + case 'l': + len = *(int *)value; + p = 4; + for(i=0;i> 8, (char *)value + p); + send_char(0); + return p; + } + return 0; +} + +int comm_rpc(int rpc_num, ...) +{ + int type_tag; + send_char(MSGTYPE_RPC_REQUEST); send_sint(rpc_num); - send_char(n_args); va_list args; - va_start(args, n_args); - while(n_args--) - send_int(va_arg(args, int)); + va_start(args, rpc_num); + while((type_tag = va_arg(args, int))) + send_value(type_tag, type_tag == 'n' ? NULL : va_arg(args, void *)); va_end(args); + send_char(0); return receive_int(); }