mirror of https://github.com/m-labs/artiq.git
embedding: refactor some more.
This commit is contained in:
parent
d085d5a372
commit
640022122b
|
@ -22,37 +22,61 @@ from .transforms.asttyped_rewriter import LocalExtractor
|
|||
def coredevice_print(x): print(x)
|
||||
|
||||
|
||||
class ObjectMap:
|
||||
class EmbeddingMap:
|
||||
def __init__(self):
|
||||
self.current_key = 0
|
||||
self.forward_map = {}
|
||||
self.reverse_map = {}
|
||||
self.object_current_key = 0
|
||||
self.object_forward_map = {}
|
||||
self.object_reverse_map = {}
|
||||
self.type_map = {}
|
||||
self.function_map = {}
|
||||
|
||||
def store(self, obj_ref):
|
||||
# Types
|
||||
def store_type(self, typ, instance_type, constructor_type):
|
||||
self.type_map[typ] = (instance_type, constructor_type)
|
||||
|
||||
def retrieve_type(self, typ):
|
||||
return self.type_map[typ]
|
||||
|
||||
def has_type(self, typ):
|
||||
return typ in self.type_map
|
||||
|
||||
def iter_types(self):
|
||||
return self.type_map.values()
|
||||
|
||||
# Functions
|
||||
def store_function(self, function, ir_function_name):
|
||||
self.function_map[function] = ir_function_name
|
||||
|
||||
def retrieve_function(self, function):
|
||||
return self.function_map[function]
|
||||
|
||||
# Objects
|
||||
def store_object(self, obj_ref):
|
||||
obj_id = id(obj_ref)
|
||||
if obj_id in self.reverse_map:
|
||||
return self.reverse_map[obj_id]
|
||||
if obj_id in self.object_reverse_map:
|
||||
return self.object_reverse_map[obj_id]
|
||||
|
||||
self.current_key += 1
|
||||
self.forward_map[self.current_key] = obj_ref
|
||||
self.reverse_map[obj_id] = self.current_key
|
||||
return self.current_key
|
||||
self.object_current_key += 1
|
||||
self.object_forward_map[self.object_current_key] = obj_ref
|
||||
self.object_reverse_map[obj_id] = self.object_current_key
|
||||
return self.object_current_key
|
||||
|
||||
def retrieve(self, obj_key):
|
||||
return self.forward_map[obj_key]
|
||||
def retrieve_object(self, obj_key):
|
||||
return self.object_forward_map[obj_key]
|
||||
|
||||
def iter_objects(self):
|
||||
return self.object_forward_map.keys()
|
||||
|
||||
def has_rpc(self):
|
||||
return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x),
|
||||
self.forward_map.values()))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.forward_map.keys())
|
||||
self.object_forward_map.values()))
|
||||
|
||||
class ASTSynthesizer:
|
||||
def __init__(self, object_map, type_map, value_map, quote_function=None, expanded_from=None):
|
||||
def __init__(self, embedding_map, value_map, quote_function=None, expanded_from=None):
|
||||
self.source = ""
|
||||
self.source_buffer = source.Buffer(self.source, "<synthesized>")
|
||||
self.object_map, self.type_map, self.value_map = object_map, type_map, value_map
|
||||
self.embedding_map = embedding_map
|
||||
self.value_map = value_map
|
||||
self.quote_function = quote_function
|
||||
self.expanded_from = expanded_from
|
||||
self.diagnostics = []
|
||||
|
@ -134,8 +158,8 @@ class ASTSynthesizer:
|
|||
else:
|
||||
typ = type(value)
|
||||
|
||||
if typ in self.type_map:
|
||||
instance_type, constructor_type = self.type_map[typ]
|
||||
if self.embedding_map.has_type(typ):
|
||||
instance_type, constructor_type = self.embedding_map.retrieve_type(typ)
|
||||
|
||||
if hasattr(value, 'kernel_invariants') and \
|
||||
value.kernel_invariants != instance_type.constant_attributes:
|
||||
|
@ -170,7 +194,7 @@ class ASTSynthesizer:
|
|||
if hasattr(typ, 'artiq_builtin'):
|
||||
exception_id = 0
|
||||
else:
|
||||
exception_id = self.object_map.store(typ)
|
||||
exception_id = self.embedding_map.store_object(typ)
|
||||
instance_type = builtins.TException("{}.{}".format(typ.__module__,
|
||||
typ.__qualname__),
|
||||
id=exception_id)
|
||||
|
@ -183,7 +207,7 @@ class ASTSynthesizer:
|
|||
constructor_type.attributes['__objectid__'] = builtins.TInt32()
|
||||
instance_type.constructor = constructor_type
|
||||
|
||||
self.type_map[typ] = instance_type, constructor_type
|
||||
self.embedding_map.store_type(typ, instance_type, constructor_type)
|
||||
|
||||
if hasattr(value, 'kernel_invariants'):
|
||||
assert isinstance(value.kernel_invariants, set)
|
||||
|
@ -531,9 +555,7 @@ class Stitcher:
|
|||
|
||||
self.functions = {}
|
||||
|
||||
self.function_map = {}
|
||||
self.object_map = ObjectMap()
|
||||
self.type_map = {}
|
||||
self.embedding_map = EmbeddingMap()
|
||||
self.value_map = defaultdict(lambda: [])
|
||||
|
||||
def stitch_call(self, function, args, kwargs, callback=None):
|
||||
|
@ -562,7 +584,7 @@ class Stitcher:
|
|||
|
||||
# For every host class we embed, fill in the function slots
|
||||
# with their corresponding closures.
|
||||
for instance_type, constructor_type in list(self.type_map.values()):
|
||||
for instance_type, constructor_type in self.embedding_map.iter_types():
|
||||
# Do we have any direct reference to a constructor?
|
||||
if len(self.value_map[constructor_type]) > 0:
|
||||
# Yes, use it.
|
||||
|
@ -592,8 +614,7 @@ class Stitcher:
|
|||
|
||||
def _synthesizer(self, expanded_from=None):
|
||||
return ASTSynthesizer(expanded_from=expanded_from,
|
||||
object_map=self.object_map,
|
||||
type_map=self.type_map,
|
||||
embedding_map=self.embedding_map,
|
||||
value_map=self.value_map,
|
||||
quote_function=self._quote_function)
|
||||
|
||||
|
@ -635,7 +656,7 @@ class Stitcher:
|
|||
|
||||
# Record the function in the function map so that LLVM IR generator
|
||||
# can handle quoting it.
|
||||
self.function_map[function] = function_node.name
|
||||
self.embedding_map.store_function(function, function_node.name)
|
||||
|
||||
# Memoize the function type before typing it to handle recursive
|
||||
# invocations.
|
||||
|
@ -803,7 +824,7 @@ class Stitcher:
|
|||
else:
|
||||
assert False
|
||||
|
||||
function_type = types.TRPC(ret_type, service=self.object_map.store(function))
|
||||
function_type = types.TRPC(ret_type, service=self.embedding_map.store_object(function))
|
||||
self.functions[function] = function_type
|
||||
return function_type
|
||||
|
||||
|
|
|
@ -18,11 +18,7 @@ class Source:
|
|||
self.engine = diagnostic.Engine(all_errors_are_fatal=True)
|
||||
else:
|
||||
self.engine = engine
|
||||
|
||||
self.function_map = {}
|
||||
self.object_map = None
|
||||
self.type_map = {}
|
||||
|
||||
self.embedding_map = None
|
||||
self.name, _ = os.path.splitext(os.path.basename(source_buffer.name))
|
||||
|
||||
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine,
|
||||
|
@ -46,9 +42,9 @@ class Source:
|
|||
class Module:
|
||||
def __init__(self, src, ref_period=1e-6):
|
||||
self.engine = src.engine
|
||||
self.function_map = src.function_map
|
||||
self.object_map = src.object_map
|
||||
self.type_map = src.type_map
|
||||
self.embedding_map = src.embedding_map
|
||||
self.name = src.name
|
||||
self.globals = src.globals
|
||||
|
||||
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
|
||||
inferencer = transforms.Inferencer(engine=self.engine)
|
||||
|
@ -65,8 +61,6 @@ class Module:
|
|||
devirtualization = analyses.Devirtualization()
|
||||
interleaver = transforms.Interleaver(engine=self.engine)
|
||||
|
||||
self.name = src.name
|
||||
self.globals = src.globals
|
||||
int_monomorphizer.visit(src.typedtree)
|
||||
inferencer.visit(src.typedtree)
|
||||
monomorphism_validator.visit(src.typedtree)
|
||||
|
@ -84,7 +78,7 @@ class Module:
|
|||
"""Compile the module to LLVM IR for the specified target."""
|
||||
llvm_ir_generator = transforms.LLVMIRGenerator(
|
||||
engine=self.engine, module_name=self.name, target=target,
|
||||
function_map=self.function_map, object_map=self.object_map, type_map=self.type_map)
|
||||
embedding_map=self.embedding_map)
|
||||
return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True)
|
||||
|
||||
def entry_point(self):
|
||||
|
|
|
@ -122,12 +122,10 @@ class DebugInfoEmitter:
|
|||
|
||||
|
||||
class LLVMIRGenerator:
|
||||
def __init__(self, engine, module_name, target, function_map, object_map, type_map):
|
||||
def __init__(self, engine, module_name, target, embedding_map):
|
||||
self.engine = engine
|
||||
self.target = target
|
||||
self.function_map = function_map
|
||||
self.object_map = object_map
|
||||
self.type_map = type_map
|
||||
self.embedding_map = embedding_map
|
||||
self.llcontext = target.llcontext
|
||||
self.llmodule = ll.Module(context=self.llcontext, name=module_name)
|
||||
self.llmodule.triple = target.triple
|
||||
|
@ -402,7 +400,7 @@ class LLVMIRGenerator:
|
|||
if any(functions):
|
||||
self.debug_info_emitter.finalize(functions[0].loc.source_buffer)
|
||||
|
||||
if attribute_writeback and self.object_map is not None:
|
||||
if attribute_writeback and self.embedding_map is not None:
|
||||
self.emit_attribute_writeback()
|
||||
|
||||
return self.llmodule
|
||||
|
@ -410,15 +408,15 @@ class LLVMIRGenerator:
|
|||
def emit_attribute_writeback(self):
|
||||
llobjects = defaultdict(lambda: [])
|
||||
|
||||
for obj_id in self.object_map:
|
||||
obj_ref = self.object_map.retrieve(obj_id)
|
||||
for obj_id in self.embedding_map.iter_objects():
|
||||
obj_ref = self.embedding_map.retrieve_object(obj_id)
|
||||
if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType,
|
||||
pytypes.BuiltinFunctionType)):
|
||||
continue
|
||||
elif isinstance(obj_ref, type):
|
||||
_, typ = self.type_map[obj_ref]
|
||||
_, typ = self.embedding_map.retrieve_type(obj_ref)
|
||||
else:
|
||||
typ, _ = self.type_map[type(obj_ref)]
|
||||
typ, _ = self.embedding_map.retrieve_type(type(obj_ref))
|
||||
|
||||
llobject = self.llmodule.get_global("O.{}".format(obj_id))
|
||||
if llobject is not None:
|
||||
|
@ -1091,8 +1089,11 @@ class LLVMIRGenerator:
|
|||
llargs = [self.map(arg) for arg in insn.arguments()]
|
||||
llclosure = self.map(insn.target_function())
|
||||
if insn.static_target_function is None:
|
||||
llfun = self.llbuilder.extract_value(llclosure, 1,
|
||||
name="fun.{}".format(llclosure.name))
|
||||
if isinstance(llclosure, ll.Constant):
|
||||
name = "fun.{}".format(llclosure.constant[1].name)
|
||||
else:
|
||||
name = "fun.{}".format(llclosure.name)
|
||||
llfun = self.llbuilder.extract_value(llclosure, 1, name=name)
|
||||
else:
|
||||
llfun = self.map(insn.static_target_function)
|
||||
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun")
|
||||
|
@ -1346,7 +1347,7 @@ class LLVMIRGenerator:
|
|||
llfields = []
|
||||
for attr in typ.attributes:
|
||||
if attr == "__objectid__":
|
||||
objectid = self.object_map.store(value)
|
||||
objectid = self.embedding_map.store_object(value)
|
||||
llfields.append(ll.Constant(lli32, objectid))
|
||||
|
||||
assert llglobal is None
|
||||
|
@ -1399,7 +1400,7 @@ class LLVMIRGenerator:
|
|||
# RPC and C functions have no runtime representation.
|
||||
return ll.Constant(llty, ll.Undefined)
|
||||
elif types.is_function(typ):
|
||||
llfun = self.get_function(typ.find(), self.function_map[value])
|
||||
llfun = self.get_function(typ.find(), self.embedding_map.retrieve_function(value))
|
||||
llclosure = ll.Constant(self.llty_of_type(typ), [
|
||||
ll.Constant(llptr, ll.Undefined),
|
||||
llfun
|
||||
|
@ -1416,14 +1417,8 @@ class LLVMIRGenerator:
|
|||
assert False
|
||||
|
||||
def process_Quote(self, insn):
|
||||
if insn.value in self.function_map and types.is_function(insn.type):
|
||||
llfun = self.get_function(insn.type.find(), self.function_map[insn.value])
|
||||
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
|
||||
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
|
||||
return llclosure
|
||||
else:
|
||||
assert self.object_map is not None
|
||||
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
|
||||
assert self.embedding_map is not None
|
||||
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
|
||||
|
||||
def process_Select(self, insn):
|
||||
return self.llbuilder.select(self.map(insn.condition()),
|
||||
|
|
|
@ -14,7 +14,7 @@ class Comm:
|
|||
def run(self):
|
||||
pass
|
||||
|
||||
def serve(self, object_map, symbolizer):
|
||||
def serve(self, embedding_map, symbolizer, demangler):
|
||||
pass
|
||||
|
||||
def check_ident(self):
|
||||
|
|
|
@ -309,13 +309,13 @@ class CommGeneric:
|
|||
_rpc_sentinel = object()
|
||||
|
||||
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
|
||||
def _receive_rpc_value(self, object_map):
|
||||
def _receive_rpc_value(self, embedding_map):
|
||||
tag = chr(self._read_int8())
|
||||
if tag == "\x00":
|
||||
return self._rpc_sentinel
|
||||
elif tag == "t":
|
||||
length = self._read_int8()
|
||||
return tuple(self._receive_rpc_value(object_map) for _ in range(length))
|
||||
return tuple(self._receive_rpc_value(embedding_map) for _ in range(length))
|
||||
elif tag == "n":
|
||||
return None
|
||||
elif tag == "b":
|
||||
|
@ -334,25 +334,25 @@ class CommGeneric:
|
|||
return self._read_string()
|
||||
elif tag == "l":
|
||||
length = self._read_int32()
|
||||
return [self._receive_rpc_value(object_map) for _ in range(length)]
|
||||
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
|
||||
elif tag == "r":
|
||||
start = self._receive_rpc_value(object_map)
|
||||
stop = self._receive_rpc_value(object_map)
|
||||
step = self._receive_rpc_value(object_map)
|
||||
start = self._receive_rpc_value(embedding_map)
|
||||
stop = self._receive_rpc_value(embedding_map)
|
||||
step = self._receive_rpc_value(embedding_map)
|
||||
return range(start, stop, step)
|
||||
elif tag == "k":
|
||||
name = self._read_string()
|
||||
value = self._receive_rpc_value(object_map)
|
||||
value = self._receive_rpc_value(embedding_map)
|
||||
return RPCKeyword(name, value)
|
||||
elif tag == "O":
|
||||
return object_map.retrieve(self._read_int32())
|
||||
return embedding_map.retrieve_object(self._read_int32())
|
||||
else:
|
||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||
|
||||
def _receive_rpc_args(self, object_map):
|
||||
def _receive_rpc_args(self, embedding_map):
|
||||
args, kwargs = [], {}
|
||||
while True:
|
||||
value = self._receive_rpc_value(object_map)
|
||||
value = self._receive_rpc_value(embedding_map)
|
||||
if value is self._rpc_sentinel:
|
||||
return args, kwargs
|
||||
elif isinstance(value, RPCKeyword):
|
||||
|
@ -440,14 +440,14 @@ class CommGeneric:
|
|||
else:
|
||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||
|
||||
def _serve_rpc(self, object_map):
|
||||
def _serve_rpc(self, embedding_map):
|
||||
service_id = self._read_int32()
|
||||
if service_id == 0:
|
||||
service = lambda obj, attr, value: setattr(obj, attr, value)
|
||||
else:
|
||||
service = object_map.retrieve(service_id)
|
||||
service = embedding_map.retrieve_object(service_id)
|
||||
|
||||
args, kwargs = self._receive_rpc_args(object_map)
|
||||
args, kwargs = self._receive_rpc_args(embedding_map)
|
||||
return_tags = self._read_bytes()
|
||||
logger.debug("rpc service: [%d]%r %r %r -> %s", service_id, service, args, kwargs, return_tags)
|
||||
|
||||
|
@ -483,7 +483,7 @@ class CommGeneric:
|
|||
hasattr(exn, 'artiq_builtin'):
|
||||
self._write_string("0:{}".format(exn_type.__name__))
|
||||
else:
|
||||
exn_id = object_map.store(exn_type)
|
||||
exn_id = embedding_map.store_object(exn_type)
|
||||
self._write_string("{}:{}.{}".format(exn_id,
|
||||
exn_type.__module__, exn_type.__qualname__))
|
||||
self._write_string(str(exn))
|
||||
|
@ -504,7 +504,7 @@ class CommGeneric:
|
|||
|
||||
self._write_flush()
|
||||
|
||||
def _serve_exception(self, object_map, symbolizer, demangler):
|
||||
def _serve_exception(self, embedding_map, symbolizer, demangler):
|
||||
name = self._read_string()
|
||||
message = self._read_string()
|
||||
params = [self._read_int64() for _ in range(3)]
|
||||
|
@ -523,19 +523,19 @@ class CommGeneric:
|
|||
if core_exn.id == 0:
|
||||
python_exn_type = getattr(exceptions, core_exn.name.split('.')[-1])
|
||||
else:
|
||||
python_exn_type = object_map.retrieve(core_exn.id)
|
||||
python_exn_type = embedding_map.retrieve_object(core_exn.id)
|
||||
|
||||
python_exn = python_exn_type(message.format(*params))
|
||||
python_exn.artiq_core_exception = core_exn
|
||||
raise python_exn
|
||||
|
||||
def serve(self, object_map, symbolizer, demangler):
|
||||
def serve(self, embedding_map, symbolizer, demangler):
|
||||
while True:
|
||||
self._read_header()
|
||||
if self._read_type == _D2HMsgType.RPC_REQUEST:
|
||||
self._serve_rpc(object_map)
|
||||
self._serve_rpc(embedding_map)
|
||||
elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION:
|
||||
self._serve_exception(object_map, symbolizer, demangler)
|
||||
self._serve_exception(embedding_map, symbolizer, demangler)
|
||||
elif self._read_type == _D2HMsgType.WATCHDOG_EXPIRED:
|
||||
raise exceptions.WatchdogExpired
|
||||
elif self._read_type == _D2HMsgType.CLOCK_FAILURE:
|
||||
|
|
|
@ -91,7 +91,7 @@ class Core:
|
|||
library = target.compile_and_link([module])
|
||||
stripped_library = target.strip(library)
|
||||
|
||||
return stitcher.object_map, stripped_library, \
|
||||
return stitcher.embedding_map, stripped_library, \
|
||||
lambda addresses: target.symbolize(library, addresses), \
|
||||
lambda symbols: target.demangle(symbols)
|
||||
except diagnostic.Error as error:
|
||||
|
@ -103,7 +103,7 @@ class Core:
|
|||
nonlocal result
|
||||
result = new_result
|
||||
|
||||
object_map, kernel_library, symbolizer, demangler = \
|
||||
embedding_map, kernel_library, symbolizer, demangler = \
|
||||
self.compile(function, args, kwargs, set_result)
|
||||
|
||||
if self.first_run:
|
||||
|
@ -113,7 +113,7 @@ class Core:
|
|||
|
||||
self.comm.load(kernel_library)
|
||||
self.comm.run()
|
||||
self.comm.serve(object_map, symbolizer, demangler)
|
||||
self.comm.serve(embedding_map, symbolizer, demangler)
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from artiq.language.environment import EnvExperiment, ProcessArgumentManager
|
|||
from artiq.master.databases import DeviceDB, DatasetDB
|
||||
from artiq.master.worker_db import DeviceManager, DatasetManager
|
||||
from artiq.coredevice.core import CompileError, host_only
|
||||
from artiq.compiler.embedding import ObjectMap
|
||||
from artiq.compiler.embedding import EmbeddingMap
|
||||
from artiq.compiler.targets import OR1KTarget
|
||||
from artiq.tools import *
|
||||
|
||||
|
@ -29,19 +29,19 @@ class StubObject:
|
|||
pass
|
||||
|
||||
|
||||
class StubObjectMap:
|
||||
class StubEmbeddingMap:
|
||||
def __init__(self):
|
||||
stub_object = StubObject()
|
||||
self.forward_map = defaultdict(lambda: stub_object)
|
||||
self.forward_map[1] = lambda _: None # return RPC
|
||||
self.next_id = -1
|
||||
self.object_forward_map = defaultdict(lambda: stub_object)
|
||||
self.object_forward_map[1] = lambda _: None # return RPC
|
||||
self.object_current_id = -1
|
||||
|
||||
def retrieve(self, object_id):
|
||||
return self.forward_map[object_id]
|
||||
def retrieve_object(self, object_id):
|
||||
return self.object_forward_map[object_id]
|
||||
|
||||
def store(self, value):
|
||||
self.forward_map[self.next_id] = value
|
||||
self.next_id -= 1
|
||||
def store_object(self, value):
|
||||
self.object_forward_map[self.object_current_id] = value
|
||||
self.object_current_id -= 1
|
||||
|
||||
|
||||
class FileRunner(EnvExperiment):
|
||||
|
@ -55,7 +55,7 @@ class FileRunner(EnvExperiment):
|
|||
|
||||
self.core.comm.load(kernel_library)
|
||||
self.core.comm.run()
|
||||
self.core.comm.serve(StubObjectMap(),
|
||||
self.core.comm.serve(StubEmbeddingMap(),
|
||||
lambda addresses: self.target.symbolize(kernel_library, addresses))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue