From 9b9fa1ab7c6bce174e4b1f521a9c5248e0402e5d Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 25 Aug 2015 21:56:01 -0700 Subject: [PATCH] Allow embedding and RPC sending host objects. --- artiq/compiler/embedding.py | 79 +++++++++++++------ artiq/compiler/ir.py | 17 ++++ artiq/compiler/module.py | 6 +- .../compiler/transforms/artiq_ir_generator.py | 3 + .../compiler/transforms/llvm_ir_generator.py | 54 ++++++++++++- artiq/coredevice/comm_generic.py | 32 ++++---- artiq/coredevice/core.py | 6 +- soc/runtime/session.c | 4 +- 8 files changed, 154 insertions(+), 47 deletions(-) diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index afbce7820..5232c6293 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -14,11 +14,30 @@ from . import types, builtins, asttyped, prelude from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer +class ObjectMap: + def __init__(self): + self.current_key = 0 + self.forward_map = {} + self.reverse_map = {} + + def store(self, obj_ref): + obj_id = id(obj_ref) + if obj_id in self.reverse_map: + return self.reverse_map[obj_id] + + self.current_key += 1 + self.forward_map[self.current_key] = obj_ref + self.reverse_map[obj_id] = self.current_key + return self.current_key + + def retrieve(self, obj_key): + return self.forward_map[obj_key] + class ASTSynthesizer: - def __init__(self, expanded_from=None): + def __init__(self, type_map, expanded_from=None): self.source = "" self.source_buffer = source.Buffer(self.source, "") - self.expanded_from = expanded_from + self.type_map, self.expanded_from = type_map, expanded_from def finalize(self): self.source_buffer.source = self.source @@ -63,8 +82,32 @@ class ASTSynthesizer: begin_loc=begin_loc, end_loc=end_loc, loc=begin_loc.join(end_loc)) else: - raise "no" - # return asttyped.QuoteT(value=value, type=types.TVar()) + if isinstance(value, type): + typ = value + else: + typ = type(value) + + if typ in self.type_map: + instance_type, constructor_type = self.type_map[typ] + else: + instance_type = types.TInstance("{}.{}".format(typ.__module__, typ.__name__)) + instance_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32)) + + constructor_type = types.TConstructor(instance_type) + constructor_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32)) + + self.type_map[typ] = instance_type, constructor_type + + quote_loc = self._add('`') + repr_loc = self._add(repr(value)) + unquote_loc = self._add('`') + + if isinstance(value, type): + return asttyped.QuoteT(value=value, type=constructor_type, + loc=quote_loc.join(unquote_loc)) + else: + return asttyped.QuoteT(value=value, type=instance_type, + loc=quote_loc.join(unquote_loc)) def call(self, function_node, args, kwargs): """ @@ -108,13 +151,14 @@ class ASTSynthesizer: loc=name_loc.join(end_loc)) class StitchingASTTypedRewriter(ASTTypedRewriter): - def __init__(self, engine, prelude, globals, host_environment, quote_function): + def __init__(self, engine, prelude, globals, host_environment, quote_function, type_map): super().__init__(engine, prelude) self.globals = globals self.env_stack.append(self.globals) self.host_environment = host_environment self.quote_function = quote_function + self.type_map = type_map def visit_Name(self, node): typ = super()._try_find_name(node.id) @@ -136,7 +180,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter): else: # It's just a value. Quote it. - synthesizer = ASTSynthesizer(expanded_from=node.loc) + synthesizer = ASTSynthesizer(expanded_from=node.loc, type_map=self.type_map) node = synthesizer.quote(value) synthesizer.finalize() return node @@ -160,19 +204,8 @@ class Stitcher: self.functions = {} - self.next_rpc = 0 - self.rpc_map = {} - self.inverse_rpc_map = {} - - def _map(self, obj): - obj_id = id(obj) - if obj_id in self.inverse_rpc_map: - return self.inverse_rpc_map[obj_id] - - self.next_rpc += 1 - self.rpc_map[self.next_rpc] = obj - self.inverse_rpc_map[obj_id] = self.next_rpc - return self.next_rpc + self.object_map = ObjectMap() + self.type_map = {} def finalize(self): inferencer = Inferencer(engine=self.engine) @@ -229,7 +262,7 @@ class Stitcher: asttyped_rewriter = StitchingASTTypedRewriter( engine=self.engine, prelude=self.prelude, globals=self.globals, host_environment=host_environment, - quote_function=self._quote_function) + quote_function=self._quote_function, type_map=self.type_map) return asttyped_rewriter.visit(function_node) def _function_loc(self, function): @@ -291,7 +324,7 @@ class Stitcher: # This is tricky, because the default value might not have # a well-defined type in APython. # In this case, we bail out, but mention why we do it. - synthesizer = ASTSynthesizer() + synthesizer = ASTSynthesizer(type_map=self.type_map) ast = synthesizer.quote(param.default) synthesizer.finalize() @@ -365,7 +398,7 @@ class Stitcher: if syscall is None: function_type = types.TRPCFunction(arg_types, optarg_types, ret_type, - service=self._map(function)) + service=self.object_map.store(function)) function_name = "rpc${}".format(function_type.service) else: function_type = types.TCFunction(arg_types, ret_type, @@ -409,7 +442,7 @@ class Stitcher: # We synthesize source code for the initial call so that # diagnostics would have something meaningful to display to the user. - synthesizer = ASTSynthesizer() + synthesizer = ASTSynthesizer(type_map=self.type_map) call_node = synthesizer.call(function_node, args, kwargs) synthesizer.finalize() self.typedtree.append(call_node) diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 1861d110e..f20357ed8 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -901,6 +901,23 @@ class Select(Instruction): def if_false(self): return self.operands[2] +class Quote(Instruction): + """ + A quote operation. Returns a host interpreter value as a constant. + + :ivar value: (string) operation name + """ + + """ + :param value: (string) operation name + """ + def __init__(self, value, typ, name=""): + super().__init__([], typ, name) + self.value = value + + def opcode(self): + return "quote({})".format(repr(self.value)) + class Branch(Terminator): """ An unconditional branch instruction. diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index 365760bb7..dbcaf2cd7 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -19,6 +19,8 @@ class Source: else: self.engine = engine + self.object_map = None + self.name, _ = os.path.splitext(os.path.basename(source_buffer.name)) asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine, @@ -42,6 +44,7 @@ class Source: class Module: def __init__(self, src): self.engine = src.engine + self.object_map = src.object_map int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine) inferencer = transforms.Inferencer(engine=self.engine) @@ -65,7 +68,8 @@ class Module: def build_llvm_ir(self, target): """Compile the module to LLVM IR for the specified target.""" llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine, - module_name=self.name, target=target) + module_name=self.name, target=target, + object_map=self.object_map) return llvm_ir_generator.process(self.artiq_ir) def entry_point(self): diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index fd1be222f..6f6b626f3 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1474,6 +1474,9 @@ class ARTIQIRGenerator(algorithm.Visitor): self.current_block = after_invoke return invoke + def visit_QuoteT(self, node): + return self.append(ir.Quote(node.value, node.type)) + def instrument_assert(self, node, value): if self.current_assert_env is not None: if isinstance(value, ir.Constant): diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 80aabdc09..4e5af87a6 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -159,15 +159,17 @@ class DebugInfoEmitter: class LLVMIRGenerator: - def __init__(self, engine, module_name, target): + def __init__(self, engine, module_name, target, object_map): self.engine = engine self.target = target + self.object_map = object_map self.llcontext = target.llcontext self.llmodule = ll.Module(context=self.llcontext, name=module_name) self.llmodule.triple = target.triple self.llmodule.data_layout = target.data_layout self.llfunction = None self.llmap = {} + self.llobject_map = {} self.phis = [] self.debug_info_emitter = DebugInfoEmitter(self.llmodule) @@ -815,6 +817,8 @@ class LLVMIRGenerator: elif ir.is_option(typ): return b"o" + self._rpc_tag(typ.params["inner"], error_handler) + elif '__objectid__' in typ.attributes: + return b"O" else: error_handler(typ) @@ -960,6 +964,54 @@ class LLVMIRGenerator: return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock, name=insn.name) + def _quote(self, value, typ, path): + value_id = id(value) + if value_id in self.llobject_map: + return self.llobject_map[value_id] + + global_name = "" + llty = self.llty_of_type(typ) + + if types.is_constructor(typ) or types.is_instance(typ): + llfields = [] + for attr in typ.attributes: + if attr == "__objectid__": + objectid = self.object_map.store(value) + llfields.append(ll.Constant(lli32, objectid)) + global_name = "object.{}".format(objectid) + else: + llfields.append(self._quote(getattr(value, attr), typ.attributes[attr], + path + [attr])) + + llvalue = ll.Constant.literal_struct(llfields) + elif builtins.is_none(typ): + assert value is None + return self.llconst_of_const(value) + elif builtins.is_bool(typ): + assert value in (True, False) + return self.llconst_of_const(value) + elif builtins.is_int(typ): + assert isinstance(value, int) + return self.llconst_of_const(value) + elif builtins.is_float(typ): + assert isinstance(value, float) + return self.llconst_of_const(value) + elif builtins.is_str(typ): + assert isinstance(value, (str, bytes)) + return self.llconst_of_const(value) + else: + assert False + + llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name) + llconst.initializer = llvalue + llconst.linkage = "private" + self.llobject_map[value_id] = llconst + return llconst + + def process_Quote(self, insn): + assert self.object_map is not None + return self._quote(insn.value, insn.type, lambda: [repr(insn.value)]) + def process_Select(self, insn): return self.llbuilder.select(self.map(insn.condition()), self.map(insn.if_true()), self.map(insn.if_false())) diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index 3f1a188d3..b1967b352 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -282,13 +282,13 @@ class CommGeneric: _rpc_sentinel = object() # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. - def _receive_rpc_value(self, rpc_map): + def _receive_rpc_value(self, object_map): tag = chr(self._read_int8()) if tag == "\x00": return self._rpc_sentinel elif tag == "t": length = self._read_int8() - return tuple(self._receive_rpc_value(rpc_map) for _ in range(length)) + return tuple(self._receive_rpc_value(object_map) for _ in range(length)) elif tag == "n": return None elif tag == "b": @@ -307,25 +307,25 @@ class CommGeneric: return self._read_string() elif tag == "l": length = self._read_int32() - return [self._receive_rpc_value(rpc_map) for _ in range(length)] + return [self._receive_rpc_value(object_map) for _ in range(length)] elif tag == "r": - start = self._receive_rpc_value(rpc_map) - stop = self._receive_rpc_value(rpc_map) - step = self._receive_rpc_value(rpc_map) + start = self._receive_rpc_value(object_map) + stop = self._receive_rpc_value(object_map) + step = self._receive_rpc_value(object_map) return range(start, stop, step) elif tag == "o": present = self._read_int8() if present: - return self._receive_rpc_value(rpc_map) + return self._receive_rpc_value(object_map) elif tag == "O": - return rpc_map[self._read_int32()] + return object_map.retrieve(self._read_int32()) else: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) - def _receive_rpc_args(self, rpc_map): + def _receive_rpc_args(self, object_map): args = [] while True: - value = self._receive_rpc_value(rpc_map) + value = self._receive_rpc_value(object_map) if value is self._rpc_sentinel: return args args.append(value) @@ -410,20 +410,20 @@ class CommGeneric: else: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) - def _serve_rpc(self, rpc_map): + def _serve_rpc(self, object_map): service = self._read_int32() - args = self._receive_rpc_args(rpc_map) + args = self._receive_rpc_args(object_map) return_tags = self._read_bytes() logger.debug("rpc service: %d %r -> %s", service, args, return_tags) try: - result = rpc_map[service](*args) + result = object_map.retrieve(service)(*args) logger.debug("rpc service: %d %r == %r", service, args, result) self._write_header(_H2DMsgType.RPC_REPLY) self._write_bytes(return_tags) self._send_rpc_value(bytearray(return_tags), result, result, - rpc_map[service]) + object_map.retrieve(service)) self._write_flush() except core_language.ARTIQException as exn: logger.debug("rpc service: %d %r ! %r", service, args, exn) @@ -473,11 +473,11 @@ class CommGeneric: [(filename, line, column, function, None)] raise core_language.ARTIQException(name, message, params, traceback) - def serve(self, rpc_map, symbolizer): + def serve(self, object_map, symbolizer): while True: self._read_header() if self._read_type == _D2HMsgType.RPC_REQUEST: - self._serve_rpc(rpc_map) + self._serve_rpc(object_map) elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION: self._serve_exception(symbolizer) else: diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index 6246f6301..e02b1ca47 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -45,14 +45,14 @@ class Core: library = target.compile_and_link([module]) stripped_library = target.strip(library) - return stitcher.rpc_map, stripped_library, \ + return stitcher.object_map, stripped_library, \ lambda addresses: target.symbolize(library, addresses) except diagnostic.Error as error: print("\n".join(error.diagnostic.render(colored=True)), file=sys.stderr) raise CompileError() from error def run(self, function, args, kwargs): - rpc_map, kernel_library, symbolizer = self.compile(function, args, kwargs) + object_map, kernel_library, symbolizer = self.compile(function, args, kwargs) if self.first_run: self.comm.check_ident() @@ -61,7 +61,7 @@ class Core: self.comm.load(kernel_library) self.comm.run() - self.comm.serve(rpc_map, symbolizer) + self.comm.serve(object_map, symbolizer) @kernel def get_rtio_counter_mu(self): diff --git a/soc/runtime/session.c b/soc/runtime/session.c index 2ea0c0e8e..ff6629f4f 100644 --- a/soc/runtime/session.c +++ b/soc/runtime/session.c @@ -816,9 +816,7 @@ static int send_rpc_value(const char **tag, void **value) case 'O': { // host object struct { uint32_t id; } **object = *value; - - if(!out_packet_int32((*object)->id)) - return 0; + return out_packet_int32((*object)->id); } default: