Allow embedding and RPC sending host objects.

This commit is contained in:
whitequark 2015-08-25 21:56:01 -07:00
parent 526d7c4e46
commit 9b9fa1ab7c
8 changed files with 154 additions and 47 deletions

View File

@ -14,11 +14,30 @@ from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
class ObjectMap:
def __init__(self):
self.current_key = 0
self.forward_map = {}
self.reverse_map = {}
def store(self, obj_ref):
obj_id = id(obj_ref)
if obj_id in self.reverse_map:
return self.reverse_map[obj_id]
self.current_key += 1
self.forward_map[self.current_key] = obj_ref
self.reverse_map[obj_id] = self.current_key
return self.current_key
def retrieve(self, obj_key):
return self.forward_map[obj_key]
class ASTSynthesizer: class ASTSynthesizer:
def __init__(self, expanded_from=None): def __init__(self, type_map, expanded_from=None):
self.source = "" self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>") self.source_buffer = source.Buffer(self.source, "<synthesized>")
self.expanded_from = expanded_from self.type_map, self.expanded_from = type_map, expanded_from
def finalize(self): def finalize(self):
self.source_buffer.source = self.source self.source_buffer.source = self.source
@ -63,8 +82,32 @@ class ASTSynthesizer:
begin_loc=begin_loc, end_loc=end_loc, begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc)) loc=begin_loc.join(end_loc))
else: else:
raise "no" if isinstance(value, type):
# return asttyped.QuoteT(value=value, type=types.TVar()) typ = value
else:
typ = type(value)
if typ in self.type_map:
instance_type, constructor_type = self.type_map[typ]
else:
instance_type = types.TInstance("{}.{}".format(typ.__module__, typ.__name__))
instance_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32))
constructor_type = types.TConstructor(instance_type)
constructor_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32))
self.type_map[typ] = instance_type, constructor_type
quote_loc = self._add('`')
repr_loc = self._add(repr(value))
unquote_loc = self._add('`')
if isinstance(value, type):
return asttyped.QuoteT(value=value, type=constructor_type,
loc=quote_loc.join(unquote_loc))
else:
return asttyped.QuoteT(value=value, type=instance_type,
loc=quote_loc.join(unquote_loc))
def call(self, function_node, args, kwargs): def call(self, function_node, args, kwargs):
""" """
@ -108,13 +151,14 @@ class ASTSynthesizer:
loc=name_loc.join(end_loc)) loc=name_loc.join(end_loc))
class StitchingASTTypedRewriter(ASTTypedRewriter): class StitchingASTTypedRewriter(ASTTypedRewriter):
def __init__(self, engine, prelude, globals, host_environment, quote_function): def __init__(self, engine, prelude, globals, host_environment, quote_function, type_map):
super().__init__(engine, prelude) super().__init__(engine, prelude)
self.globals = globals self.globals = globals
self.env_stack.append(self.globals) self.env_stack.append(self.globals)
self.host_environment = host_environment self.host_environment = host_environment
self.quote_function = quote_function self.quote_function = quote_function
self.type_map = type_map
def visit_Name(self, node): def visit_Name(self, node):
typ = super()._try_find_name(node.id) typ = super()._try_find_name(node.id)
@ -136,7 +180,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
else: else:
# It's just a value. Quote it. # It's just a value. Quote it.
synthesizer = ASTSynthesizer(expanded_from=node.loc) synthesizer = ASTSynthesizer(expanded_from=node.loc, type_map=self.type_map)
node = synthesizer.quote(value) node = synthesizer.quote(value)
synthesizer.finalize() synthesizer.finalize()
return node return node
@ -160,19 +204,8 @@ class Stitcher:
self.functions = {} self.functions = {}
self.next_rpc = 0 self.object_map = ObjectMap()
self.rpc_map = {} self.type_map = {}
self.inverse_rpc_map = {}
def _map(self, obj):
obj_id = id(obj)
if obj_id in self.inverse_rpc_map:
return self.inverse_rpc_map[obj_id]
self.next_rpc += 1
self.rpc_map[self.next_rpc] = obj
self.inverse_rpc_map[obj_id] = self.next_rpc
return self.next_rpc
def finalize(self): def finalize(self):
inferencer = Inferencer(engine=self.engine) inferencer = Inferencer(engine=self.engine)
@ -229,7 +262,7 @@ class Stitcher:
asttyped_rewriter = StitchingASTTypedRewriter( asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, prelude=self.prelude, engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment, globals=self.globals, host_environment=host_environment,
quote_function=self._quote_function) quote_function=self._quote_function, type_map=self.type_map)
return asttyped_rewriter.visit(function_node) return asttyped_rewriter.visit(function_node)
def _function_loc(self, function): def _function_loc(self, function):
@ -291,7 +324,7 @@ class Stitcher:
# This is tricky, because the default value might not have # This is tricky, because the default value might not have
# a well-defined type in APython. # a well-defined type in APython.
# In this case, we bail out, but mention why we do it. # In this case, we bail out, but mention why we do it.
synthesizer = ASTSynthesizer() synthesizer = ASTSynthesizer(type_map=self.type_map)
ast = synthesizer.quote(param.default) ast = synthesizer.quote(param.default)
synthesizer.finalize() synthesizer.finalize()
@ -365,7 +398,7 @@ class Stitcher:
if syscall is None: if syscall is None:
function_type = types.TRPCFunction(arg_types, optarg_types, ret_type, function_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
service=self._map(function)) service=self.object_map.store(function))
function_name = "rpc${}".format(function_type.service) function_name = "rpc${}".format(function_type.service)
else: else:
function_type = types.TCFunction(arg_types, ret_type, function_type = types.TCFunction(arg_types, ret_type,
@ -409,7 +442,7 @@ class Stitcher:
# We synthesize source code for the initial call so that # We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user. # diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer() synthesizer = ASTSynthesizer(type_map=self.type_map)
call_node = synthesizer.call(function_node, args, kwargs) call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize() synthesizer.finalize()
self.typedtree.append(call_node) self.typedtree.append(call_node)

View File

@ -901,6 +901,23 @@ class Select(Instruction):
def if_false(self): def if_false(self):
return self.operands[2] return self.operands[2]
class Quote(Instruction):
"""
A quote operation. Returns a host interpreter value as a constant.
:ivar value: (string) operation name
"""
"""
:param value: (string) operation name
"""
def __init__(self, value, typ, name=""):
super().__init__([], typ, name)
self.value = value
def opcode(self):
return "quote({})".format(repr(self.value))
class Branch(Terminator): class Branch(Terminator):
""" """
An unconditional branch instruction. An unconditional branch instruction.

View File

@ -19,6 +19,8 @@ class Source:
else: else:
self.engine = engine self.engine = engine
self.object_map = None
self.name, _ = os.path.splitext(os.path.basename(source_buffer.name)) self.name, _ = os.path.splitext(os.path.basename(source_buffer.name))
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine, asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine,
@ -42,6 +44,7 @@ class Source:
class Module: class Module:
def __init__(self, src): def __init__(self, src):
self.engine = src.engine self.engine = src.engine
self.object_map = src.object_map
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine) int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
inferencer = transforms.Inferencer(engine=self.engine) inferencer = transforms.Inferencer(engine=self.engine)
@ -65,7 +68,8 @@ class Module:
def build_llvm_ir(self, target): def build_llvm_ir(self, target):
"""Compile the module to LLVM IR for the specified target.""" """Compile the module to LLVM IR for the specified target."""
llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine, llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine,
module_name=self.name, target=target) module_name=self.name, target=target,
object_map=self.object_map)
return llvm_ir_generator.process(self.artiq_ir) return llvm_ir_generator.process(self.artiq_ir)
def entry_point(self): def entry_point(self):

View File

@ -1474,6 +1474,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.current_block = after_invoke self.current_block = after_invoke
return invoke return invoke
def visit_QuoteT(self, node):
return self.append(ir.Quote(node.value, node.type))
def instrument_assert(self, node, value): def instrument_assert(self, node, value):
if self.current_assert_env is not None: if self.current_assert_env is not None:
if isinstance(value, ir.Constant): if isinstance(value, ir.Constant):

View File

@ -159,15 +159,17 @@ class DebugInfoEmitter:
class LLVMIRGenerator: class LLVMIRGenerator:
def __init__(self, engine, module_name, target): def __init__(self, engine, module_name, target, object_map):
self.engine = engine self.engine = engine
self.target = target self.target = target
self.object_map = object_map
self.llcontext = target.llcontext self.llcontext = target.llcontext
self.llmodule = ll.Module(context=self.llcontext, name=module_name) self.llmodule = ll.Module(context=self.llcontext, name=module_name)
self.llmodule.triple = target.triple self.llmodule.triple = target.triple
self.llmodule.data_layout = target.data_layout self.llmodule.data_layout = target.data_layout
self.llfunction = None self.llfunction = None
self.llmap = {} self.llmap = {}
self.llobject_map = {}
self.phis = [] self.phis = []
self.debug_info_emitter = DebugInfoEmitter(self.llmodule) self.debug_info_emitter = DebugInfoEmitter(self.llmodule)
@ -815,6 +817,8 @@ class LLVMIRGenerator:
elif ir.is_option(typ): elif ir.is_option(typ):
return b"o" + self._rpc_tag(typ.params["inner"], return b"o" + self._rpc_tag(typ.params["inner"],
error_handler) error_handler)
elif '__objectid__' in typ.attributes:
return b"O"
else: else:
error_handler(typ) error_handler(typ)
@ -960,6 +964,54 @@ class LLVMIRGenerator:
return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock, return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock,
name=insn.name) name=insn.name)
def _quote(self, value, typ, path):
value_id = id(value)
if value_id in self.llobject_map:
return self.llobject_map[value_id]
global_name = ""
llty = self.llty_of_type(typ)
if types.is_constructor(typ) or types.is_instance(typ):
llfields = []
for attr in typ.attributes:
if attr == "__objectid__":
objectid = self.object_map.store(value)
llfields.append(ll.Constant(lli32, objectid))
global_name = "object.{}".format(objectid)
else:
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
path + [attr]))
llvalue = ll.Constant.literal_struct(llfields)
elif builtins.is_none(typ):
assert value is None
return self.llconst_of_const(value)
elif builtins.is_bool(typ):
assert value in (True, False)
return self.llconst_of_const(value)
elif builtins.is_int(typ):
assert isinstance(value, int)
return self.llconst_of_const(value)
elif builtins.is_float(typ):
assert isinstance(value, float)
return self.llconst_of_const(value)
elif builtins.is_str(typ):
assert isinstance(value, (str, bytes))
return self.llconst_of_const(value)
else:
assert False
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
llconst.initializer = llvalue
llconst.linkage = "private"
self.llobject_map[value_id] = llconst
return llconst
def process_Quote(self, insn):
assert self.object_map is not None
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
def process_Select(self, insn): def process_Select(self, insn):
return self.llbuilder.select(self.map(insn.condition()), return self.llbuilder.select(self.map(insn.condition()),
self.map(insn.if_true()), self.map(insn.if_false())) self.map(insn.if_true()), self.map(insn.if_false()))

View File

@ -282,13 +282,13 @@ class CommGeneric:
_rpc_sentinel = object() _rpc_sentinel = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, rpc_map): def _receive_rpc_value(self, object_map):
tag = chr(self._read_int8()) tag = chr(self._read_int8())
if tag == "\x00": if tag == "\x00":
return self._rpc_sentinel return self._rpc_sentinel
elif tag == "t": elif tag == "t":
length = self._read_int8() length = self._read_int8()
return tuple(self._receive_rpc_value(rpc_map) for _ in range(length)) return tuple(self._receive_rpc_value(object_map) for _ in range(length))
elif tag == "n": elif tag == "n":
return None return None
elif tag == "b": elif tag == "b":
@ -307,25 +307,25 @@ class CommGeneric:
return self._read_string() return self._read_string()
elif tag == "l": elif tag == "l":
length = self._read_int32() length = self._read_int32()
return [self._receive_rpc_value(rpc_map) for _ in range(length)] return [self._receive_rpc_value(object_map) for _ in range(length)]
elif tag == "r": elif tag == "r":
start = self._receive_rpc_value(rpc_map) start = self._receive_rpc_value(object_map)
stop = self._receive_rpc_value(rpc_map) stop = self._receive_rpc_value(object_map)
step = self._receive_rpc_value(rpc_map) step = self._receive_rpc_value(object_map)
return range(start, stop, step) return range(start, stop, step)
elif tag == "o": elif tag == "o":
present = self._read_int8() present = self._read_int8()
if present: if present:
return self._receive_rpc_value(rpc_map) return self._receive_rpc_value(object_map)
elif tag == "O": elif tag == "O":
return rpc_map[self._read_int32()] return object_map.retrieve(self._read_int32())
else: else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, rpc_map): def _receive_rpc_args(self, object_map):
args = [] args = []
while True: while True:
value = self._receive_rpc_value(rpc_map) value = self._receive_rpc_value(object_map)
if value is self._rpc_sentinel: if value is self._rpc_sentinel:
return args return args
args.append(value) args.append(value)
@ -410,20 +410,20 @@ class CommGeneric:
else: else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _serve_rpc(self, rpc_map): def _serve_rpc(self, object_map):
service = self._read_int32() service = self._read_int32()
args = self._receive_rpc_args(rpc_map) args = self._receive_rpc_args(object_map)
return_tags = self._read_bytes() return_tags = self._read_bytes()
logger.debug("rpc service: %d %r -> %s", service, args, return_tags) logger.debug("rpc service: %d %r -> %s", service, args, return_tags)
try: try:
result = rpc_map[service](*args) result = object_map.retrieve(service)(*args)
logger.debug("rpc service: %d %r == %r", service, args, result) logger.debug("rpc service: %d %r == %r", service, args, result)
self._write_header(_H2DMsgType.RPC_REPLY) self._write_header(_H2DMsgType.RPC_REPLY)
self._write_bytes(return_tags) self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags), result, result, self._send_rpc_value(bytearray(return_tags), result, result,
rpc_map[service]) object_map.retrieve(service))
self._write_flush() self._write_flush()
except core_language.ARTIQException as exn: except core_language.ARTIQException as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn) logger.debug("rpc service: %d %r ! %r", service, args, exn)
@ -473,11 +473,11 @@ class CommGeneric:
[(filename, line, column, function, None)] [(filename, line, column, function, None)]
raise core_language.ARTIQException(name, message, params, traceback) raise core_language.ARTIQException(name, message, params, traceback)
def serve(self, rpc_map, symbolizer): def serve(self, object_map, symbolizer):
while True: while True:
self._read_header() self._read_header()
if self._read_type == _D2HMsgType.RPC_REQUEST: if self._read_type == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(rpc_map) self._serve_rpc(object_map)
elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION: elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception(symbolizer) self._serve_exception(symbolizer)
else: else:

View File

@ -45,14 +45,14 @@ class Core:
library = target.compile_and_link([module]) library = target.compile_and_link([module])
stripped_library = target.strip(library) stripped_library = target.strip(library)
return stitcher.rpc_map, stripped_library, \ return stitcher.object_map, stripped_library, \
lambda addresses: target.symbolize(library, addresses) lambda addresses: target.symbolize(library, addresses)
except diagnostic.Error as error: except diagnostic.Error as error:
print("\n".join(error.diagnostic.render(colored=True)), file=sys.stderr) print("\n".join(error.diagnostic.render(colored=True)), file=sys.stderr)
raise CompileError() from error raise CompileError() from error
def run(self, function, args, kwargs): def run(self, function, args, kwargs):
rpc_map, kernel_library, symbolizer = self.compile(function, args, kwargs) object_map, kernel_library, symbolizer = self.compile(function, args, kwargs)
if self.first_run: if self.first_run:
self.comm.check_ident() self.comm.check_ident()
@ -61,7 +61,7 @@ class Core:
self.comm.load(kernel_library) self.comm.load(kernel_library)
self.comm.run() self.comm.run()
self.comm.serve(rpc_map, symbolizer) self.comm.serve(object_map, symbolizer)
@kernel @kernel
def get_rtio_counter_mu(self): def get_rtio_counter_mu(self):

View File

@ -816,9 +816,7 @@ static int send_rpc_value(const char **tag, void **value)
case 'O': { // host object case 'O': { // host object
struct { uint32_t id; } **object = *value; struct { uint32_t id; } **object = *value;
return out_packet_int32((*object)->id);
if(!out_packet_int32((*object)->id))
return 0;
} }
default: default: