forked from M-Labs/artiq
Allow embedding and RPC sending host objects.
This commit is contained in:
parent
526d7c4e46
commit
9b9fa1ab7c
@ -14,11 +14,30 @@ from . import types, builtins, asttyped, prelude
|
||||
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
||||
|
||||
|
||||
class ObjectMap:
|
||||
def __init__(self):
|
||||
self.current_key = 0
|
||||
self.forward_map = {}
|
||||
self.reverse_map = {}
|
||||
|
||||
def store(self, obj_ref):
|
||||
obj_id = id(obj_ref)
|
||||
if obj_id in self.reverse_map:
|
||||
return self.reverse_map[obj_id]
|
||||
|
||||
self.current_key += 1
|
||||
self.forward_map[self.current_key] = obj_ref
|
||||
self.reverse_map[obj_id] = self.current_key
|
||||
return self.current_key
|
||||
|
||||
def retrieve(self, obj_key):
|
||||
return self.forward_map[obj_key]
|
||||
|
||||
class ASTSynthesizer:
|
||||
def __init__(self, expanded_from=None):
|
||||
def __init__(self, type_map, expanded_from=None):
|
||||
self.source = ""
|
||||
self.source_buffer = source.Buffer(self.source, "<synthesized>")
|
||||
self.expanded_from = expanded_from
|
||||
self.type_map, self.expanded_from = type_map, expanded_from
|
||||
|
||||
def finalize(self):
|
||||
self.source_buffer.source = self.source
|
||||
@ -63,8 +82,32 @@ class ASTSynthesizer:
|
||||
begin_loc=begin_loc, end_loc=end_loc,
|
||||
loc=begin_loc.join(end_loc))
|
||||
else:
|
||||
raise "no"
|
||||
# return asttyped.QuoteT(value=value, type=types.TVar())
|
||||
if isinstance(value, type):
|
||||
typ = value
|
||||
else:
|
||||
typ = type(value)
|
||||
|
||||
if typ in self.type_map:
|
||||
instance_type, constructor_type = self.type_map[typ]
|
||||
else:
|
||||
instance_type = types.TInstance("{}.{}".format(typ.__module__, typ.__name__))
|
||||
instance_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32))
|
||||
|
||||
constructor_type = types.TConstructor(instance_type)
|
||||
constructor_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32))
|
||||
|
||||
self.type_map[typ] = instance_type, constructor_type
|
||||
|
||||
quote_loc = self._add('`')
|
||||
repr_loc = self._add(repr(value))
|
||||
unquote_loc = self._add('`')
|
||||
|
||||
if isinstance(value, type):
|
||||
return asttyped.QuoteT(value=value, type=constructor_type,
|
||||
loc=quote_loc.join(unquote_loc))
|
||||
else:
|
||||
return asttyped.QuoteT(value=value, type=instance_type,
|
||||
loc=quote_loc.join(unquote_loc))
|
||||
|
||||
def call(self, function_node, args, kwargs):
|
||||
"""
|
||||
@ -108,13 +151,14 @@ class ASTSynthesizer:
|
||||
loc=name_loc.join(end_loc))
|
||||
|
||||
class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||
def __init__(self, engine, prelude, globals, host_environment, quote_function):
|
||||
def __init__(self, engine, prelude, globals, host_environment, quote_function, type_map):
|
||||
super().__init__(engine, prelude)
|
||||
self.globals = globals
|
||||
self.env_stack.append(self.globals)
|
||||
|
||||
self.host_environment = host_environment
|
||||
self.quote_function = quote_function
|
||||
self.type_map = type_map
|
||||
|
||||
def visit_Name(self, node):
|
||||
typ = super()._try_find_name(node.id)
|
||||
@ -136,7 +180,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||
|
||||
else:
|
||||
# It's just a value. Quote it.
|
||||
synthesizer = ASTSynthesizer(expanded_from=node.loc)
|
||||
synthesizer = ASTSynthesizer(expanded_from=node.loc, type_map=self.type_map)
|
||||
node = synthesizer.quote(value)
|
||||
synthesizer.finalize()
|
||||
return node
|
||||
@ -160,19 +204,8 @@ class Stitcher:
|
||||
|
||||
self.functions = {}
|
||||
|
||||
self.next_rpc = 0
|
||||
self.rpc_map = {}
|
||||
self.inverse_rpc_map = {}
|
||||
|
||||
def _map(self, obj):
|
||||
obj_id = id(obj)
|
||||
if obj_id in self.inverse_rpc_map:
|
||||
return self.inverse_rpc_map[obj_id]
|
||||
|
||||
self.next_rpc += 1
|
||||
self.rpc_map[self.next_rpc] = obj
|
||||
self.inverse_rpc_map[obj_id] = self.next_rpc
|
||||
return self.next_rpc
|
||||
self.object_map = ObjectMap()
|
||||
self.type_map = {}
|
||||
|
||||
def finalize(self):
|
||||
inferencer = Inferencer(engine=self.engine)
|
||||
@ -229,7 +262,7 @@ class Stitcher:
|
||||
asttyped_rewriter = StitchingASTTypedRewriter(
|
||||
engine=self.engine, prelude=self.prelude,
|
||||
globals=self.globals, host_environment=host_environment,
|
||||
quote_function=self._quote_function)
|
||||
quote_function=self._quote_function, type_map=self.type_map)
|
||||
return asttyped_rewriter.visit(function_node)
|
||||
|
||||
def _function_loc(self, function):
|
||||
@ -291,7 +324,7 @@ class Stitcher:
|
||||
# This is tricky, because the default value might not have
|
||||
# a well-defined type in APython.
|
||||
# In this case, we bail out, but mention why we do it.
|
||||
synthesizer = ASTSynthesizer()
|
||||
synthesizer = ASTSynthesizer(type_map=self.type_map)
|
||||
ast = synthesizer.quote(param.default)
|
||||
synthesizer.finalize()
|
||||
|
||||
@ -365,7 +398,7 @@ class Stitcher:
|
||||
|
||||
if syscall is None:
|
||||
function_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
|
||||
service=self._map(function))
|
||||
service=self.object_map.store(function))
|
||||
function_name = "rpc${}".format(function_type.service)
|
||||
else:
|
||||
function_type = types.TCFunction(arg_types, ret_type,
|
||||
@ -409,7 +442,7 @@ class Stitcher:
|
||||
|
||||
# We synthesize source code for the initial call so that
|
||||
# diagnostics would have something meaningful to display to the user.
|
||||
synthesizer = ASTSynthesizer()
|
||||
synthesizer = ASTSynthesizer(type_map=self.type_map)
|
||||
call_node = synthesizer.call(function_node, args, kwargs)
|
||||
synthesizer.finalize()
|
||||
self.typedtree.append(call_node)
|
||||
|
@ -901,6 +901,23 @@ class Select(Instruction):
|
||||
def if_false(self):
|
||||
return self.operands[2]
|
||||
|
||||
class Quote(Instruction):
|
||||
"""
|
||||
A quote operation. Returns a host interpreter value as a constant.
|
||||
|
||||
:ivar value: (string) operation name
|
||||
"""
|
||||
|
||||
"""
|
||||
:param value: (string) operation name
|
||||
"""
|
||||
def __init__(self, value, typ, name=""):
|
||||
super().__init__([], typ, name)
|
||||
self.value = value
|
||||
|
||||
def opcode(self):
|
||||
return "quote({})".format(repr(self.value))
|
||||
|
||||
class Branch(Terminator):
|
||||
"""
|
||||
An unconditional branch instruction.
|
||||
|
@ -19,6 +19,8 @@ class Source:
|
||||
else:
|
||||
self.engine = engine
|
||||
|
||||
self.object_map = None
|
||||
|
||||
self.name, _ = os.path.splitext(os.path.basename(source_buffer.name))
|
||||
|
||||
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine,
|
||||
@ -42,6 +44,7 @@ class Source:
|
||||
class Module:
|
||||
def __init__(self, src):
|
||||
self.engine = src.engine
|
||||
self.object_map = src.object_map
|
||||
|
||||
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
|
||||
inferencer = transforms.Inferencer(engine=self.engine)
|
||||
@ -65,7 +68,8 @@ class Module:
|
||||
def build_llvm_ir(self, target):
|
||||
"""Compile the module to LLVM IR for the specified target."""
|
||||
llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine,
|
||||
module_name=self.name, target=target)
|
||||
module_name=self.name, target=target,
|
||||
object_map=self.object_map)
|
||||
return llvm_ir_generator.process(self.artiq_ir)
|
||||
|
||||
def entry_point(self):
|
||||
|
@ -1474,6 +1474,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.current_block = after_invoke
|
||||
return invoke
|
||||
|
||||
def visit_QuoteT(self, node):
|
||||
return self.append(ir.Quote(node.value, node.type))
|
||||
|
||||
def instrument_assert(self, node, value):
|
||||
if self.current_assert_env is not None:
|
||||
if isinstance(value, ir.Constant):
|
||||
|
@ -159,15 +159,17 @@ class DebugInfoEmitter:
|
||||
|
||||
|
||||
class LLVMIRGenerator:
|
||||
def __init__(self, engine, module_name, target):
|
||||
def __init__(self, engine, module_name, target, object_map):
|
||||
self.engine = engine
|
||||
self.target = target
|
||||
self.object_map = object_map
|
||||
self.llcontext = target.llcontext
|
||||
self.llmodule = ll.Module(context=self.llcontext, name=module_name)
|
||||
self.llmodule.triple = target.triple
|
||||
self.llmodule.data_layout = target.data_layout
|
||||
self.llfunction = None
|
||||
self.llmap = {}
|
||||
self.llobject_map = {}
|
||||
self.phis = []
|
||||
self.debug_info_emitter = DebugInfoEmitter(self.llmodule)
|
||||
|
||||
@ -815,6 +817,8 @@ class LLVMIRGenerator:
|
||||
elif ir.is_option(typ):
|
||||
return b"o" + self._rpc_tag(typ.params["inner"],
|
||||
error_handler)
|
||||
elif '__objectid__' in typ.attributes:
|
||||
return b"O"
|
||||
else:
|
||||
error_handler(typ)
|
||||
|
||||
@ -960,6 +964,54 @@ class LLVMIRGenerator:
|
||||
return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock,
|
||||
name=insn.name)
|
||||
|
||||
def _quote(self, value, typ, path):
|
||||
value_id = id(value)
|
||||
if value_id in self.llobject_map:
|
||||
return self.llobject_map[value_id]
|
||||
|
||||
global_name = ""
|
||||
llty = self.llty_of_type(typ)
|
||||
|
||||
if types.is_constructor(typ) or types.is_instance(typ):
|
||||
llfields = []
|
||||
for attr in typ.attributes:
|
||||
if attr == "__objectid__":
|
||||
objectid = self.object_map.store(value)
|
||||
llfields.append(ll.Constant(lli32, objectid))
|
||||
global_name = "object.{}".format(objectid)
|
||||
else:
|
||||
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
|
||||
path + [attr]))
|
||||
|
||||
llvalue = ll.Constant.literal_struct(llfields)
|
||||
elif builtins.is_none(typ):
|
||||
assert value is None
|
||||
return self.llconst_of_const(value)
|
||||
elif builtins.is_bool(typ):
|
||||
assert value in (True, False)
|
||||
return self.llconst_of_const(value)
|
||||
elif builtins.is_int(typ):
|
||||
assert isinstance(value, int)
|
||||
return self.llconst_of_const(value)
|
||||
elif builtins.is_float(typ):
|
||||
assert isinstance(value, float)
|
||||
return self.llconst_of_const(value)
|
||||
elif builtins.is_str(typ):
|
||||
assert isinstance(value, (str, bytes))
|
||||
return self.llconst_of_const(value)
|
||||
else:
|
||||
assert False
|
||||
|
||||
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
|
||||
llconst.initializer = llvalue
|
||||
llconst.linkage = "private"
|
||||
self.llobject_map[value_id] = llconst
|
||||
return llconst
|
||||
|
||||
def process_Quote(self, insn):
|
||||
assert self.object_map is not None
|
||||
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
|
||||
|
||||
def process_Select(self, insn):
|
||||
return self.llbuilder.select(self.map(insn.condition()),
|
||||
self.map(insn.if_true()), self.map(insn.if_false()))
|
||||
|
@ -282,13 +282,13 @@ class CommGeneric:
|
||||
_rpc_sentinel = object()
|
||||
|
||||
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
|
||||
def _receive_rpc_value(self, rpc_map):
|
||||
def _receive_rpc_value(self, object_map):
|
||||
tag = chr(self._read_int8())
|
||||
if tag == "\x00":
|
||||
return self._rpc_sentinel
|
||||
elif tag == "t":
|
||||
length = self._read_int8()
|
||||
return tuple(self._receive_rpc_value(rpc_map) for _ in range(length))
|
||||
return tuple(self._receive_rpc_value(object_map) for _ in range(length))
|
||||
elif tag == "n":
|
||||
return None
|
||||
elif tag == "b":
|
||||
@ -307,25 +307,25 @@ class CommGeneric:
|
||||
return self._read_string()
|
||||
elif tag == "l":
|
||||
length = self._read_int32()
|
||||
return [self._receive_rpc_value(rpc_map) for _ in range(length)]
|
||||
return [self._receive_rpc_value(object_map) for _ in range(length)]
|
||||
elif tag == "r":
|
||||
start = self._receive_rpc_value(rpc_map)
|
||||
stop = self._receive_rpc_value(rpc_map)
|
||||
step = self._receive_rpc_value(rpc_map)
|
||||
start = self._receive_rpc_value(object_map)
|
||||
stop = self._receive_rpc_value(object_map)
|
||||
step = self._receive_rpc_value(object_map)
|
||||
return range(start, stop, step)
|
||||
elif tag == "o":
|
||||
present = self._read_int8()
|
||||
if present:
|
||||
return self._receive_rpc_value(rpc_map)
|
||||
return self._receive_rpc_value(object_map)
|
||||
elif tag == "O":
|
||||
return rpc_map[self._read_int32()]
|
||||
return object_map.retrieve(self._read_int32())
|
||||
else:
|
||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||
|
||||
def _receive_rpc_args(self, rpc_map):
|
||||
def _receive_rpc_args(self, object_map):
|
||||
args = []
|
||||
while True:
|
||||
value = self._receive_rpc_value(rpc_map)
|
||||
value = self._receive_rpc_value(object_map)
|
||||
if value is self._rpc_sentinel:
|
||||
return args
|
||||
args.append(value)
|
||||
@ -410,20 +410,20 @@ class CommGeneric:
|
||||
else:
|
||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||
|
||||
def _serve_rpc(self, rpc_map):
|
||||
def _serve_rpc(self, object_map):
|
||||
service = self._read_int32()
|
||||
args = self._receive_rpc_args(rpc_map)
|
||||
args = self._receive_rpc_args(object_map)
|
||||
return_tags = self._read_bytes()
|
||||
logger.debug("rpc service: %d %r -> %s", service, args, return_tags)
|
||||
|
||||
try:
|
||||
result = rpc_map[service](*args)
|
||||
result = object_map.retrieve(service)(*args)
|
||||
logger.debug("rpc service: %d %r == %r", service, args, result)
|
||||
|
||||
self._write_header(_H2DMsgType.RPC_REPLY)
|
||||
self._write_bytes(return_tags)
|
||||
self._send_rpc_value(bytearray(return_tags), result, result,
|
||||
rpc_map[service])
|
||||
object_map.retrieve(service))
|
||||
self._write_flush()
|
||||
except core_language.ARTIQException as exn:
|
||||
logger.debug("rpc service: %d %r ! %r", service, args, exn)
|
||||
@ -473,11 +473,11 @@ class CommGeneric:
|
||||
[(filename, line, column, function, None)]
|
||||
raise core_language.ARTIQException(name, message, params, traceback)
|
||||
|
||||
def serve(self, rpc_map, symbolizer):
|
||||
def serve(self, object_map, symbolizer):
|
||||
while True:
|
||||
self._read_header()
|
||||
if self._read_type == _D2HMsgType.RPC_REQUEST:
|
||||
self._serve_rpc(rpc_map)
|
||||
self._serve_rpc(object_map)
|
||||
elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION:
|
||||
self._serve_exception(symbolizer)
|
||||
else:
|
||||
|
@ -45,14 +45,14 @@ class Core:
|
||||
library = target.compile_and_link([module])
|
||||
stripped_library = target.strip(library)
|
||||
|
||||
return stitcher.rpc_map, stripped_library, \
|
||||
return stitcher.object_map, stripped_library, \
|
||||
lambda addresses: target.symbolize(library, addresses)
|
||||
except diagnostic.Error as error:
|
||||
print("\n".join(error.diagnostic.render(colored=True)), file=sys.stderr)
|
||||
raise CompileError() from error
|
||||
|
||||
def run(self, function, args, kwargs):
|
||||
rpc_map, kernel_library, symbolizer = self.compile(function, args, kwargs)
|
||||
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs)
|
||||
|
||||
if self.first_run:
|
||||
self.comm.check_ident()
|
||||
@ -61,7 +61,7 @@ class Core:
|
||||
|
||||
self.comm.load(kernel_library)
|
||||
self.comm.run()
|
||||
self.comm.serve(rpc_map, symbolizer)
|
||||
self.comm.serve(object_map, symbolizer)
|
||||
|
||||
@kernel
|
||||
def get_rtio_counter_mu(self):
|
||||
|
@ -816,9 +816,7 @@ static int send_rpc_value(const char **tag, void **value)
|
||||
|
||||
case 'O': { // host object
|
||||
struct { uint32_t id; } **object = *value;
|
||||
|
||||
if(!out_packet_int32((*object)->id))
|
||||
return 0;
|
||||
return out_packet_int32((*object)->id);
|
||||
}
|
||||
|
||||
default:
|
||||
|
Loading…
Reference in New Issue
Block a user