embedding: refactor some more.

This commit is contained in:
whitequark 2016-05-16 14:30:21 +00:00
parent d085d5a372
commit 640022122b
7 changed files with 107 additions and 97 deletions

View File

@ -22,37 +22,61 @@ from .transforms.asttyped_rewriter import LocalExtractor
def coredevice_print(x): print(x)
class ObjectMap:
class EmbeddingMap:
def __init__(self):
self.current_key = 0
self.forward_map = {}
self.reverse_map = {}
self.object_current_key = 0
self.object_forward_map = {}
self.object_reverse_map = {}
self.type_map = {}
self.function_map = {}
def store(self, obj_ref):
# Types
def store_type(self, typ, instance_type, constructor_type):
self.type_map[typ] = (instance_type, constructor_type)
def retrieve_type(self, typ):
return self.type_map[typ]
def has_type(self, typ):
return typ in self.type_map
def iter_types(self):
return self.type_map.values()
# Functions
def store_function(self, function, ir_function_name):
self.function_map[function] = ir_function_name
def retrieve_function(self, function):
return self.function_map[function]
# Objects
def store_object(self, obj_ref):
obj_id = id(obj_ref)
if obj_id in self.reverse_map:
return self.reverse_map[obj_id]
if obj_id in self.object_reverse_map:
return self.object_reverse_map[obj_id]
self.current_key += 1
self.forward_map[self.current_key] = obj_ref
self.reverse_map[obj_id] = self.current_key
return self.current_key
self.object_current_key += 1
self.object_forward_map[self.object_current_key] = obj_ref
self.object_reverse_map[obj_id] = self.object_current_key
return self.object_current_key
def retrieve(self, obj_key):
return self.forward_map[obj_key]
def retrieve_object(self, obj_key):
return self.object_forward_map[obj_key]
def iter_objects(self):
return self.object_forward_map.keys()
def has_rpc(self):
return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x),
self.forward_map.values()))
def __iter__(self):
return iter(self.forward_map.keys())
self.object_forward_map.values()))
class ASTSynthesizer:
def __init__(self, object_map, type_map, value_map, quote_function=None, expanded_from=None):
def __init__(self, embedding_map, value_map, quote_function=None, expanded_from=None):
self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>")
self.object_map, self.type_map, self.value_map = object_map, type_map, value_map
self.embedding_map = embedding_map
self.value_map = value_map
self.quote_function = quote_function
self.expanded_from = expanded_from
self.diagnostics = []
@ -134,8 +158,8 @@ class ASTSynthesizer:
else:
typ = type(value)
if typ in self.type_map:
instance_type, constructor_type = self.type_map[typ]
if self.embedding_map.has_type(typ):
instance_type, constructor_type = self.embedding_map.retrieve_type(typ)
if hasattr(value, 'kernel_invariants') and \
value.kernel_invariants != instance_type.constant_attributes:
@ -170,7 +194,7 @@ class ASTSynthesizer:
if hasattr(typ, 'artiq_builtin'):
exception_id = 0
else:
exception_id = self.object_map.store(typ)
exception_id = self.embedding_map.store_object(typ)
instance_type = builtins.TException("{}.{}".format(typ.__module__,
typ.__qualname__),
id=exception_id)
@ -183,7 +207,7 @@ class ASTSynthesizer:
constructor_type.attributes['__objectid__'] = builtins.TInt32()
instance_type.constructor = constructor_type
self.type_map[typ] = instance_type, constructor_type
self.embedding_map.store_type(typ, instance_type, constructor_type)
if hasattr(value, 'kernel_invariants'):
assert isinstance(value.kernel_invariants, set)
@ -531,9 +555,7 @@ class Stitcher:
self.functions = {}
self.function_map = {}
self.object_map = ObjectMap()
self.type_map = {}
self.embedding_map = EmbeddingMap()
self.value_map = defaultdict(lambda: [])
def stitch_call(self, function, args, kwargs, callback=None):
@ -562,7 +584,7 @@ class Stitcher:
# For every host class we embed, fill in the function slots
# with their corresponding closures.
for instance_type, constructor_type in list(self.type_map.values()):
for instance_type, constructor_type in self.embedding_map.iter_types():
# Do we have any direct reference to a constructor?
if len(self.value_map[constructor_type]) > 0:
# Yes, use it.
@ -592,8 +614,7 @@ class Stitcher:
def _synthesizer(self, expanded_from=None):
return ASTSynthesizer(expanded_from=expanded_from,
object_map=self.object_map,
type_map=self.type_map,
embedding_map=self.embedding_map,
value_map=self.value_map,
quote_function=self._quote_function)
@ -635,7 +656,7 @@ class Stitcher:
# Record the function in the function map so that LLVM IR generator
# can handle quoting it.
self.function_map[function] = function_node.name
self.embedding_map.store_function(function, function_node.name)
# Memoize the function type before typing it to handle recursive
# invocations.
@ -803,7 +824,7 @@ class Stitcher:
else:
assert False
function_type = types.TRPC(ret_type, service=self.object_map.store(function))
function_type = types.TRPC(ret_type, service=self.embedding_map.store_object(function))
self.functions[function] = function_type
return function_type

