mirror of https://github.com/m-labs/artiq.git
rpc: support all data types as parameters
This commit is contained in:
parent
44e7b99792
commit
0d10ae7580
|
@ -171,19 +171,35 @@ class Comm(AutoContext):
|
||||||
_write_exactly(self.port, struct.pack(
|
_write_exactly(self.port, struct.pack(
|
||||||
">lbl", 0x5a5a5a5a, _H2DMsgType.RUN_KERNEL.value, len(kname)))
|
">lbl", 0x5a5a5a5a, _H2DMsgType.RUN_KERNEL.value, len(kname)))
|
||||||
for c in kname:
|
for c in kname:
|
||||||
_write_exactly(self.port, struct.pack("B", ord(c)))
|
_write_exactly(self.port, struct.pack(">B", ord(c)))
|
||||||
logger.debug("running kernel: {}".format(kname))
|
logger.debug("running kernel: {}".format(kname))
|
||||||
|
|
||||||
def serve(self, rpc_map, user_exception_map):
|
def _receive_rpc_values(self):
|
||||||
|
r = []
|
||||||
while True:
|
while True:
|
||||||
msg = self._get_device_msg()
|
type_tag = chr(struct.unpack(">B", _read_exactly(self.port, 1))[0])
|
||||||
if msg == _D2HMsgType.RPC_REQUEST:
|
if type_tag == "\x00":
|
||||||
rpc_num, n_args = struct.unpack(">hB",
|
return r
|
||||||
_read_exactly(self.port, 3))
|
if type_tag == "n":
|
||||||
args = []
|
r.append(None)
|
||||||
for i in range(n_args):
|
if type_tag == "b":
|
||||||
args.append(*struct.unpack(">l",
|
r.append(bool(struct.unpack(">B",
|
||||||
_read_exactly(self.port, 4)))
|
_read_exactly(self.port, 1))[0]))
|
||||||
|
if type_tag == "i":
|
||||||
|
r.append(struct.unpack(">l", _read_exactly(self.port, 4))[0])
|
||||||
|
if type_tag == "I":
|
||||||
|
r.append(struct.unpack(">q", _read_exactly(self.port, 8))[0])
|
||||||
|
if type_tag == "f":
|
||||||
|
r.append(struct.unpack(">d", _read_exactly(self.port, 8))[0])
|
||||||
|
if type_tag == "F":
|
||||||
|
n, d = struct.unpack(">qq", _read_exactly(self.port, 16))
|
||||||
|
r.append(Fraction(n, d))
|
||||||
|
if type_tag == "l":
|
||||||
|
r.append(self._receive_rpc_values())
|
||||||
|
|
||||||
|
def _serve_rpc(self, rpc_map):
|
||||||
|
(rpc_num, ) = struct.unpack(">h", _read_exactly(self.port, 2))
|
||||||
|
args = self._receive_rpc_values()
|
||||||
logger.debug("rpc service: {} ({})".format(rpc_num, args))
|
logger.debug("rpc service: {} ({})".format(rpc_num, args))
|
||||||
r = rpc_map[rpc_num](*args)
|
r = rpc_map[rpc_num](*args)
|
||||||
if r is None:
|
if r is None:
|
||||||
|
@ -191,13 +207,22 @@ class Comm(AutoContext):
|
||||||
_write_exactly(self.port, struct.pack(">l", r))
|
_write_exactly(self.port, struct.pack(">l", r))
|
||||||
logger.debug("rpc service: {} ({}) == {}".format(
|
logger.debug("rpc service: {} ({}) == {}".format(
|
||||||
rpc_num, args, r))
|
rpc_num, args, r))
|
||||||
elif msg == _D2HMsgType.KERNEL_EXCEPTION:
|
|
||||||
|
def _serve_exception(self, user_exception_map):
|
||||||
(eid, ) = struct.unpack(">l", _read_exactly(self.port, 4))
|
(eid, ) = struct.unpack(">l", _read_exactly(self.port, 4))
|
||||||
if eid < core_language.first_user_eid:
|
if eid < core_language.first_user_eid:
|
||||||
exception = runtime_exceptions.exception_map[eid]
|
exception = runtime_exceptions.exception_map[eid]
|
||||||
else:
|
else:
|
||||||
exception = user_exception_map[eid]
|
exception = user_exception_map[eid]
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
|
def serve(self, rpc_map, user_exception_map):
|
||||||
|
while True:
|
||||||
|
msg = self._get_device_msg()
|
||||||
|
if msg == _D2HMsgType.RPC_REQUEST:
|
||||||
|
self._serve_rpc(rpc_map)
|
||||||
|
elif msg == _D2HMsgType.KERNEL_EXCEPTION:
|
||||||
|
self._serve_exception(user_exception_map)
|
||||||
elif msg == _D2HMsgType.KERNEL_FINISHED:
|
elif msg == _D2HMsgType.KERNEL_FINISHED:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import llvmlite.ir as ll
|
import llvmlite.ir as ll
|
||||||
import llvmlite.binding as llvm
|
import llvmlite.binding as llvm
|
||||||
|
|
||||||
from artiq.py2llvm import base_types
|
from artiq.py2llvm import base_types, fractions, lists
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@ llvm.initialize_all_targets()
|
||||||
llvm.initialize_all_asmprinters()
|
llvm.initialize_all_asmprinters()
|
||||||
|
|
||||||
_syscalls = {
|
_syscalls = {
|
||||||
"rpc": "i+:i",
|
|
||||||
"gpio_set": "ib:n",
|
"gpio_set": "ib:n",
|
||||||
"rtio_oe": "ib:n",
|
"rtio_oe": "ib:n",
|
||||||
"rtio_set": "Iii:n",
|
"rtio_set": "Iii:n",
|
||||||
|
@ -23,36 +22,56 @@ _syscalls = {
|
||||||
"dds_program": "Iiiiibb:n",
|
"dds_program": "Iiiiibb:n",
|
||||||
}
|
}
|
||||||
|
|
||||||
_chr_to_type = {
|
|
||||||
"n": lambda: ll.VoidType(),
|
|
||||||
"b": lambda: ll.IntType(1),
|
|
||||||
"i": lambda: ll.IntType(32),
|
|
||||||
"I": lambda: ll.IntType(64)
|
|
||||||
}
|
|
||||||
|
|
||||||
_chr_to_value = {
|
def _chr_to_type(c):
|
||||||
"n": lambda: base_types.VNone(),
|
if c == "n":
|
||||||
"b": lambda: base_types.VBool(),
|
return ll.VoidType()
|
||||||
"i": lambda: base_types.VInt(),
|
if c == "b":
|
||||||
"I": lambda: base_types.VInt(64)
|
return ll.IntType(1)
|
||||||
}
|
if c == "i":
|
||||||
|
return ll.IntType(32)
|
||||||
|
if c == "I":
|
||||||
|
return ll.IntType(64)
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
def _str_to_functype(s):
|
def _str_to_functype(s):
|
||||||
assert(s[-2] == ":")
|
assert(s[-2] == ":")
|
||||||
type_ret = _chr_to_type[s[-1]]()
|
type_ret = _chr_to_type(s[-1])
|
||||||
|
type_args = [_chr_to_type(c) for c in s[:-2] if c != "n"]
|
||||||
|
return ll.FunctionType(type_ret, type_args)
|
||||||
|
|
||||||
var_arg_fixcount = None
|
|
||||||
type_args = []
|
def _chr_to_value(c):
|
||||||
for n, c in enumerate(s[:-2]):
|
if c == "n":
|
||||||
if c == "+":
|
return base_types.VNone()
|
||||||
type_args.append(ll.IntType(32))
|
if c == "b":
|
||||||
var_arg_fixcount = n
|
return base_types.VBool()
|
||||||
elif c != "n":
|
if c == "i":
|
||||||
type_args.append(_chr_to_type[c]())
|
return base_types.VInt()
|
||||||
return (var_arg_fixcount,
|
if c == "I":
|
||||||
ll.FunctionType(type_ret, type_args,
|
return base_types.VInt(64)
|
||||||
var_arg=var_arg_fixcount is not None))
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def _value_to_str(v):
|
||||||
|
if isinstance(v, base_types.VNone):
|
||||||
|
return "n"
|
||||||
|
if isinstance(v, base_types.VBool):
|
||||||
|
return "b"
|
||||||
|
if isinstance(v, base_types.VInt):
|
||||||
|
if v.nbits == 32:
|
||||||
|
return "i"
|
||||||
|
if v.nbits == 64:
|
||||||
|
return "I"
|
||||||
|
raise ValueError
|
||||||
|
if isinstance(v, base_types.VFloat):
|
||||||
|
return "f"
|
||||||
|
if isinstance(v, fractions.VFraction):
|
||||||
|
return "F"
|
||||||
|
if isinstance(v, lists.VList):
|
||||||
|
return "l" + _value_to_str(v.el_type)
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
class LinkInterface:
|
class LinkInterface:
|
||||||
|
@ -60,13 +79,15 @@ class LinkInterface:
|
||||||
self.module = module
|
self.module = module
|
||||||
llvm_module = self.module.llvm_module
|
llvm_module = self.module.llvm_module
|
||||||
|
|
||||||
|
# RPC
|
||||||
|
func_type = ll.FunctionType(ll.IntType(32), [ll.IntType(32)],
|
||||||
|
var_arg=1)
|
||||||
|
self.rpc = ll.Function(llvm_module, func_type, "__syscall_rpc")
|
||||||
|
|
||||||
# syscalls
|
# syscalls
|
||||||
self.syscalls = dict()
|
self.syscalls = dict()
|
||||||
self.var_arg_fixcount = dict()
|
|
||||||
for func_name, func_type_str in _syscalls.items():
|
for func_name, func_type_str in _syscalls.items():
|
||||||
var_arg_fixcount, func_type = _str_to_functype(func_type_str)
|
func_type = _str_to_functype(func_type_str)
|
||||||
if var_arg_fixcount is not None:
|
|
||||||
self.var_arg_fixcount[func_name] = var_arg_fixcount
|
|
||||||
self.syscalls[func_name] = ll.Function(
|
self.syscalls[func_name] = ll.Function(
|
||||||
llvm_module, func_type, "__syscall_" + func_name)
|
llvm_module, func_type, "__syscall_" + func_name)
|
||||||
|
|
||||||
|
@ -91,19 +112,48 @@ class LinkInterface:
|
||||||
self.eh_raise = ll.Function(llvm_module, func_type, "__eh_raise")
|
self.eh_raise = ll.Function(llvm_module, func_type, "__eh_raise")
|
||||||
self.eh_raise.attributes.add("noreturn")
|
self.eh_raise.attributes.add("noreturn")
|
||||||
|
|
||||||
def build_syscall(self, syscall_name, args, builder):
|
def _build_rpc(self, args, builder):
|
||||||
r = _chr_to_value[_syscalls[syscall_name][-1]]()
|
r = base_types.VInt()
|
||||||
|
if builder is not None:
|
||||||
|
new_args = []
|
||||||
|
new_args.append(args[0].auto_load(builder)) # RPC number
|
||||||
|
for arg in args[1:]:
|
||||||
|
# type tag
|
||||||
|
arg_type_str = _value_to_str(arg)
|
||||||
|
arg_type_int = 0
|
||||||
|
for c in reversed(arg_type_str):
|
||||||
|
arg_type_int <<= 8
|
||||||
|
arg_type_int |= ord(c)
|
||||||
|
new_args.append(ll.Constant(ll.IntType(32), arg_type_int))
|
||||||
|
|
||||||
|
# pointer to value
|
||||||
|
if not isinstance(arg, base_types.VNone):
|
||||||
|
if isinstance(arg.llvm_value.type, ll.PointerType):
|
||||||
|
new_args.append(arg.llvm_value)
|
||||||
|
else:
|
||||||
|
arg_ptr = arg.new()
|
||||||
|
arg_ptr.alloca(builder)
|
||||||
|
arg_ptr.auto_store(builder, arg.llvm_value)
|
||||||
|
new_args.append(arg_ptr.llvm_value)
|
||||||
|
# end marker
|
||||||
|
new_args.append(ll.Constant(ll.IntType(32), 0))
|
||||||
|
r.auto_store(builder, builder.call(self.rpc, new_args))
|
||||||
|
return r
|
||||||
|
|
||||||
|
def _build_regular_syscall(self, syscall_name, args, builder):
|
||||||
|
r = _chr_to_value(_syscalls[syscall_name][-1])
|
||||||
if builder is not None:
|
if builder is not None:
|
||||||
args = [arg.auto_load(builder) for arg in args]
|
args = [arg.auto_load(builder) for arg in args]
|
||||||
if syscall_name in self.var_arg_fixcount:
|
|
||||||
fixcount = self.var_arg_fixcount[syscall_name]
|
|
||||||
args = args[:fixcount] \
|
|
||||||
+ [ll.Constant(ll.IntType(32), len(args) - fixcount)] \
|
|
||||||
+ args[fixcount:]
|
|
||||||
r.auto_store(builder, builder.call(self.syscalls[syscall_name],
|
r.auto_store(builder, builder.call(self.syscalls[syscall_name],
|
||||||
args))
|
args))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def build_syscall(self, syscall_name, args, builder):
|
||||||
|
if syscall_name == "rpc":
|
||||||
|
return self._build_rpc(args, builder)
|
||||||
|
else:
|
||||||
|
return self._build_regular_syscall(syscall_name, args, builder)
|
||||||
|
|
||||||
def build_catch(self, builder):
|
def build_catch(self, builder):
|
||||||
jmpbuf = builder.call(self.eh_push, [])
|
jmpbuf = builder.call(self.eh_push, [])
|
||||||
exception_occured = builder.call(self.eh_setjmp, [jmpbuf])
|
exception_occured = builder.call(self.eh_setjmp, [jmpbuf])
|
||||||
|
|
|
@ -447,17 +447,6 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
||||||
attr_writeback = []
|
attr_writeback = []
|
||||||
for (_, attr), attr_info in attribute_namespace.items():
|
for (_, attr), attr_info in attribute_namespace.items():
|
||||||
if attr_info.read_write:
|
if attr_info.read_write:
|
||||||
# HACK/FIXME: since RPC of non-int is not supported yet, skip
|
|
||||||
# writeback of other types for now.
|
|
||||||
# This code breaks if an int is promoted to int64
|
|
||||||
if hasattr(attr_info.obj, attr):
|
|
||||||
val = getattr(attr_info.obj, attr)
|
|
||||||
if (not isinstance(val, int)
|
|
||||||
or isinstance(val, core_language.int64)
|
|
||||||
or isinstance(val, bool)):
|
|
||||||
continue
|
|
||||||
#
|
|
||||||
|
|
||||||
setter = partial(setattr, attr_info.obj, attr)
|
setter = partial(setattr, attr_info.obj, attr)
|
||||||
func = ast.copy_location(
|
func = ast.copy_location(
|
||||||
ast.Name("syscall", ast.Load()), loc_node)
|
ast.Name("syscall", ast.Load()), loc_node)
|
||||||
|
|
|
@ -11,7 +11,7 @@ typedef int (*object_loader)(void *, int);
|
||||||
typedef int (*kernel_runner)(const char *, int *);
|
typedef int (*kernel_runner)(const char *, int *);
|
||||||
|
|
||||||
void comm_serve(object_loader load_object, kernel_runner run_kernel);
|
void comm_serve(object_loader load_object, kernel_runner run_kernel);
|
||||||
int comm_rpc(int rpc_num, int n_args, ...);
|
int comm_rpc(int rpc_num, ...);
|
||||||
void comm_log(const char *fmt, ...);
|
void comm_log(const char *fmt, ...);
|
||||||
|
|
||||||
#endif /* __COMM_H */
|
#endif /* __COMM_H */
|
||||||
|
|
|
@ -181,17 +181,59 @@ void comm_serve(object_loader load_object, kernel_runner run_kernel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int comm_rpc(int rpc_num, int n_args, ...)
|
static int send_value(int type_tag, void *value)
|
||||||
{
|
{
|
||||||
|
char base_type;
|
||||||
|
int i, p;
|
||||||
|
int len;
|
||||||
|
|
||||||
|
base_type = type_tag;
|
||||||
|
send_char(base_type);
|
||||||
|
switch(base_type) {
|
||||||
|
case 'n':
|
||||||
|
return 0;
|
||||||
|
case 'b':
|
||||||
|
if(*(char *)value)
|
||||||
|
send_char(1);
|
||||||
|
else
|
||||||
|
send_char(0);
|
||||||
|
return 1;
|
||||||
|
case 'i':
|
||||||
|
send_int(*(int *)value);
|
||||||
|
return 4;
|
||||||
|
case 'I':
|
||||||
|
case 'f':
|
||||||
|
send_int(*(int *)value);
|
||||||
|
send_int(*((int *)value + 1));
|
||||||
|
return 8;
|
||||||
|
case 'F':
|
||||||
|
for(i=0;i<4;i++)
|
||||||
|
send_int(*((int *)value + i));
|
||||||
|
return 16;
|
||||||
|
case 'l':
|
||||||
|
len = *(int *)value;
|
||||||
|
p = 4;
|
||||||
|
for(i=0;i<len;i++)
|
||||||
|
p += send_value(type_tag >> 8, (char *)value + p);
|
||||||
|
send_char(0);
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int comm_rpc(int rpc_num, ...)
|
||||||
|
{
|
||||||
|
int type_tag;
|
||||||
|
|
||||||
send_char(MSGTYPE_RPC_REQUEST);
|
send_char(MSGTYPE_RPC_REQUEST);
|
||||||
send_sint(rpc_num);
|
send_sint(rpc_num);
|
||||||
send_char(n_args);
|
|
||||||
|
|
||||||
va_list args;
|
va_list args;
|
||||||
va_start(args, n_args);
|
va_start(args, rpc_num);
|
||||||
while(n_args--)
|
while((type_tag = va_arg(args, int)))
|
||||||
send_int(va_arg(args, int));
|
send_value(type_tag, type_tag == 'n' ? NULL : va_arg(args, void *));
|
||||||
va_end(args);
|
va_end(args);
|
||||||
|
send_char(0);
|
||||||
|
|
||||||
return receive_int();
|
return receive_int();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue