embedding: refactor some more.

This commit is contained in:
whitequark 2016-05-16 14:30:21 +00:00
parent d085d5a372
commit 640022122b
7 changed files with 107 additions and 97 deletions

View File

@ -22,37 +22,61 @@ from .transforms.asttyped_rewriter import LocalExtractor
def coredevice_print(x): print(x) def coredevice_print(x): print(x)
class ObjectMap: class EmbeddingMap:
def __init__(self): def __init__(self):
self.current_key = 0 self.object_current_key = 0
self.forward_map = {} self.object_forward_map = {}
self.reverse_map = {} self.object_reverse_map = {}
self.type_map = {}
self.function_map = {}
def store(self, obj_ref): # Types
def store_type(self, typ, instance_type, constructor_type):
self.type_map[typ] = (instance_type, constructor_type)
def retrieve_type(self, typ):
return self.type_map[typ]
def has_type(self, typ):
return typ in self.type_map
def iter_types(self):
return self.type_map.values()
# Functions
def store_function(self, function, ir_function_name):
self.function_map[function] = ir_function_name
def retrieve_function(self, function):
return self.function_map[function]
# Objects
def store_object(self, obj_ref):
obj_id = id(obj_ref) obj_id = id(obj_ref)
if obj_id in self.reverse_map: if obj_id in self.object_reverse_map:
return self.reverse_map[obj_id] return self.object_reverse_map[obj_id]
self.current_key += 1 self.object_current_key += 1
self.forward_map[self.current_key] = obj_ref self.object_forward_map[self.object_current_key] = obj_ref
self.reverse_map[obj_id] = self.current_key self.object_reverse_map[obj_id] = self.object_current_key
return self.current_key return self.object_current_key
def retrieve(self, obj_key): def retrieve_object(self, obj_key):
return self.forward_map[obj_key] return self.object_forward_map[obj_key]
def iter_objects(self):
return self.object_forward_map.keys()
def has_rpc(self): def has_rpc(self):
return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x), return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x),
self.forward_map.values())) self.object_forward_map.values()))
def __iter__(self):
return iter(self.forward_map.keys())
class ASTSynthesizer: class ASTSynthesizer:
def __init__(self, object_map, type_map, value_map, quote_function=None, expanded_from=None): def __init__(self, embedding_map, value_map, quote_function=None, expanded_from=None):
self.source = "" self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>") self.source_buffer = source.Buffer(self.source, "<synthesized>")
self.object_map, self.type_map, self.value_map = object_map, type_map, value_map self.embedding_map = embedding_map
self.value_map = value_map
self.quote_function = quote_function self.quote_function = quote_function
self.expanded_from = expanded_from self.expanded_from = expanded_from
self.diagnostics = [] self.diagnostics = []
@ -134,8 +158,8 @@ class ASTSynthesizer:
else: else:
typ = type(value) typ = type(value)
if typ in self.type_map: if self.embedding_map.has_type(typ):
instance_type, constructor_type = self.type_map[typ] instance_type, constructor_type = self.embedding_map.retrieve_type(typ)
if hasattr(value, 'kernel_invariants') and \ if hasattr(value, 'kernel_invariants') and \
value.kernel_invariants != instance_type.constant_attributes: value.kernel_invariants != instance_type.constant_attributes:
@ -170,7 +194,7 @@ class ASTSynthesizer:
if hasattr(typ, 'artiq_builtin'): if hasattr(typ, 'artiq_builtin'):
exception_id = 0 exception_id = 0
else: else:
exception_id = self.object_map.store(typ) exception_id = self.embedding_map.store_object(typ)
instance_type = builtins.TException("{}.{}".format(typ.__module__, instance_type = builtins.TException("{}.{}".format(typ.__module__,
typ.__qualname__), typ.__qualname__),
id=exception_id) id=exception_id)
@ -183,7 +207,7 @@ class ASTSynthesizer:
constructor_type.attributes['__objectid__'] = builtins.TInt32() constructor_type.attributes['__objectid__'] = builtins.TInt32()
instance_type.constructor = constructor_type instance_type.constructor = constructor_type
self.type_map[typ] = instance_type, constructor_type self.embedding_map.store_type(typ, instance_type, constructor_type)
if hasattr(value, 'kernel_invariants'): if hasattr(value, 'kernel_invariants'):
assert isinstance(value.kernel_invariants, set) assert isinstance(value.kernel_invariants, set)
@ -531,9 +555,7 @@ class Stitcher:
self.functions = {} self.functions = {}
self.function_map = {} self.embedding_map = EmbeddingMap()
self.object_map = ObjectMap()
self.type_map = {}
self.value_map = defaultdict(lambda: []) self.value_map = defaultdict(lambda: [])
def stitch_call(self, function, args, kwargs, callback=None): def stitch_call(self, function, args, kwargs, callback=None):
@ -562,7 +584,7 @@ class Stitcher:
# For every host class we embed, fill in the function slots # For every host class we embed, fill in the function slots
# with their corresponding closures. # with their corresponding closures.
for instance_type, constructor_type in list(self.type_map.values()): for instance_type, constructor_type in self.embedding_map.iter_types():
# Do we have any direct reference to a constructor? # Do we have any direct reference to a constructor?
if len(self.value_map[constructor_type]) > 0: if len(self.value_map[constructor_type]) > 0:
# Yes, use it. # Yes, use it.
@ -592,8 +614,7 @@ class Stitcher:
def _synthesizer(self, expanded_from=None): def _synthesizer(self, expanded_from=None):
return ASTSynthesizer(expanded_from=expanded_from, return ASTSynthesizer(expanded_from=expanded_from,
object_map=self.object_map, embedding_map=self.embedding_map,
type_map=self.type_map,
value_map=self.value_map, value_map=self.value_map,
quote_function=self._quote_function) quote_function=self._quote_function)
@ -635,7 +656,7 @@ class Stitcher:
# Record the function in the function map so that LLVM IR generator # Record the function in the function map so that LLVM IR generator
# can handle quoting it. # can handle quoting it.
self.function_map[function] = function_node.name self.embedding_map.store_function(function, function_node.name)
# Memoize the function type before typing it to handle recursive # Memoize the function type before typing it to handle recursive
# invocations. # invocations.
@ -803,7 +824,7 @@ class Stitcher:
else: else:
assert False assert False
function_type = types.TRPC(ret_type, service=self.object_map.store(function)) function_type = types.TRPC(ret_type, service=self.embedding_map.store_object(function))
self.functions[function] = function_type self.functions[function] = function_type
return function_type return function_type

View File

@ -18,11 +18,7 @@ class Source:
self.engine = diagnostic.Engine(all_errors_are_fatal=True) self.engine = diagnostic.Engine(all_errors_are_fatal=True)
else: else:
self.engine = engine self.engine = engine
self.embedding_map = None
self.function_map = {}
self.object_map = None
self.type_map = {}
self.name, _ = os.path.splitext(os.path.basename(source_buffer.name)) self.name, _ = os.path.splitext(os.path.basename(source_buffer.name))
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine, asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine,
@ -46,9 +42,9 @@ class Source:
class Module: class Module:
def __init__(self, src, ref_period=1e-6): def __init__(self, src, ref_period=1e-6):
self.engine = src.engine self.engine = src.engine
self.function_map = src.function_map self.embedding_map = src.embedding_map
self.object_map = src.object_map self.name = src.name
self.type_map = src.type_map self.globals = src.globals
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine) int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
inferencer = transforms.Inferencer(engine=self.engine) inferencer = transforms.Inferencer(engine=self.engine)
@ -65,8 +61,6 @@ class Module:
devirtualization = analyses.Devirtualization() devirtualization = analyses.Devirtualization()
interleaver = transforms.Interleaver(engine=self.engine) interleaver = transforms.Interleaver(engine=self.engine)
self.name = src.name
self.globals = src.globals
int_monomorphizer.visit(src.typedtree) int_monomorphizer.visit(src.typedtree)
inferencer.visit(src.typedtree) inferencer.visit(src.typedtree)
monomorphism_validator.visit(src.typedtree) monomorphism_validator.visit(src.typedtree)
@ -84,7 +78,7 @@ class Module:
"""Compile the module to LLVM IR for the specified target.""" """Compile the module to LLVM IR for the specified target."""
llvm_ir_generator = transforms.LLVMIRGenerator( llvm_ir_generator = transforms.LLVMIRGenerator(
engine=self.engine, module_name=self.name, target=target, engine=self.engine, module_name=self.name, target=target,
function_map=self.function_map, object_map=self.object_map, type_map=self.type_map) embedding_map=self.embedding_map)
return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True) return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True)
def entry_point(self): def entry_point(self):

View File

@ -122,12 +122,10 @@ class DebugInfoEmitter:
class LLVMIRGenerator: class LLVMIRGenerator:
def __init__(self, engine, module_name, target, function_map, object_map, type_map): def __init__(self, engine, module_name, target, embedding_map):
self.engine = engine self.engine = engine
self.target = target self.target = target
self.function_map = function_map self.embedding_map = embedding_map
self.object_map = object_map
self.type_map = type_map
self.llcontext = target.llcontext self.llcontext = target.llcontext
self.llmodule = ll.Module(context=self.llcontext, name=module_name) self.llmodule = ll.Module(context=self.llcontext, name=module_name)
self.llmodule.triple = target.triple self.llmodule.triple = target.triple
@ -402,7 +400,7 @@ class LLVMIRGenerator:
if any(functions): if any(functions):
self.debug_info_emitter.finalize(functions[0].loc.source_buffer) self.debug_info_emitter.finalize(functions[0].loc.source_buffer)
if attribute_writeback and self.object_map is not None: if attribute_writeback and self.embedding_map is not None:
self.emit_attribute_writeback() self.emit_attribute_writeback()
return self.llmodule return self.llmodule
@ -410,15 +408,15 @@ class LLVMIRGenerator:
def emit_attribute_writeback(self): def emit_attribute_writeback(self):
llobjects = defaultdict(lambda: []) llobjects = defaultdict(lambda: [])
for obj_id in self.object_map: for obj_id in self.embedding_map.iter_objects():
obj_ref = self.object_map.retrieve(obj_id) obj_ref = self.embedding_map.retrieve_object(obj_id)
if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType, if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType,
pytypes.BuiltinFunctionType)): pytypes.BuiltinFunctionType)):
continue continue
elif isinstance(obj_ref, type): elif isinstance(obj_ref, type):
_, typ = self.type_map[obj_ref] _, typ = self.embedding_map.retrieve_type(obj_ref)
else: else:
typ, _ = self.type_map[type(obj_ref)] typ, _ = self.embedding_map.retrieve_type(type(obj_ref))
llobject = self.llmodule.get_global("O.{}".format(obj_id)) llobject = self.llmodule.get_global("O.{}".format(obj_id))
if llobject is not None: if llobject is not None:
@ -1091,8 +1089,11 @@ class LLVMIRGenerator:
llargs = [self.map(arg) for arg in insn.arguments()] llargs = [self.map(arg) for arg in insn.arguments()]
llclosure = self.map(insn.target_function()) llclosure = self.map(insn.target_function())
if insn.static_target_function is None: if insn.static_target_function is None:
llfun = self.llbuilder.extract_value(llclosure, 1, if isinstance(llclosure, ll.Constant):
name="fun.{}".format(llclosure.name)) name = "fun.{}".format(llclosure.constant[1].name)
else:
name = "fun.{}".format(llclosure.name)
llfun = self.llbuilder.extract_value(llclosure, 1, name=name)
else: else:
llfun = self.map(insn.static_target_function) llfun = self.map(insn.static_target_function)
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun") llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun")
@ -1346,7 +1347,7 @@ class LLVMIRGenerator:
llfields = [] llfields = []
for attr in typ.attributes: for attr in typ.attributes:
if attr == "__objectid__": if attr == "__objectid__":
objectid = self.object_map.store(value) objectid = self.embedding_map.store_object(value)
llfields.append(ll.Constant(lli32, objectid)) llfields.append(ll.Constant(lli32, objectid))
assert llglobal is None assert llglobal is None
@ -1399,7 +1400,7 @@ class LLVMIRGenerator:
# RPC and C functions have no runtime representation. # RPC and C functions have no runtime representation.
return ll.Constant(llty, ll.Undefined) return ll.Constant(llty, ll.Undefined)
elif types.is_function(typ): elif types.is_function(typ):
llfun = self.get_function(typ.find(), self.function_map[value]) llfun = self.get_function(typ.find(), self.embedding_map.retrieve_function(value))
llclosure = ll.Constant(self.llty_of_type(typ), [ llclosure = ll.Constant(self.llty_of_type(typ), [
ll.Constant(llptr, ll.Undefined), ll.Constant(llptr, ll.Undefined),
llfun llfun
@ -1416,14 +1417,8 @@ class LLVMIRGenerator:
assert False assert False
def process_Quote(self, insn): def process_Quote(self, insn):
if insn.value in self.function_map and types.is_function(insn.type): assert self.embedding_map is not None
llfun = self.get_function(insn.type.find(), self.function_map[insn.value]) return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
return llclosure
else:
assert self.object_map is not None
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
def process_Select(self, insn): def process_Select(self, insn):
return self.llbuilder.select(self.map(insn.condition()), return self.llbuilder.select(self.map(insn.condition()),

View File

@ -14,7 +14,7 @@ class Comm:
def run(self): def run(self):
pass pass
def serve(self, object_map, symbolizer): def serve(self, embedding_map, symbolizer, demangler):
pass pass
def check_ident(self): def check_ident(self):

View File

@ -309,13 +309,13 @@ class CommGeneric:
_rpc_sentinel = object() _rpc_sentinel = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, object_map): def _receive_rpc_value(self, embedding_map):
tag = chr(self._read_int8()) tag = chr(self._read_int8())
if tag == "\x00": if tag == "\x00":
return self._rpc_sentinel return self._rpc_sentinel
elif tag == "t": elif tag == "t":
length = self._read_int8() length = self._read_int8()
return tuple(self._receive_rpc_value(object_map) for _ in range(length)) return tuple(self._receive_rpc_value(embedding_map) for _ in range(length))
elif tag == "n": elif tag == "n":
return None return None
elif tag == "b": elif tag == "b":
@ -334,25 +334,25 @@ class CommGeneric:
return self._read_string() return self._read_string()
elif tag == "l": elif tag == "l":
length = self._read_int32() length = self._read_int32()
return [self._receive_rpc_value(object_map) for _ in range(length)] return [self._receive_rpc_value(embedding_map) for _ in range(length)]
elif tag == "r": elif tag == "r":
start = self._receive_rpc_value(object_map) start = self._receive_rpc_value(embedding_map)
stop = self._receive_rpc_value(object_map) stop = self._receive_rpc_value(embedding_map)
step = self._receive_rpc_value(object_map) step = self._receive_rpc_value(embedding_map)
return range(start, stop, step) return range(start, stop, step)
elif tag == "k": elif tag == "k":
name = self._read_string() name = self._read_string()
value = self._receive_rpc_value(object_map) value = self._receive_rpc_value(embedding_map)
return RPCKeyword(name, value) return RPCKeyword(name, value)
elif tag == "O": elif tag == "O":
return object_map.retrieve(self._read_int32()) return embedding_map.retrieve_object(self._read_int32())
else: else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, object_map): def _receive_rpc_args(self, embedding_map):
args, kwargs = [], {} args, kwargs = [], {}
while True: while True:
value = self._receive_rpc_value(object_map) value = self._receive_rpc_value(embedding_map)
if value is self._rpc_sentinel: if value is self._rpc_sentinel:
return args, kwargs return args, kwargs
elif isinstance(value, RPCKeyword): elif isinstance(value, RPCKeyword):
@ -440,14 +440,14 @@ class CommGeneric:
else: else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _serve_rpc(self, object_map): def _serve_rpc(self, embedding_map):
service_id = self._read_int32() service_id = self._read_int32()
if service_id == 0: if service_id == 0:
service = lambda obj, attr, value: setattr(obj, attr, value) service = lambda obj, attr, value: setattr(obj, attr, value)
else: else:
service = object_map.retrieve(service_id) service = embedding_map.retrieve_object(service_id)
args, kwargs = self._receive_rpc_args(object_map) args, kwargs = self._receive_rpc_args(embedding_map)
return_tags = self._read_bytes() return_tags = self._read_bytes()
logger.debug("rpc service: [%d]%r %r %r -> %s", service_id, service, args, kwargs, return_tags) logger.debug("rpc service: [%d]%r %r %r -> %s", service_id, service, args, kwargs, return_tags)
@ -483,7 +483,7 @@ class CommGeneric:
hasattr(exn, 'artiq_builtin'): hasattr(exn, 'artiq_builtin'):
self._write_string("0:{}".format(exn_type.__name__)) self._write_string("0:{}".format(exn_type.__name__))
else: else:
exn_id = object_map.store(exn_type) exn_id = embedding_map.store_object(exn_type)
self._write_string("{}:{}.{}".format(exn_id, self._write_string("{}:{}.{}".format(exn_id,
exn_type.__module__, exn_type.__qualname__)) exn_type.__module__, exn_type.__qualname__))
self._write_string(str(exn)) self._write_string(str(exn))
@ -504,7 +504,7 @@ class CommGeneric:
self._write_flush() self._write_flush()
def _serve_exception(self, object_map, symbolizer, demangler): def _serve_exception(self, embedding_map, symbolizer, demangler):
name = self._read_string() name = self._read_string()
message = self._read_string() message = self._read_string()
params = [self._read_int64() for _ in range(3)] params = [self._read_int64() for _ in range(3)]
@ -523,19 +523,19 @@ class CommGeneric:
if core_exn.id == 0: if core_exn.id == 0:
python_exn_type = getattr(exceptions, core_exn.name.split('.')[-1]) python_exn_type = getattr(exceptions, core_exn.name.split('.')[-1])
else: else:
python_exn_type = object_map.retrieve(core_exn.id) python_exn_type = embedding_map.retrieve_object(core_exn.id)
python_exn = python_exn_type(message.format(*params)) python_exn = python_exn_type(message.format(*params))
python_exn.artiq_core_exception = core_exn python_exn.artiq_core_exception = core_exn
raise python_exn raise python_exn
def serve(self, object_map, symbolizer, demangler): def serve(self, embedding_map, symbolizer, demangler):
while True: while True:
self._read_header() self._read_header()
if self._read_type == _D2HMsgType.RPC_REQUEST: if self._read_type == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(object_map) self._serve_rpc(embedding_map)
elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION: elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception(object_map, symbolizer, demangler) self._serve_exception(embedding_map, symbolizer, demangler)
elif self._read_type == _D2HMsgType.WATCHDOG_EXPIRED: elif self._read_type == _D2HMsgType.WATCHDOG_EXPIRED:
raise exceptions.WatchdogExpired raise exceptions.WatchdogExpired
elif self._read_type == _D2HMsgType.CLOCK_FAILURE: elif self._read_type == _D2HMsgType.CLOCK_FAILURE:

View File

@ -91,7 +91,7 @@ class Core:
library = target.compile_and_link([module]) library = target.compile_and_link([module])
stripped_library = target.strip(library) stripped_library = target.strip(library)
return stitcher.object_map, stripped_library, \ return stitcher.embedding_map, stripped_library, \
lambda addresses: target.symbolize(library, addresses), \ lambda addresses: target.symbolize(library, addresses), \
lambda symbols: target.demangle(symbols) lambda symbols: target.demangle(symbols)
except diagnostic.Error as error: except diagnostic.Error as error:
@ -103,7 +103,7 @@ class Core:
nonlocal result nonlocal result
result = new_result result = new_result
object_map, kernel_library, symbolizer, demangler = \ embedding_map, kernel_library, symbolizer, demangler = \
self.compile(function, args, kwargs, set_result) self.compile(function, args, kwargs, set_result)
if self.first_run: if self.first_run:
@ -113,7 +113,7 @@ class Core:
self.comm.load(kernel_library) self.comm.load(kernel_library)
self.comm.run() self.comm.run()
self.comm.serve(object_map, symbolizer, demangler) self.comm.serve(embedding_map, symbolizer, demangler)
return result return result

View File

@ -16,7 +16,7 @@ from artiq.language.environment import EnvExperiment, ProcessArgumentManager
from artiq.master.databases import DeviceDB, DatasetDB from artiq.master.databases import DeviceDB, DatasetDB
from artiq.master.worker_db import DeviceManager, DatasetManager from artiq.master.worker_db import DeviceManager, DatasetManager
from artiq.coredevice.core import CompileError, host_only from artiq.coredevice.core import CompileError, host_only
from artiq.compiler.embedding import ObjectMap from artiq.compiler.embedding import EmbeddingMap
from artiq.compiler.targets import OR1KTarget from artiq.compiler.targets import OR1KTarget
from artiq.tools import * from artiq.tools import *
@ -29,19 +29,19 @@ class StubObject:
pass pass
class StubObjectMap: class StubEmbeddingMap:
def __init__(self): def __init__(self):
stub_object = StubObject() stub_object = StubObject()
self.forward_map = defaultdict(lambda: stub_object) self.object_forward_map = defaultdict(lambda: stub_object)
self.forward_map[1] = lambda _: None # return RPC self.object_forward_map[1] = lambda _: None # return RPC
self.next_id = -1 self.object_current_id = -1
def retrieve(self, object_id): def retrieve_object(self, object_id):
return self.forward_map[object_id] return self.object_forward_map[object_id]
def store(self, value): def store_object(self, value):
self.forward_map[self.next_id] = value self.object_forward_map[self.object_current_id] = value
self.next_id -= 1 self.object_current_id -= 1
class FileRunner(EnvExperiment): class FileRunner(EnvExperiment):
@ -55,7 +55,7 @@ class FileRunner(EnvExperiment):
self.core.comm.load(kernel_library) self.core.comm.load(kernel_library)
self.core.comm.run() self.core.comm.run()
self.core.comm.serve(StubObjectMap(), self.core.comm.serve(StubEmbeddingMap(),
lambda addresses: self.target.symbolize(kernel_library, addresses)) lambda addresses: self.target.symbolize(kernel_library, addresses))