From 640022122ba298f2f66f30eb1df5add966b9c7eb Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 16 May 2016 14:30:21 +0000 Subject: [PATCH] embedding: refactor some more. --- artiq/compiler/embedding.py | 83 ++++++++++++------- artiq/compiler/module.py | 16 ++-- .../compiler/transforms/llvm_ir_generator.py | 37 ++++----- artiq/coredevice/comm_dummy.py | 2 +- artiq/coredevice/comm_generic.py | 38 ++++----- artiq/coredevice/core.py | 6 +- artiq/frontend/artiq_run.py | 22 ++--- 7 files changed, 107 insertions(+), 97 deletions(-) diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 9f9259641..6727cf5d3 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -22,37 +22,61 @@ from .transforms.asttyped_rewriter import LocalExtractor def coredevice_print(x): print(x) -class ObjectMap: +class EmbeddingMap: def __init__(self): - self.current_key = 0 - self.forward_map = {} - self.reverse_map = {} + self.object_current_key = 0 + self.object_forward_map = {} + self.object_reverse_map = {} + self.type_map = {} + self.function_map = {} - def store(self, obj_ref): + # Types + def store_type(self, typ, instance_type, constructor_type): + self.type_map[typ] = (instance_type, constructor_type) + + def retrieve_type(self, typ): + return self.type_map[typ] + + def has_type(self, typ): + return typ in self.type_map + + def iter_types(self): + return self.type_map.values() + + # Functions + def store_function(self, function, ir_function_name): + self.function_map[function] = ir_function_name + + def retrieve_function(self, function): + return self.function_map[function] + + # Objects + def store_object(self, obj_ref): obj_id = id(obj_ref) - if obj_id in self.reverse_map: - return self.reverse_map[obj_id] + if obj_id in self.object_reverse_map: + return self.object_reverse_map[obj_id] - self.current_key += 1 - self.forward_map[self.current_key] = obj_ref - self.reverse_map[obj_id] = self.current_key - return self.current_key + self.object_current_key += 1 + self.object_forward_map[self.object_current_key] = obj_ref + self.object_reverse_map[obj_id] = self.object_current_key + return self.object_current_key - def retrieve(self, obj_key): - return self.forward_map[obj_key] + def retrieve_object(self, obj_key): + return self.object_forward_map[obj_key] + + def iter_objects(self): + return self.object_forward_map.keys() def has_rpc(self): return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x), - self.forward_map.values())) - - def __iter__(self): - return iter(self.forward_map.keys()) + self.object_forward_map.values())) class ASTSynthesizer: - def __init__(self, object_map, type_map, value_map, quote_function=None, expanded_from=None): + def __init__(self, embedding_map, value_map, quote_function=None, expanded_from=None): self.source = "" self.source_buffer = source.Buffer(self.source, "") - self.object_map, self.type_map, self.value_map = object_map, type_map, value_map + self.embedding_map = embedding_map + self.value_map = value_map self.quote_function = quote_function self.expanded_from = expanded_from self.diagnostics = [] @@ -134,8 +158,8 @@ class ASTSynthesizer: else: typ = type(value) - if typ in self.type_map: - instance_type, constructor_type = self.type_map[typ] + if self.embedding_map.has_type(typ): + instance_type, constructor_type = self.embedding_map.retrieve_type(typ) if hasattr(value, 'kernel_invariants') and \ value.kernel_invariants != instance_type.constant_attributes: @@ -170,7 +194,7 @@ class ASTSynthesizer: if hasattr(typ, 'artiq_builtin'): exception_id = 0 else: - exception_id = self.object_map.store(typ) + exception_id = self.embedding_map.store_object(typ) instance_type = builtins.TException("{}.{}".format(typ.__module__, typ.__qualname__), id=exception_id) @@ -183,7 +207,7 @@ class ASTSynthesizer: constructor_type.attributes['__objectid__'] = builtins.TInt32() instance_type.constructor = constructor_type - self.type_map[typ] = instance_type, constructor_type + self.embedding_map.store_type(typ, instance_type, constructor_type) if hasattr(value, 'kernel_invariants'): assert isinstance(value.kernel_invariants, set) @@ -531,9 +555,7 @@ class Stitcher: self.functions = {} - self.function_map = {} - self.object_map = ObjectMap() - self.type_map = {} + self.embedding_map = EmbeddingMap() self.value_map = defaultdict(lambda: []) def stitch_call(self, function, args, kwargs, callback=None): @@ -562,7 +584,7 @@ class Stitcher: # For every host class we embed, fill in the function slots # with their corresponding closures. - for instance_type, constructor_type in list(self.type_map.values()): + for instance_type, constructor_type in self.embedding_map.iter_types(): # Do we have any direct reference to a constructor? if len(self.value_map[constructor_type]) > 0: # Yes, use it. @@ -592,8 +614,7 @@ class Stitcher: def _synthesizer(self, expanded_from=None): return ASTSynthesizer(expanded_from=expanded_from, - object_map=self.object_map, - type_map=self.type_map, + embedding_map=self.embedding_map, value_map=self.value_map, quote_function=self._quote_function) @@ -635,7 +656,7 @@ class Stitcher: # Record the function in the function map so that LLVM IR generator # can handle quoting it. - self.function_map[function] = function_node.name + self.embedding_map.store_function(function, function_node.name) # Memoize the function type before typing it to handle recursive # invocations. @@ -803,7 +824,7 @@ class Stitcher: else: assert False - function_type = types.TRPC(ret_type, service=self.object_map.store(function)) + function_type = types.TRPC(ret_type, service=self.embedding_map.store_object(function)) self.functions[function] = function_type return function_type diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index b3c324fc2..355e3072d 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -18,11 +18,7 @@ class Source: self.engine = diagnostic.Engine(all_errors_are_fatal=True) else: self.engine = engine - - self.function_map = {} - self.object_map = None - self.type_map = {} - + self.embedding_map = None self.name, _ = os.path.splitext(os.path.basename(source_buffer.name)) asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine, @@ -46,9 +42,9 @@ class Source: class Module: def __init__(self, src, ref_period=1e-6): self.engine = src.engine - self.function_map = src.function_map - self.object_map = src.object_map - self.type_map = src.type_map + self.embedding_map = src.embedding_map + self.name = src.name + self.globals = src.globals int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine) inferencer = transforms.Inferencer(engine=self.engine) @@ -65,8 +61,6 @@ class Module: devirtualization = analyses.Devirtualization() interleaver = transforms.Interleaver(engine=self.engine) - self.name = src.name - self.globals = src.globals int_monomorphizer.visit(src.typedtree) inferencer.visit(src.typedtree) monomorphism_validator.visit(src.typedtree) @@ -84,7 +78,7 @@ class Module: """Compile the module to LLVM IR for the specified target.""" llvm_ir_generator = transforms.LLVMIRGenerator( engine=self.engine, module_name=self.name, target=target, - function_map=self.function_map, object_map=self.object_map, type_map=self.type_map) + embedding_map=self.embedding_map) return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True) def entry_point(self): diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index e12017fd2..a92378c55 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -122,12 +122,10 @@ class DebugInfoEmitter: class LLVMIRGenerator: - def __init__(self, engine, module_name, target, function_map, object_map, type_map): + def __init__(self, engine, module_name, target, embedding_map): self.engine = engine self.target = target - self.function_map = function_map - self.object_map = object_map - self.type_map = type_map + self.embedding_map = embedding_map self.llcontext = target.llcontext self.llmodule = ll.Module(context=self.llcontext, name=module_name) self.llmodule.triple = target.triple @@ -402,7 +400,7 @@ class LLVMIRGenerator: if any(functions): self.debug_info_emitter.finalize(functions[0].loc.source_buffer) - if attribute_writeback and self.object_map is not None: + if attribute_writeback and self.embedding_map is not None: self.emit_attribute_writeback() return self.llmodule @@ -410,15 +408,15 @@ class LLVMIRGenerator: def emit_attribute_writeback(self): llobjects = defaultdict(lambda: []) - for obj_id in self.object_map: - obj_ref = self.object_map.retrieve(obj_id) + for obj_id in self.embedding_map.iter_objects(): + obj_ref = self.embedding_map.retrieve_object(obj_id) if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType, pytypes.BuiltinFunctionType)): continue elif isinstance(obj_ref, type): - _, typ = self.type_map[obj_ref] + _, typ = self.embedding_map.retrieve_type(obj_ref) else: - typ, _ = self.type_map[type(obj_ref)] + typ, _ = self.embedding_map.retrieve_type(type(obj_ref)) llobject = self.llmodule.get_global("O.{}".format(obj_id)) if llobject is not None: @@ -1091,8 +1089,11 @@ class LLVMIRGenerator: llargs = [self.map(arg) for arg in insn.arguments()] llclosure = self.map(insn.target_function()) if insn.static_target_function is None: - llfun = self.llbuilder.extract_value(llclosure, 1, - name="fun.{}".format(llclosure.name)) + if isinstance(llclosure, ll.Constant): + name = "fun.{}".format(llclosure.constant[1].name) + else: + name = "fun.{}".format(llclosure.name) + llfun = self.llbuilder.extract_value(llclosure, 1, name=name) else: llfun = self.map(insn.static_target_function) llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun") @@ -1346,7 +1347,7 @@ class LLVMIRGenerator: llfields = [] for attr in typ.attributes: if attr == "__objectid__": - objectid = self.object_map.store(value) + objectid = self.embedding_map.store_object(value) llfields.append(ll.Constant(lli32, objectid)) assert llglobal is None @@ -1399,7 +1400,7 @@ class LLVMIRGenerator: # RPC and C functions have no runtime representation. return ll.Constant(llty, ll.Undefined) elif types.is_function(typ): - llfun = self.get_function(typ.find(), self.function_map[value]) + llfun = self.get_function(typ.find(), self.embedding_map.retrieve_function(value)) llclosure = ll.Constant(self.llty_of_type(typ), [ ll.Constant(llptr, ll.Undefined), llfun @@ -1416,14 +1417,8 @@ class LLVMIRGenerator: assert False def process_Quote(self, insn): - if insn.value in self.function_map and types.is_function(insn.type): - llfun = self.get_function(insn.type.find(), self.function_map[insn.value]) - llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) - llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name) - return llclosure - else: - assert self.object_map is not None - return self._quote(insn.value, insn.type, lambda: [repr(insn.value)]) + assert self.embedding_map is not None + return self._quote(insn.value, insn.type, lambda: [repr(insn.value)]) def process_Select(self, insn): return self.llbuilder.select(self.map(insn.condition()), diff --git a/artiq/coredevice/comm_dummy.py b/artiq/coredevice/comm_dummy.py index c8626bb17..0d85ba7d8 100644 --- a/artiq/coredevice/comm_dummy.py +++ b/artiq/coredevice/comm_dummy.py @@ -14,7 +14,7 @@ class Comm: def run(self): pass - def serve(self, object_map, symbolizer): + def serve(self, embedding_map, symbolizer, demangler): pass def check_ident(self): diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index 2687ef38e..5aa4db687 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -309,13 +309,13 @@ class CommGeneric: _rpc_sentinel = object() # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. - def _receive_rpc_value(self, object_map): + def _receive_rpc_value(self, embedding_map): tag = chr(self._read_int8()) if tag == "\x00": return self._rpc_sentinel elif tag == "t": length = self._read_int8() - return tuple(self._receive_rpc_value(object_map) for _ in range(length)) + return tuple(self._receive_rpc_value(embedding_map) for _ in range(length)) elif tag == "n": return None elif tag == "b": @@ -334,25 +334,25 @@ class CommGeneric: return self._read_string() elif tag == "l": length = self._read_int32() - return [self._receive_rpc_value(object_map) for _ in range(length)] + return [self._receive_rpc_value(embedding_map) for _ in range(length)] elif tag == "r": - start = self._receive_rpc_value(object_map) - stop = self._receive_rpc_value(object_map) - step = self._receive_rpc_value(object_map) + start = self._receive_rpc_value(embedding_map) + stop = self._receive_rpc_value(embedding_map) + step = self._receive_rpc_value(embedding_map) return range(start, stop, step) elif tag == "k": name = self._read_string() - value = self._receive_rpc_value(object_map) + value = self._receive_rpc_value(embedding_map) return RPCKeyword(name, value) elif tag == "O": - return object_map.retrieve(self._read_int32()) + return embedding_map.retrieve_object(self._read_int32()) else: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) - def _receive_rpc_args(self, object_map): + def _receive_rpc_args(self, embedding_map): args, kwargs = [], {} while True: - value = self._receive_rpc_value(object_map) + value = self._receive_rpc_value(embedding_map) if value is self._rpc_sentinel: return args, kwargs elif isinstance(value, RPCKeyword): @@ -440,14 +440,14 @@ class CommGeneric: else: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) - def _serve_rpc(self, object_map): + def _serve_rpc(self, embedding_map): service_id = self._read_int32() if service_id == 0: service = lambda obj, attr, value: setattr(obj, attr, value) else: - service = object_map.retrieve(service_id) + service = embedding_map.retrieve_object(service_id) - args, kwargs = self._receive_rpc_args(object_map) + args, kwargs = self._receive_rpc_args(embedding_map) return_tags = self._read_bytes() logger.debug("rpc service: [%d]%r %r %r -> %s", service_id, service, args, kwargs, return_tags) @@ -483,7 +483,7 @@ class CommGeneric: hasattr(exn, 'artiq_builtin'): self._write_string("0:{}".format(exn_type.__name__)) else: - exn_id = object_map.store(exn_type) + exn_id = embedding_map.store_object(exn_type) self._write_string("{}:{}.{}".format(exn_id, exn_type.__module__, exn_type.__qualname__)) self._write_string(str(exn)) @@ -504,7 +504,7 @@ class CommGeneric: self._write_flush() - def _serve_exception(self, object_map, symbolizer, demangler): + def _serve_exception(self, embedding_map, symbolizer, demangler): name = self._read_string() message = self._read_string() params = [self._read_int64() for _ in range(3)] @@ -523,19 +523,19 @@ class CommGeneric: if core_exn.id == 0: python_exn_type = getattr(exceptions, core_exn.name.split('.')[-1]) else: - python_exn_type = object_map.retrieve(core_exn.id) + python_exn_type = embedding_map.retrieve_object(core_exn.id) python_exn = python_exn_type(message.format(*params)) python_exn.artiq_core_exception = core_exn raise python_exn - def serve(self, object_map, symbolizer, demangler): + def serve(self, embedding_map, symbolizer, demangler): while True: self._read_header() if self._read_type == _D2HMsgType.RPC_REQUEST: - self._serve_rpc(object_map) + self._serve_rpc(embedding_map) elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION: - self._serve_exception(object_map, symbolizer, demangler) + self._serve_exception(embedding_map, symbolizer, demangler) elif self._read_type == _D2HMsgType.WATCHDOG_EXPIRED: raise exceptions.WatchdogExpired elif self._read_type == _D2HMsgType.CLOCK_FAILURE: diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index 035dac8e8..2d692afd2 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -91,7 +91,7 @@ class Core: library = target.compile_and_link([module]) stripped_library = target.strip(library) - return stitcher.object_map, stripped_library, \ + return stitcher.embedding_map, stripped_library, \ lambda addresses: target.symbolize(library, addresses), \ lambda symbols: target.demangle(symbols) except diagnostic.Error as error: @@ -103,7 +103,7 @@ class Core: nonlocal result result = new_result - object_map, kernel_library, symbolizer, demangler = \ + embedding_map, kernel_library, symbolizer, demangler = \ self.compile(function, args, kwargs, set_result) if self.first_run: @@ -113,7 +113,7 @@ class Core: self.comm.load(kernel_library) self.comm.run() - self.comm.serve(object_map, symbolizer, demangler) + self.comm.serve(embedding_map, symbolizer, demangler) return result diff --git a/artiq/frontend/artiq_run.py b/artiq/frontend/artiq_run.py index 15450b907..ef3004ada 100755 --- a/artiq/frontend/artiq_run.py +++ b/artiq/frontend/artiq_run.py @@ -16,7 +16,7 @@ from artiq.language.environment import EnvExperiment, ProcessArgumentManager from artiq.master.databases import DeviceDB, DatasetDB from artiq.master.worker_db import DeviceManager, DatasetManager from artiq.coredevice.core import CompileError, host_only -from artiq.compiler.embedding import ObjectMap +from artiq.compiler.embedding import EmbeddingMap from artiq.compiler.targets import OR1KTarget from artiq.tools import * @@ -29,19 +29,19 @@ class StubObject: pass -class StubObjectMap: +class StubEmbeddingMap: def __init__(self): stub_object = StubObject() - self.forward_map = defaultdict(lambda: stub_object) - self.forward_map[1] = lambda _: None # return RPC - self.next_id = -1 + self.object_forward_map = defaultdict(lambda: stub_object) + self.object_forward_map[1] = lambda _: None # return RPC + self.object_current_id = -1 - def retrieve(self, object_id): - return self.forward_map[object_id] + def retrieve_object(self, object_id): + return self.object_forward_map[object_id] - def store(self, value): - self.forward_map[self.next_id] = value - self.next_id -= 1 + def store_object(self, value): + self.object_forward_map[self.object_current_id] = value + self.object_current_id -= 1 class FileRunner(EnvExperiment): @@ -55,7 +55,7 @@ class FileRunner(EnvExperiment): self.core.comm.load(kernel_library) self.core.comm.run() - self.core.comm.serve(StubObjectMap(), + self.core.comm.serve(StubEmbeddingMap(), lambda addresses: self.target.symbolize(kernel_library, addresses))