View File

@ -18,11 +18,7 @@ class Source:
self.engine = diagnostic.Engine(all_errors_are_fatal=True)
else:
self.engine = engine
self.function_map = {}
self.object_map = None
self.type_map = {}
self.embedding_map = None
self.name, _ = os.path.splitext(os.path.basename(source_buffer.name))
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine,
@ -46,9 +42,9 @@ class Source:
class Module:
def __init__(self, src, ref_period=1e-6):
self.engine = src.engine
self.function_map = src.function_map
self.object_map = src.object_map
self.type_map = src.type_map
self.embedding_map = src.embedding_map
self.name = src.name
self.globals = src.globals
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
inferencer = transforms.Inferencer(engine=self.engine)
@ -65,8 +61,6 @@ class Module:
devirtualization = analyses.Devirtualization()
interleaver = transforms.Interleaver(engine=self.engine)
self.name = src.name
self.globals = src.globals
int_monomorphizer.visit(src.typedtree)
inferencer.visit(src.typedtree)
monomorphism_validator.visit(src.typedtree)
@ -84,7 +78,7 @@ class Module:
"""Compile the module to LLVM IR for the specified target."""
llvm_ir_generator = transforms.LLVMIRGenerator(
engine=self.engine, module_name=self.name, target=target,
function_map=self.function_map, object_map=self.object_map, type_map=self.type_map)
embedding_map=self.embedding_map)
return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True)
def entry_point(self):

View File

@ -122,12 +122,10 @@ class DebugInfoEmitter:
class LLVMIRGenerator:
def __init__(self, engine, module_name, target, function_map, object_map, type_map):
def __init__(self, engine, module_name, target, embedding_map):
self.engine = engine
self.target = target
self.function_map = function_map
self.object_map = object_map
self.type_map = type_map
self.embedding_map = embedding_map
self.llcontext = target.llcontext
self.llmodule = ll.Module(context=self.llcontext, name=module_name)
self.llmodule.triple = target.triple
@ -402,7 +400,7 @@ class LLVMIRGenerator:
if any(functions):
self.debug_info_emitter.finalize(functions[0].loc.source_buffer)
if attribute_writeback and self.object_map is not None:
if attribute_writeback and self.embedding_map is not None:
self.emit_attribute_writeback()
return self.llmodule
@ -410,15 +408,15 @@ class LLVMIRGenerator:
def emit_attribute_writeback(self):
llobjects = defaultdict(lambda: [])
for obj_id in self.object_map:
obj_ref = self.object_map.retrieve(obj_id)
for obj_id in self.embedding_map.iter_objects():
obj_ref = self.embedding_map.retrieve_object(obj_id)
if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType,
pytypes.BuiltinFunctionType)):
continue
elif isinstance(obj_ref, type):
_, typ = self.type_map[obj_ref]
_, typ = self.embedding_map.retrieve_type(obj_ref)
else:
typ, _ = self.type_map[type(obj_ref)]
typ, _ = self.embedding_map.retrieve_type(type(obj_ref))
llobject = self.llmodule.get_global("O.{}".format(obj_id))
if llobject is not None:
@ -1091,8 +1089,11 @@ class LLVMIRGenerator:
llargs = [self.map(arg) for arg in insn.arguments()]
llclosure = self.map(insn.target_function())
if insn.static_target_function is None:
llfun = self.llbuilder.extract_value(llclosure, 1,
name="fun.{}".format(llclosure.name))
if isinstance(llclosure, ll.Constant):
name = "fun.{}".format(llclosure.constant[1].name)
else:
name = "fun.{}".format(llclosure.name)
llfun = self.llbuilder.extract_value(llclosure, 1, name=name)
else:
llfun = self.map(insn.static_target_function)
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun")
@ -1346,7 +1347,7 @@ class LLVMIRGenerator:
llfields = []
for attr in typ.attributes:
if attr == "__objectid__":
objectid = self.object_map.store(value)
objectid = self.embedding_map.store_object(value)
llfields.append(ll.Constant(lli32, objectid))
assert llglobal is None
@ -1399,7 +1400,7 @@ class LLVMIRGenerator:
# RPC and C functions have no runtime representation.
return ll.Constant(llty, ll.Undefined)
elif types.is_function(typ):
llfun = self.get_function(typ.find(), self.function_map[value])
llfun = self.get_function(typ.find(), self.embedding_map.retrieve_function(value))
llclosure = ll.Constant(self.llty_of_type(typ), [
ll.Constant(llptr, ll.Undefined),
llfun
@ -1416,13 +1417,7 @@ class LLVMIRGenerator:
assert False
def process_Quote(self, insn):
if insn.value in self.function_map and types.is_function(insn.type):
llfun = self.get_function(insn.type.find(), self.function_map[insn.value])
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
return llclosure
else:
assert self.object_map is not None
assert self.embedding_map is not None
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
def process_Select(self, insn):

View File

@ -14,7 +14,7 @@ class Comm:
def run(self):
pass
def serve(self, object_map, symbolizer):
def serve(self, embedding_map, symbolizer, demangler):
pass
def check_ident(self):

View File

@ -309,13 +309,13 @@ class CommGeneric:
_rpc_sentinel = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, object_map):
def _receive_rpc_value(self, embedding_map):
tag = chr(self._read_int8())
if tag == "\x00":
return self._rpc_sentinel
elif tag == "t":
length = self._read_int8()
return tuple(self._receive_rpc_value(object_map) for _ in range(length))
return tuple(self._receive_rpc_value(embedding_map) for _ in range(length))
elif tag == "n":
return None
elif tag == "b":
@ -334,25 +334,25 @@ class CommGeneric:
return self._read_string()
elif tag == "l":
length = self._read_int32()
return [self._receive_rpc_value(object_map) for _ in range(length)]
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
elif tag == "r":
start = self._receive_rpc_value(object_map)
stop = self._receive_rpc_value(object_map)
step = self._receive_rpc_value(object_map)
start = self._receive_rpc_value(embedding_map)
stop = self._receive_rpc_value(embedding_map)
step = self._receive_rpc_value(embedding_map)
return range(start, stop, step)
elif tag == "k":
name = self._read_string()
value = self._receive_rpc_value(object_map)
value = self._receive_rpc_value(embedding_map)
return RPCKeyword(name, value)
elif tag == "O":
return object_map.retrieve(self._read_int32())
return embedding_map.retrieve_object(self._read_int32())
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, object_map):
def _receive_rpc_args(self, embedding_map):
args, kwargs = [], {}
while True:
value = self._receive_rpc_value(object_map)
value = self._receive_rpc_value(embedding_map)
if value is self._rpc_sentinel:
return args, kwargs
elif isinstance(value, RPCKeyword):
@ -440,14 +440,14 @@ class CommGeneric:
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _serve_rpc(self, object_map):
def _serve_rpc(self, embedding_map):
service_id = self._read_int32()
if service_id == 0:
service = lambda obj, attr, value: setattr(obj, attr, value)
else:
service = object_map.retrieve(service_id)
service = embedding_map.retrieve_object(service_id)
args, kwargs = self._receive_rpc_args(object_map)
args, kwargs = self._receive_rpc_args(embedding_map)
return_tags = self._read_bytes()
logger.debug("rpc service: [%d]%r %r %r -> %s", service_id, service, args, kwargs, return_tags)
@ -483,7 +483,7 @@ class CommGeneric:
hasattr(exn, 'artiq_builtin'):
self._write_string("0:{}".format(exn_type.__name__))
else:
exn_id = object_map.store(exn_type)
exn_id = embedding_map.store_object(exn_type)
self._write_string("{}:{}.{}".format(exn_id,
exn_type.__module__, exn_type.__qualname__))
self._write_string(str(exn))
@ -504,7 +504,7 @@ class CommGeneric:
self._write_flush()
def _serve_exception(self, object_map, symbolizer, demangler):
def _serve_exception(self, embedding_map, symbolizer, demangler):
name = self._read_string()
message = self._read_string()
params = [self._read_int64() for _ in range(3)]
@ -523,19 +523,19 @@ class CommGeneric:
if core_exn.id == 0:
python_exn_type = getattr(exceptions, core_exn.name.split('.')[-1])
else:
python_exn_type = object_map.retrieve(core_exn.id)
python_exn_type = embedding_map.retrieve_object(core_exn.id)
python_exn = python_exn_type(message.format(*params))
python_exn.artiq_core_exception = core_exn
raise python_exn
def serve(self, object_map, symbolizer, demangler):
def serve(self, embedding_map, symbolizer, demangler):
while True:
self._read_header()
if self._read_type == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(object_map)
self._serve_rpc(embedding_map)
elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception(object_map, symbolizer, demangler)
self._serve_exception(embedding_map, symbolizer, demangler)
elif self._read_type == _D2HMsgType.WATCHDOG_EXPIRED:
raise exceptions.WatchdogExpired
elif self._read_type == _D2HMsgType.CLOCK_FAILURE:

View File

@ -91,7 +91,7 @@ class Core:
library = target.compile_and_link([module])
stripped_library = target.strip(library)
return stitcher.object_map, stripped_library, \
return stitcher.embedding_map, stripped_library, \
lambda addresses: target.symbolize(library, addresses), \
lambda symbols: target.demangle(symbols)
except diagnostic.Error as error:
@ -103,7 +103,7 @@ class Core:
nonlocal result
result = new_result
object_map, kernel_library, symbolizer, demangler = \
embedding_map, kernel_library, symbolizer, demangler = \
self.compile(function, args, kwargs, set_result)
if self.first_run:
@ -113,7 +113,7 @@ class Core:
self.comm.load(kernel_library)
self.comm.run()
self.comm.serve(object_map, symbolizer, demangler)
self.comm.serve(embedding_map, symbolizer, demangler)
return result

View File

@ -16,7 +16,7 @@ from artiq.language.environment import EnvExperiment, ProcessArgumentManager
from artiq.master.databases import DeviceDB, DatasetDB
from artiq.master.worker_db import DeviceManager, DatasetManager
from artiq.coredevice.core import CompileError, host_only
from artiq.compiler.embedding import ObjectMap
from artiq.compiler.embedding import EmbeddingMap
from artiq.compiler.targets import OR1KTarget
from artiq.tools import *
@ -29,19 +29,19 @@ class StubObject:
pass
class StubObjectMap:
class StubEmbeddingMap:
def __init__(self):
stub_object = StubObject()
self.forward_map = defaultdict(lambda: stub_object)
self.forward_map[1] = lambda _: None # return RPC
self.next_id = -1
self.object_forward_map = defaultdict(lambda: stub_object)
self.object_forward_map[1] = lambda _: None # return RPC
self.object_current_id = -1
def retrieve(self, object_id):
return self.forward_map[object_id]
def retrieve_object(self, object_id):
return self.object_forward_map[object_id]
def store(self, value):
self.forward_map[self.next_id] = value
self.next_id -= 1
def store_object(self, value):
self.object_forward_map[self.object_current_id] = value
self.object_current_id -= 1
class FileRunner(EnvExperiment):
@ -55,7 +55,7 @@ class FileRunner(EnvExperiment):
self.core.comm.load(kernel_library)
self.core.comm.run()
self.core.comm.serve(StubObjectMap(),
self.core.comm.serve(StubEmbeddingMap(),
lambda addresses: self.target.symbolize(kernel_library, addresses))