mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-25 11:18:27 +08:00
Implement sending RPCs.
This commit is contained in:
parent
22457bc19c
commit
b26af5df60
@ -163,7 +163,7 @@ def is_bool(typ):
|
||||
|
||||
def is_int(typ, width=None):
|
||||
if width is not None:
|
||||
return types.is_mono(typ, "int", {"width": width})
|
||||
return types.is_mono(typ, "int", width=width)
|
||||
else:
|
||||
return types.is_mono(typ, "int")
|
||||
|
||||
@ -184,7 +184,7 @@ def is_numeric(typ):
|
||||
|
||||
def is_list(typ, elt=None):
|
||||
if elt is not None:
|
||||
return types.is_mono(typ, "list", {"elt": elt})
|
||||
return types.is_mono(typ, "list", elt=elt)
|
||||
else:
|
||||
return types.is_mono(typ, "list")
|
||||
|
||||
|
@ -5,10 +5,13 @@ the references to the host objects and translates the functions
|
||||
annotated as ``@kernel`` when they are referenced.
|
||||
"""
|
||||
|
||||
import inspect, os
|
||||
import os, re, linecache, inspect
|
||||
from collections import OrderedDict
|
||||
|
||||
from pythonparser import ast, source, diagnostic, parse_buffer
|
||||
|
||||
from . import types, builtins, asttyped, prelude
|
||||
from .transforms import ASTTypedRewriter, Inferencer
|
||||
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
||||
|
||||
|
||||
class ASTSynthesizer:
|
||||
@ -45,6 +48,9 @@ class ASTSynthesizer:
|
||||
typ = builtins.TFloat()
|
||||
return asttyped.NumT(n=value, ctx=None, type=typ,
|
||||
loc=self._add(repr(value)))
|
||||
elif isinstance(value, str):
|
||||
return asttyped.StrT(s=value, ctx=None, type=builtins.TStr(),
|
||||
loc=self._add(repr(value)))
|
||||
elif isinstance(value, list):
|
||||
begin_loc = self._add("[")
|
||||
elts = []
|
||||
@ -123,7 +129,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||
if inspect.isfunction(value):
|
||||
# It's a function. We need to translate the function and insert
|
||||
# a reference to it.
|
||||
function_name = self.quote_function(value)
|
||||
function_name = self.quote_function(value, node.loc)
|
||||
return asttyped.NameT(id=function_name, ctx=None,
|
||||
type=self.globals[function_name],
|
||||
loc=node.loc)
|
||||
@ -154,7 +160,19 @@ class Stitcher:
|
||||
|
||||
self.functions = {}
|
||||
|
||||
self.next_rpc = 0
|
||||
self.rpc_map = {}
|
||||
self.inverse_rpc_map = {}
|
||||
|
||||
def _map(self, obj):
|
||||
obj_id = id(obj)
|
||||
if obj_id in self.inverse_rpc_map:
|
||||
return self.inverse_rpc_map[obj_id]
|
||||
|
||||
self.next_rpc += 1
|
||||
self.rpc_map[self.next_rpc] = obj
|
||||
self.inverse_rpc_map[obj_id] = self.next_rpc
|
||||
return self.next_rpc
|
||||
|
||||
def _iterate(self):
|
||||
inferencer = Inferencer(engine=self.engine)
|
||||
@ -213,17 +231,102 @@ class Stitcher:
|
||||
quote_function=self._quote_function)
|
||||
return asttyped_rewriter.visit(function_node)
|
||||
|
||||
def _quote_function(self, function):
|
||||
def _function_def_note(self, function):
|
||||
filename = function.__code__.co_filename
|
||||
line = function.__code__.co_firstlineno
|
||||
name = function.__code__.co_name
|
||||
|
||||
source_line = linecache.getline(filename, line)
|
||||
column = re.search("def", source_line).start(0)
|
||||
source_buffer = source.Buffer(source_line, filename, line)
|
||||
loc = source.Range(source_buffer, column, column)
|
||||
return diagnostic.Diagnostic("note",
|
||||
"definition of function '{function}'",
|
||||
{"function": name},
|
||||
loc)
|
||||
|
||||
def _type_of_param(self, function, loc, param):
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
# Try and infer the type from the default value.
|
||||
# This is tricky, because the default value might not have
|
||||
# a well-defined type in APython.
|
||||
# In this case, we bail out, but mention why we do it.
|
||||
synthesizer = ASTSynthesizer()
|
||||
ast = synthesizer.quote(param.default)
|
||||
synthesizer.finalize()
|
||||
|
||||
def proxy_diagnostic(diag):
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"expanded from here while trying to infer a type for an"
|
||||
" unannotated optional argument '{param_name}' from its default value",
|
||||
{"param_name": param.name},
|
||||
loc)
|
||||
diag.notes.append(note)
|
||||
|
||||
diag.notes.append(self._function_def_note(function))
|
||||
|
||||
self.engine.process(diag)
|
||||
|
||||
proxy_engine = diagnostic.Engine()
|
||||
proxy_engine.process = proxy_diagnostic
|
||||
Inferencer(engine=proxy_engine).visit(ast)
|
||||
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
||||
|
||||
return ast.type
|
||||
else:
|
||||
# Let the rest of the program decide.
|
||||
return types.TVar()
|
||||
|
||||
def _quote_rpc_function(self, function, loc):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
arg_types = OrderedDict()
|
||||
optarg_types = OrderedDict()
|
||||
for param in signature.parameters.values():
|
||||
if param.kind not in (inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD):
|
||||
# We pretend we don't see *args, kwpostargs=..., **kwargs.
|
||||
# Since every method can be still invoked without any arguments
|
||||
# going into *args and the slots after it, this is always safe,
|
||||
# if sometimes constraining.
|
||||
#
|
||||
# Accepting POSITIONAL_ONLY is OK, because the compiler
|
||||
# desugars the keyword arguments into positional ones internally.
|
||||
continue
|
||||
|
||||
if param.default is inspect.Parameter.empty:
|
||||
arg_types[param.name] = self._type_of_param(function, loc, param)
|
||||
else:
|
||||
optarg_types[param.name] = self._type_of_param(function, loc, param)
|
||||
|
||||
# Fixed for now.
|
||||
ret_type = builtins.TInt(types.TValue(32))
|
||||
|
||||
rpc_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
|
||||
service=self._map(function))
|
||||
|
||||
rpc_name = "__rpc_{}__".format(rpc_type.service)
|
||||
self.globals[rpc_name] = rpc_type
|
||||
self.functions[function] = rpc_name
|
||||
|
||||
return rpc_name
|
||||
|
||||
def _quote_function(self, function, loc):
|
||||
if function in self.functions:
|
||||
return self.functions[function]
|
||||
|
||||
# Insert the typed AST for the new function and restart inference.
|
||||
# It doesn't really matter where we insert as long as it is before
|
||||
# the final call.
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.insert(0, function_node)
|
||||
self.inference_finished = False
|
||||
return function_node.name
|
||||
if hasattr(function, "artiq_embedded"):
|
||||
# Insert the typed AST for the new function and restart inference.
|
||||
# It doesn't really matter where we insert as long as it is before
|
||||
# the final call.
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.insert(0, function_node)
|
||||
self.inference_finished = False
|
||||
return function_node.name
|
||||
else:
|
||||
# Insert a storage-less global whose type instructs the compiler
|
||||
# to perform an RPC instead of a regular call.
|
||||
return self._quote_rpc_function(function, loc)
|
||||
|
||||
def stitch_call(self, function, args, kwargs):
|
||||
function_node = self._quote_embedded_function(function)
|
||||
|
@ -41,14 +41,16 @@ class Source:
|
||||
|
||||
class Module:
|
||||
def __init__(self, src):
|
||||
int_monomorphizer = transforms.IntMonomorphizer(engine=src.engine)
|
||||
inferencer = transforms.Inferencer(engine=src.engine)
|
||||
monomorphism_validator = validators.MonomorphismValidator(engine=src.engine)
|
||||
escape_validator = validators.EscapeValidator(engine=src.engine)
|
||||
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=src.engine,
|
||||
self.engine = src.engine
|
||||
|
||||
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
|
||||
inferencer = transforms.Inferencer(engine=self.engine)
|
||||
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
|
||||
escape_validator = validators.EscapeValidator(engine=self.engine)
|
||||
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine,
|
||||
module_name=src.name)
|
||||
dead_code_eliminator = transforms.DeadCodeEliminator(engine=src.engine)
|
||||
local_access_validator = validators.LocalAccessValidator(engine=src.engine)
|
||||
dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
|
||||
local_access_validator = validators.LocalAccessValidator(engine=self.engine)
|
||||
|
||||
self.name = src.name
|
||||
self.globals = src.globals
|
||||
@ -62,7 +64,8 @@ class Module:
|
||||
|
||||
def build_llvm_ir(self, target):
|
||||
"""Compile the module to LLVM IR for the specified target."""
|
||||
llvm_ir_generator = transforms.LLVMIRGenerator(module_name=self.name, target=target)
|
||||
llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine,
|
||||
module_name=self.name, target=target)
|
||||
return llvm_ir_generator.process(self.artiq_ir)
|
||||
|
||||
def entry_point(self):
|
||||
|
@ -3,12 +3,13 @@
|
||||
into LLVM intermediate representation.
|
||||
"""
|
||||
|
||||
from pythonparser import ast
|
||||
from pythonparser import ast, diagnostic
|
||||
from llvmlite_artiq import ir as ll
|
||||
from .. import types, builtins, ir
|
||||
|
||||
class LLVMIRGenerator:
|
||||
def __init__(self, module_name, target):
|
||||
def __init__(self, engine, module_name, target):
|
||||
self.engine = engine
|
||||
self.target = target
|
||||
self.llcontext = target.llcontext
|
||||
self.llmodule = ll.Module(context=self.llcontext, name=module_name)
|
||||
@ -21,6 +22,11 @@ class LLVMIRGenerator:
|
||||
typ = typ.find()
|
||||
if types.is_tuple(typ):
|
||||
return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts])
|
||||
elif types.is_rpc_function(typ):
|
||||
if for_return:
|
||||
return ll.VoidType()
|
||||
else:
|
||||
return ll.LiteralStructType([])
|
||||
elif types.is_function(typ):
|
||||
envarg = ll.IntType(8).as_pointer()
|
||||
llty = ll.FunctionType(args=[envarg] +
|
||||
@ -89,10 +95,13 @@ class LLVMIRGenerator:
|
||||
return ll.Constant(llty, False)
|
||||
elif isinstance(const.value, (int, float)):
|
||||
return ll.Constant(llty, const.value)
|
||||
elif isinstance(const.value, str):
|
||||
assert "\0" not in const.value
|
||||
elif isinstance(const.value, (str, bytes)):
|
||||
if isinstance(const.value, str):
|
||||
assert "\0" not in const.value
|
||||
as_bytes = (const.value + "\0").encode("utf-8")
|
||||
else:
|
||||
as_bytes = const.value
|
||||
|
||||
as_bytes = (const.value + "\0").encode("utf-8")
|
||||
if ir.is_exn_typeinfo(const.type):
|
||||
# Exception typeinfo; should be merged with identical others
|
||||
name = "__artiq_exn_" + const.value
|
||||
@ -144,6 +153,9 @@ class LLVMIRGenerator:
|
||||
llty = ll.FunctionType(ll.VoidType(), [self.llty_of_type(builtins.TException())])
|
||||
elif name == "__artiq_reraise":
|
||||
llty = ll.FunctionType(ll.VoidType(), [])
|
||||
elif name == "rpc":
|
||||
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()],
|
||||
var_arg=True)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@ -546,11 +558,79 @@ class LLVMIRGenerator:
|
||||
name=insn.name)
|
||||
return llvalue
|
||||
|
||||
# See session.c:send_rpc_value.
|
||||
def _rpc_tag(self, typ, root_type, root_loc):
|
||||
if types.is_tuple(typ):
|
||||
assert len(typ.elts) < 256
|
||||
return b"t" + bytes([len(typ.elts)]) + \
|
||||
b"".join([self._rpc_tag(elt_type, root_type, root_loc)
|
||||
for elt_type in typ.elts])
|
||||
elif builtins.is_none(typ):
|
||||
return b"n"
|
||||
elif builtins.is_bool(typ):
|
||||
return b"b"
|
||||
elif builtins.is_int(typ, types.TValue(32)):
|
||||
return b"i"
|
||||
elif builtins.is_int(typ, types.TValue(64)):
|
||||
return b"I"
|
||||
elif builtins.is_float(typ):
|
||||
return b"f"
|
||||
elif builtins.is_str(typ):
|
||||
return b"s"
|
||||
elif builtins.is_list(typ):
|
||||
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
|
||||
root_type, root_loc)
|
||||
elif builtins.is_range(typ):
|
||||
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
|
||||
root_type, root_loc)
|
||||
elif ir.is_option(typ):
|
||||
return b"o" + self._rpc_tag(typ.params["inner"],
|
||||
root_type, root_loc)
|
||||
else:
|
||||
printer = types.TypePrinter()
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"value of type {type}",
|
||||
{"type": printer.name(root_type)},
|
||||
root_loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type {type} is not supported in remote procedure calls",
|
||||
{"type": printer.name(typ)},
|
||||
root_loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def _build_rpc(self, service, args, return_type):
|
||||
llservice = ll.Constant(ll.IntType(32), service)
|
||||
|
||||
tag = b""
|
||||
for arg in args:
|
||||
if isinstance(arg, ir.Constant):
|
||||
# Constants don't have locations, but conveniently
|
||||
# they also never fail to serialize.
|
||||
tag += self._rpc_tag(arg.type, arg.type, None)
|
||||
else:
|
||||
tag += self._rpc_tag(arg.type, arg.type, arg.loc)
|
||||
tag += b":\x00"
|
||||
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
|
||||
|
||||
llargs = []
|
||||
for arg in args:
|
||||
llarg = self.map(arg)
|
||||
llargslot = self.llbuilder.alloca(llarg.type)
|
||||
self.llbuilder.store(llarg, llargslot)
|
||||
llargs.append(llargslot)
|
||||
|
||||
return self.llbuiltin("rpc"), [llservice, lltag] + llargs
|
||||
|
||||
def prepare_call(self, insn):
|
||||
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
|
||||
llenv = self.llbuilder.extract_value(llclosure, 0)
|
||||
llfun = self.llbuilder.extract_value(llclosure, 1)
|
||||
return llfun, [llenv] + list(llargs)
|
||||
if types.is_rpc_function(insn.target_function().type):
|
||||
return self._build_rpc(insn.target_function().type.service,
|
||||
insn.arguments(),
|
||||
insn.target_function().type.ret)
|
||||
else:
|
||||
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
|
||||
llenv = self.llbuilder.extract_value(llclosure, 0)
|
||||
llfun = self.llbuilder.extract_value(llclosure, 1)
|
||||
return llfun, [llenv] + list(llargs)
|
||||
|
||||
def process_Call(self, insn):
|
||||
llfun, llargs = self.prepare_call(insn)
|
||||
|
@ -222,6 +222,26 @@ class TFunction(Type):
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
class TRPCFunction(TFunction):
|
||||
"""
|
||||
A function type of a remote function.
|
||||
|
||||
:ivar service: (int) RPC service number
|
||||
"""
|
||||
|
||||
def __init__(self, args, optargs, ret, service):
|
||||
super().__init__(args, optargs, ret)
|
||||
self.service = service
|
||||
|
||||
def unify(self, other):
|
||||
if isinstance(other, TRPCFunction) and \
|
||||
self.service == other.service:
|
||||
super().unify(other)
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
else:
|
||||
raise UnificationError(self, other)
|
||||
|
||||
class TBuiltin(Type):
|
||||
"""
|
||||
An instance of builtin type. Every instance of a builtin
|
||||
@ -310,6 +330,8 @@ def is_mono(typ, name=None, **params):
|
||||
typ = typ.find()
|
||||
params_match = True
|
||||
for param in params:
|
||||
if param not in typ.params:
|
||||
return False
|
||||
params_match = params_match and \
|
||||
typ.params[param].find() == params[param].find()
|
||||
return isinstance(typ, TMono) and \
|
||||
@ -329,6 +351,9 @@ def is_tuple(typ, elts=None):
|
||||
def is_function(typ):
|
||||
return isinstance(typ.find(), TFunction)
|
||||
|
||||
def is_rpc_function(typ):
|
||||
return isinstance(typ.find(), TRPCFunction)
|
||||
|
||||
def is_builtin(typ, name=None):
|
||||
typ = typ.find()
|
||||
if name is None:
|
||||
@ -381,11 +406,16 @@ class TypePrinter(object):
|
||||
return "(%s,)" % self.name(typ.elts[0])
|
||||
else:
|
||||
return "(%s)" % ", ".join(list(map(self.name, typ.elts)))
|
||||
elif isinstance(typ, TFunction):
|
||||
elif isinstance(typ, (TFunction, TRPCFunction)):
|
||||
args = []
|
||||
args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args]
|
||||
args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs]
|
||||
return "(%s)->%s" % (", ".join(args), self.name(typ.ret))
|
||||
signature = "(%s)->%s" % (", ".join(args), self.name(typ.ret))
|
||||
|
||||
if isinstance(typ, TRPCFunction):
|
||||
return "rpc({}) {}".format(typ.service, signature)
|
||||
elif isinstance(typ, TFunction):
|
||||
return signature
|
||||
elif isinstance(typ, TBuiltinFunction):
|
||||
return "<function %s>" % typ.name
|
||||
elif isinstance(typ, (TConstructor, TExceptionConstructor)):
|
||||
|
@ -276,8 +276,16 @@ class CommGeneric:
|
||||
self._write_empty(_H2DMsgType.RUN_KERNEL)
|
||||
logger.debug("running kernel")
|
||||
|
||||
def _receive_rpc_value(self, tag, rpc_map):
|
||||
if tag == "n":
|
||||
_rpc_sentinel = object()
|
||||
|
||||
def _receive_rpc_value(self, rpc_map):
|
||||
tag = chr(self._read_int8())
|
||||
if tag == "\x00":
|
||||
return self._rpc_sentinel
|
||||
elif tag == "t":
|
||||
length = self._read_int8()
|
||||
return tuple(self._receive_rpc_value(rpc_map) for _ in range(length))
|
||||
elif tag == "n":
|
||||
return None
|
||||
elif tag == "b":
|
||||
return bool(self._read_int8())
|
||||
@ -291,31 +299,36 @@ class CommGeneric:
|
||||
numerator = self._read_int64()
|
||||
denominator = self._read_int64()
|
||||
return Fraction(numerator, denominator)
|
||||
elif tag == "s":
|
||||
return self._read_string()
|
||||
elif tag == "l":
|
||||
elt_tag = chr(self._read_int8())
|
||||
length = self._read_int32()
|
||||
return [self._receive_rpc_value(elt_tag) for _ in range(length)]
|
||||
return [self._receive_rpc_value(rpc_map) for _ in range(length)]
|
||||
elif tag == "r":
|
||||
lower = self._receive_rpc_value(rpc_map)
|
||||
upper = self._receive_rpc_value(rpc_map)
|
||||
step = self._receive_rpc_value(rpc_map)
|
||||
return range(lower, upper, step)
|
||||
elif tag == "o":
|
||||
return rpc_map[self._read_int32()]
|
||||
else:
|
||||
raise IOError("Unknown RPC value tag: {}", tag)
|
||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||
|
||||
def _receive_rpc_values(self, rpc_map):
|
||||
result = []
|
||||
def _receive_rpc_args(self, rpc_map):
|
||||
args = []
|
||||
while True:
|
||||
tag = chr(self._read_int8())
|
||||
if tag == "\x00":
|
||||
return result
|
||||
else:
|
||||
result.append(self._receive_rpc_value(tag, rpc_map))
|
||||
value = self._receive_rpc_value(rpc_map)
|
||||
if value is self._rpc_sentinel:
|
||||
return args
|
||||
args.append(value)
|
||||
|
||||
def _serve_rpc(self, rpc_map):
|
||||
service = self._read_int32()
|
||||
args = self._receive_rpc_values(rpc_map)
|
||||
args = self._receive_rpc_args(rpc_map)
|
||||
logger.debug("rpc service: %d %r", service, args)
|
||||
|
||||
try:
|
||||
result = rpc_map[rpc_num](args)
|
||||
result = rpc_map[service](*args)
|
||||
if not isinstance(result, int) or not (-2**31 < result < 2**31-1):
|
||||
raise ValueError("An RPC must return an int(width=32)")
|
||||
except ARTIQException as exn:
|
||||
|
@ -50,13 +50,13 @@ class Core:
|
||||
raise CompileError() from error
|
||||
|
||||
def run(self, function, args, kwargs):
|
||||
kernel_library, rpc_map = self.compile(function, args, kwargs)
|
||||
|
||||
if self.first_run:
|
||||
self.comm.check_ident()
|
||||
self.comm.switch_clock(self.external_clock)
|
||||
self.first_run = False
|
||||
|
||||
kernel_library, rpc_map = self.compile(function, args, kwargs)
|
||||
|
||||
try:
|
||||
self.comm.load(kernel_library)
|
||||
except Exception as error:
|
||||
|
@ -301,33 +301,34 @@ void watchdog_clear(int id)
|
||||
mailbox_send_and_wait(&request);
|
||||
}
|
||||
|
||||
int rpc(int rpc_num, ...)
|
||||
int rpc(int service, const char *tag, ...)
|
||||
{
|
||||
struct msg_rpc_request request;
|
||||
struct msg_rpc_send_request request;
|
||||
struct msg_base *reply;
|
||||
|
||||
request.type = MESSAGE_TYPE_RPC_REQUEST;
|
||||
request.rpc_num = rpc_num;
|
||||
va_start(request.args, rpc_num);
|
||||
request.type = MESSAGE_TYPE_RPC_SEND_REQUEST;
|
||||
request.service = service;
|
||||
request.tag = tag;
|
||||
va_start(request.args, tag);
|
||||
mailbox_send_and_wait(&request);
|
||||
va_end(request.args);
|
||||
|
||||
reply = mailbox_wait_and_receive();
|
||||
if(reply->type == MESSAGE_TYPE_RPC_REPLY) {
|
||||
int result = ((struct msg_rpc_reply *)reply)->result;
|
||||
mailbox_acknowledge();
|
||||
return result;
|
||||
} else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) {
|
||||
struct artiq_exception exception;
|
||||
memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception,
|
||||
sizeof(struct artiq_exception));
|
||||
mailbox_acknowledge();
|
||||
__artiq_raise(&exception);
|
||||
} else {
|
||||
// if(reply->type == MESSAGE_TYPE_RPC_REPLY) {
|
||||
// int result = ((struct msg_rpc_reply *)reply)->result;
|
||||
// mailbox_acknowledge();
|
||||
// return result;
|
||||
// } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) {
|
||||
// struct artiq_exception exception;
|
||||
// memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception,
|
||||
// sizeof(struct artiq_exception));
|
||||
// mailbox_acknowledge();
|
||||
// __artiq_raise(&exception);
|
||||
// } else {
|
||||
log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d",
|
||||
reply->type);
|
||||
while(1);
|
||||
}
|
||||
// }
|
||||
}
|
||||
|
||||
void lognonl(const char *fmt, ...)
|
||||
|
@ -5,7 +5,7 @@ long long int now_init(void);
|
||||
void now_save(long long int now);
|
||||
int watchdog_set(int ms);
|
||||
void watchdog_clear(int id);
|
||||
int rpc(int service, ...);
|
||||
int rpc(int service, const char *tag, ...);
|
||||
void lognonl(const char *fmt, ...);
|
||||
void log(const char *fmt, ...);
|
||||
|
||||
|
@ -14,8 +14,9 @@ enum {
|
||||
MESSAGE_TYPE_WATCHDOG_SET_REQUEST,
|
||||
MESSAGE_TYPE_WATCHDOG_SET_REPLY,
|
||||
MESSAGE_TYPE_WATCHDOG_CLEAR,
|
||||
MESSAGE_TYPE_RPC_REQUEST,
|
||||
MESSAGE_TYPE_RPC_REPLY,
|
||||
MESSAGE_TYPE_RPC_SEND_REQUEST,
|
||||
MESSAGE_TYPE_RPC_RECV_REQUEST,
|
||||
MESSAGE_TYPE_RPC_RECV_REPLY,
|
||||
MESSAGE_TYPE_RPC_EXCEPTION,
|
||||
MESSAGE_TYPE_LOG,
|
||||
|
||||
@ -80,15 +81,21 @@ struct msg_watchdog_clear {
|
||||
int id;
|
||||
};
|
||||
|
||||
struct msg_rpc_request {
|
||||
struct msg_rpc_send_request {
|
||||
int type;
|
||||
int rpc_num;
|
||||
int service;
|
||||
const char *tag;
|
||||
va_list args;
|
||||
};
|
||||
|
||||
struct msg_rpc_reply {
|
||||
struct msg_rpc_recv_request {
|
||||
int type;
|
||||
int result;
|
||||
// TODO ???
|
||||
};
|
||||
|
||||
struct msg_rpc_recv_reply {
|
||||
int type;
|
||||
// TODO ???
|
||||
};
|
||||
|
||||
struct msg_rpc_exception {
|
||||
|
@ -457,23 +457,23 @@ static int process_input(void)
|
||||
user_kernel_state = USER_KERNEL_RUNNING;
|
||||
break;
|
||||
|
||||
case REMOTEMSG_TYPE_RPC_REPLY: {
|
||||
struct msg_rpc_reply reply;
|
||||
// case REMOTEMSG_TYPE_RPC_REPLY: {
|
||||
// struct msg_rpc_reply reply;
|
||||
|
||||
int result = in_packet_int32();
|
||||
// int result = in_packet_int32();
|
||||
|
||||
if(user_kernel_state != USER_KERNEL_WAIT_RPC) {
|
||||
log("Unsolicited RPC reply");
|
||||
return 0; // restart session
|
||||
}
|
||||
// if(user_kernel_state != USER_KERNEL_WAIT_RPC) {
|
||||
// log("Unsolicited RPC reply");
|
||||
// return 0; // restart session
|
||||
// }
|
||||
|
||||
reply.type = MESSAGE_TYPE_RPC_REPLY;
|
||||
reply.result = result;
|
||||
mailbox_send_and_wait(&reply);
|
||||
// reply.type = MESSAGE_TYPE_RPC_REPLY;
|
||||
// reply.result = result;
|
||||
// mailbox_send_and_wait(&reply);
|
||||
|
||||
user_kernel_state = USER_KERNEL_RUNNING;
|
||||
break;
|
||||
}
|
||||
// user_kernel_state = USER_KERNEL_RUNNING;
|
||||
// break;
|
||||
// }
|
||||
|
||||
case REMOTEMSG_TYPE_RPC_EXCEPTION: {
|
||||
struct msg_rpc_exception reply;
|
||||
@ -509,91 +509,156 @@ static int process_input(void)
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int send_rpc_value(const char **tag, void *value)
|
||||
// See llvm_ir_generator.py:_rpc_tag.
|
||||
static int send_rpc_value(const char **tag, void **value)
|
||||
{
|
||||
if(!out_packet_int8(**tag))
|
||||
return -1;
|
||||
return 0;
|
||||
|
||||
switch(*(*tag)++) {
|
||||
case 't': { // tuple
|
||||
int size = *(*tag)++;
|
||||
if(!out_packet_int8(size))
|
||||
return 0;
|
||||
|
||||
for(int i = 0; i < size; i++) {
|
||||
if(!send_rpc_value(tag, value))
|
||||
return 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
int size = 0;
|
||||
switch(**tag) {
|
||||
case 0: // last tag
|
||||
case 'n': // None
|
||||
break;
|
||||
|
||||
case 'b': // bool
|
||||
size = 1;
|
||||
if(!out_packet_chunk(value, size))
|
||||
return -1;
|
||||
case 'b': { // bool
|
||||
int size = sizeof(int8_t);
|
||||
if(!out_packet_chunk(*value, size))
|
||||
return 0;
|
||||
*value = (void*)((intptr_t)(*value) + size);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'i': // int(width=32)
|
||||
size = 4;
|
||||
if(!out_packet_chunk(value, size))
|
||||
return -1;
|
||||
case 'i': { // int(width=32)
|
||||
int size = sizeof(int32_t);
|
||||
if(!out_packet_chunk(*value, size))
|
||||
return 0;
|
||||
*value = (void*)((intptr_t)(*value) + size);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'I': // int(width=64)
|
||||
case 'f': // float
|
||||
size = 8;
|
||||
if(!out_packet_chunk(value, size))
|
||||
return -1;
|
||||
case 'I': { // int(width=64)
|
||||
int size = sizeof(int64_t);
|
||||
if(!out_packet_chunk(*value, size))
|
||||
return 0;
|
||||
*value = (void*)((intptr_t)(*value) + size);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'F': // Fraction
|
||||
size = 16;
|
||||
if(!out_packet_chunk(value, size))
|
||||
return -1;
|
||||
case 'f': { // float
|
||||
int size = sizeof(double);
|
||||
if(!out_packet_chunk(*value, size))
|
||||
return 0;
|
||||
*value = (void*)((intptr_t)(*value) + size);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'F': { // Fraction
|
||||
int size = sizeof(int64_t) * 2;
|
||||
if(!out_packet_chunk(*value, size))
|
||||
return 0;
|
||||
*value = (void*)((intptr_t)(*value) + size);
|
||||
break;
|
||||
}
|
||||
|
||||
case 's': { // string
|
||||
const char **string = *value;
|
||||
if(!out_packet_string(*string))
|
||||
return 0;
|
||||
*value = (void*)((intptr_t)(*value) + strlen(*string) + 1);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'l': { // list(elt='a)
|
||||
struct { uint32_t length; void *elements; } *list = value;
|
||||
struct { uint32_t length; struct {} *elements; } *list = *value;
|
||||
void *element = list->elements;
|
||||
|
||||
const char *tag_copy = *tag + 1;
|
||||
if(!out_packet_int32(list->length))
|
||||
return 0;
|
||||
|
||||
const char *tag_copy;
|
||||
for(int i = 0; i < list->length; i++) {
|
||||
int element_size = send_rpc_value(&tag_copy, element);
|
||||
if(element_size < 0)
|
||||
return -1;
|
||||
element = (void*)((intptr_t)element + element_size);
|
||||
tag_copy = *tag;
|
||||
if(!send_rpc_value(&tag_copy, &element))
|
||||
return 0;
|
||||
}
|
||||
*tag = tag_copy;
|
||||
|
||||
size = sizeof(list);
|
||||
*value = (void*)((intptr_t)(*value) + sizeof(*list));
|
||||
break;
|
||||
}
|
||||
|
||||
case 'o': { // host object
|
||||
struct { uint32_t id; } *object = value;
|
||||
|
||||
if(!out_packet_int32(object->id))
|
||||
return -1;
|
||||
|
||||
size = sizeof(object);
|
||||
case 'r': { // range(elt='a)
|
||||
const char *tag_copy;
|
||||
tag_copy = *tag;
|
||||
if(!send_rpc_value(&tag_copy, value)) // min
|
||||
return 0;
|
||||
tag_copy = *tag;
|
||||
if(!send_rpc_value(&tag_copy, value)) // max
|
||||
return 0;
|
||||
tag_copy = *tag;
|
||||
if(!send_rpc_value(&tag_copy, value)) // step
|
||||
return 0;
|
||||
*tag = tag_copy;
|
||||
break;
|
||||
}
|
||||
|
||||
case 'o': { // option(inner='a)
|
||||
struct { int8_t present; struct {} contents; } *option = *value;
|
||||
void *contents = &option->contents;
|
||||
|
||||
if(!out_packet_int8(option->present))
|
||||
return 0;
|
||||
|
||||
// option never appears in composite types, so we don't have
|
||||
// to accurately advance *value.
|
||||
if(option->present) {
|
||||
return send_rpc_value(tag, &contents);
|
||||
} else {
|
||||
(*tag)++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
case 'O': { // host object
|
||||
struct { uint32_t id; } **object = *value;
|
||||
|
||||
if(!out_packet_int32((*object)->id))
|
||||
return 0;
|
||||
}
|
||||
|
||||
default:
|
||||
return -1;
|
||||
log("send_rpc_value: unknown tag %02x", *((*tag) - 1));
|
||||
return 0;
|
||||
}
|
||||
|
||||
(*tag)++;
|
||||
return size;
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int send_rpc_request(int service, va_list args)
|
||||
static int send_rpc_request(int service, const char *tag, va_list args)
|
||||
{
|
||||
out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST);
|
||||
out_packet_int32(service);
|
||||
|
||||
const char *tag = va_arg(args, const char*);
|
||||
while(*tag) {
|
||||
while(*tag != ':') {
|
||||
void *value = va_arg(args, void*);
|
||||
if(!kloader_validate_kpointer(value))
|
||||
return 0;
|
||||
if(send_rpc_value(&tag, &value) < 0)
|
||||
if(!send_rpc_value(&tag, &value))
|
||||
return 0;
|
||||
}
|
||||
|
||||
out_packet_int8(0);
|
||||
out_packet_finish();
|
||||
return 1;
|
||||
}
|
||||
@ -670,10 +735,10 @@ static int process_kmsg(struct msg_base *umsg)
|
||||
break;
|
||||
}
|
||||
|
||||
case MESSAGE_TYPE_RPC_REQUEST: {
|
||||
struct msg_rpc_request *msg = (struct msg_rpc_request *)umsg;
|
||||
case MESSAGE_TYPE_RPC_SEND_REQUEST: {
|
||||
struct msg_rpc_send_request *msg = (struct msg_rpc_send_request *)umsg;
|
||||
|
||||
if(!send_rpc_request(msg->rpc_num, msg->args)) {
|
||||
if(!send_rpc_request(msg->service, msg->tag, msg->args)) {
|
||||
log("Failed to send RPC request");
|
||||
return 0; // restart session
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user