compiler: support subkernels

This commit is contained in:
mwojcik 2023-10-05 14:35:50 +08:00 committed by Sébastien Bourdeauducq
parent 1a0fc317df
commit 0a750c77e8
15 changed files with 699 additions and 112 deletions

View File

@ -21,13 +21,19 @@ class scoped(object):
set of variables resolved as globals set of variables resolved as globals
""" """
class remote(object):
"""
:ivar remote_fn: (bool) whether function is ran on a remote device,
meaning arguments are received remotely and return is sent remotely
"""
# Typed versions of untyped nodes # Typed versions of untyped nodes
class argT(ast.arg, commontyped): class argT(ast.arg, commontyped):
pass pass
class ClassDefT(ast.ClassDef): class ClassDefT(ast.ClassDef):
_types = ("constructor_type",) _types = ("constructor_type",)
class FunctionDefT(ast.FunctionDef, scoped): class FunctionDefT(ast.FunctionDef, scoped, remote):
_types = ("signature_type",) _types = ("signature_type",)
class QuotedFunctionDefT(FunctionDefT): class QuotedFunctionDefT(FunctionDefT):
""" """
@ -58,7 +64,7 @@ class BinOpT(ast.BinOp, commontyped):
pass pass
class BoolOpT(ast.BoolOp, commontyped): class BoolOpT(ast.BoolOp, commontyped):
pass pass
class CallT(ast.Call, commontyped): class CallT(ast.Call, commontyped, remote):
""" """
:ivar iodelay: (:class:`iodelay.Expr`) :ivar iodelay: (:class:`iodelay.Expr`)
:ivar arg_exprs: (dict of str to :class:`iodelay.Expr`) :ivar arg_exprs: (dict of str to :class:`iodelay.Expr`)

View File

@ -38,6 +38,9 @@ class TInt(types.TMono):
def one(): def one():
return 1 return 1
def TInt8():
return TInt(types.TValue(8))
def TInt32(): def TInt32():
return TInt(types.TValue(32)) return TInt(types.TValue(32))
@ -244,6 +247,12 @@ def fn_at_mu():
def fn_rtio_log(): def fn_rtio_log():
return types.TBuiltinFunction("rtio_log") return types.TBuiltinFunction("rtio_log")
def fn_subkernel_await():
return types.TBuiltinFunction("subkernel_await")
def fn_subkernel_preload():
return types.TBuiltinFunction("subkernel_preload")
# Accessors # Accessors
def is_none(typ): def is_none(typ):
@ -326,7 +335,7 @@ def get_iterable_elt(typ):
# n-dimensional arrays, rather than the n-1 dimensional result of iterating over # n-dimensional arrays, rather than the n-1 dimensional result of iterating over
# the first axis, which makes the name a bit misleading. # the first axis, which makes the name a bit misleading.
if is_str(typ) or is_bytes(typ) or is_bytearray(typ): if is_str(typ) or is_bytes(typ) or is_bytearray(typ):
return TInt(types.TValue(8)) return TInt8()
elif types._is_pointer(typ) or is_iterable(typ): elif types._is_pointer(typ) or is_iterable(typ):
return typ.find()["elt"].find() return typ.find()["elt"].find()
else: else:
@ -342,5 +351,5 @@ def is_allocated(typ):
is_float(typ) or is_range(typ) or is_float(typ) or is_range(typ) or
types._is_pointer(typ) or types.is_function(typ) or types._is_pointer(typ) or types.is_function(typ) or
types.is_external_function(typ) or types.is_rpc(typ) or types.is_external_function(typ) or types.is_rpc(typ) or
types.is_method(typ) or types.is_tuple(typ) or types.is_subkernel(typ) or types.is_method(typ) or
types.is_value(typ)) types.is_tuple(typ) or types.is_value(typ))

View File

@ -74,7 +74,9 @@ class EmbeddingMap:
"CacheError", "CacheError",
"SPIError", "SPIError",
"0:ZeroDivisionError", "0:ZeroDivisionError",
"0:IndexError"]) "0:IndexError",
"UnwrapNoneError",
"SubkernelError"])
def preallocate_runtime_exception_names(self, names): def preallocate_runtime_exception_names(self, names):
for i, name in enumerate(names): for i, name in enumerate(names):
@ -183,7 +185,15 @@ class EmbeddingMap:
obj_typ, _ = self.type_map[type(obj_ref)] obj_typ, _ = self.type_map[type(obj_ref)]
yield obj_id, obj_ref, obj_typ yield obj_id, obj_ref, obj_typ
def has_rpc(self): def subkernels(self):
subkernels = {}
for k, v in self.object_forward_map.items():
if hasattr(v, "artiq_embedded"):
if v.artiq_embedded.destination is not None:
subkernels[k] = v
return subkernels
def has_rpc_or_subkernel(self):
return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x), return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x),
self.object_forward_map.values())) self.object_forward_map.values()))
@ -469,7 +479,7 @@ class ASTSynthesizer:
return asttyped.QuoteT(value=value, type=instance_type, return asttyped.QuoteT(value=value, type=instance_type,
loc=loc) loc=loc)
def call(self, callee, args, kwargs, callback=None): def call(self, callee, args, kwargs, callback=None, remote_fn=False):
""" """
Construct an AST fragment calling a function specified by Construct an AST fragment calling a function specified by
an AST node `function_node`, with given arguments. an AST node `function_node`, with given arguments.
@ -513,7 +523,7 @@ class ASTSynthesizer:
starargs=None, kwargs=None, starargs=None, kwargs=None,
type=types.TVar(), iodelay=None, arg_exprs={}, type=types.TVar(), iodelay=None, arg_exprs={},
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=callee_node.loc.join(end_loc)) loc=callee_node.loc.join(end_loc), remote_fn=remote_fn)
if callback is not None: if callback is not None:
node = asttyped.CallT( node = asttyped.CallT(
@ -548,7 +558,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
arg=node.arg, annotation=None, arg=node.arg, annotation=None,
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
def visit_quoted_function(self, node, function): def visit_quoted_function(self, node, function, remote_fn):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node) extractor.visit(node)
@ -569,7 +579,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
body=node.body, decorator_list=node.decorator_list, body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc, keyword_loc=node.keyword_loc, name_loc=node.name_loc,
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
loc=node.loc) loc=node.loc, remote_fn=remote_fn)
try: try:
self.env_stack.append(node.typing_env) self.env_stack.append(node.typing_env)
@ -777,7 +787,7 @@ class TypedtreeHasher(algorithm.Visitor):
return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields)) return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields))
class Stitcher: class Stitcher:
def __init__(self, core, dmgr, engine=None, print_as_rpc=True): def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[]):
self.core = core self.core = core
self.dmgr = dmgr self.dmgr = dmgr
if engine is None: if engine is None:
@ -803,11 +813,19 @@ class Stitcher:
self.value_map = defaultdict(lambda: []) self.value_map = defaultdict(lambda: [])
self.definitely_changed = False self.definitely_changed = False
self.destination = destination
self.first_call = True
# for non-annotated subkernels:
# main kernel inferencer output with types of arguments
self.subkernel_arg_types = subkernel_arg_types
def stitch_call(self, function, args, kwargs, callback=None): def stitch_call(self, function, args, kwargs, callback=None):
# We synthesize source code for the initial call so that # We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user. # diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function)) synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function))
call_node = synthesizer.call(function, args, kwargs, callback) # first call of a subkernel will get its arguments from remote (DRTIO)
remote_fn = self.destination != 0
call_node = synthesizer.call(function, args, kwargs, callback, remote_fn=remote_fn)
synthesizer.finalize() synthesizer.finalize()
self.typedtree.append(call_node) self.typedtree.append(call_node)
@ -919,6 +937,10 @@ class Stitcher:
return [diagnostic.Diagnostic("note", return [diagnostic.Diagnostic("note",
"in kernel function here", {}, "in kernel function here", {},
call_loc)] call_loc)]
elif fn_kind == 'subkernel':
return [diagnostic.Diagnostic("note",
"in subkernel call here", {},
call_loc)]
else: else:
assert False assert False
else: else:
@ -938,7 +960,7 @@ class Stitcher:
self._function_loc(function), self._function_loc(function),
notes=self._call_site_note(loc, fn_kind)) notes=self._call_site_note(loc, fn_kind))
self.engine.process(diag) self.engine.process(diag)
elif fn_kind == 'rpc' and param.default is not inspect.Parameter.empty: elif fn_kind == 'rpc' or fn_kind == 'subkernel' and param.default is not inspect.Parameter.empty:
notes = [] notes = []
notes.append(diagnostic.Diagnostic("note", notes.append(diagnostic.Diagnostic("note",
"expanded from here while trying to infer a type for an" "expanded from here while trying to infer a type for an"
@ -957,11 +979,18 @@ class Stitcher:
Inferencer(engine=self.engine).visit(ast) Inferencer(engine=self.engine).visit(ast)
IntMonomorphizer(engine=self.engine).visit(ast) IntMonomorphizer(engine=self.engine).visit(ast)
return ast.type return ast.type
else: elif fn_kind == 'kernel' and self.first_call and self.destination != 0:
# Let the rest of the program decide. # subkernels do not have access to the main kernel code to infer
return types.TVar() # arg types - so these are cached and passed onto subkernel
# compilation, to avoid having to annotate them fully
for name, typ in self.subkernel_arg_types:
if param.name == name:
return typ
def _quote_embedded_function(self, function, flags): # Let the rest of the program decide.
return types.TVar()
def _quote_embedded_function(self, function, flags, remote_fn=False):
# we are now parsing new functions... definitely changed the type # we are now parsing new functions... definitely changed the type
self.definitely_changed = True self.definitely_changed = True
@ -1060,7 +1089,7 @@ class Stitcher:
engine=self.engine, prelude=self.prelude, engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment, globals=self.globals, host_environment=host_environment,
quote=self._quote) quote=self._quote)
function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function) function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function, remote_fn)
function_node.flags = flags function_node.flags = flags
# Add it into our typedtree so that it gets inferenced and codegen'd. # Add it into our typedtree so that it gets inferenced and codegen'd.
@ -1174,7 +1203,6 @@ class Stitcher:
signature = inspect.signature(function) signature = inspect.signature(function)
arg_types = OrderedDict() arg_types = OrderedDict()
optarg_types = OrderedDict()
for param in signature.parameters.values(): for param in signature.parameters.values():
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
@ -1212,6 +1240,40 @@ class Stitcher:
self.functions[function] = function_type self.functions[function] = function_type
return function_type return function_type
def _quote_subkernel(self, function, loc):
if isinstance(function, SpecializedFunction):
host_function = function.host_function
else:
host_function = function
ret_type = builtins.TNone()
signature = inspect.signature(host_function)
if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(host_function, signature.return_annotation,
"return type", loc, fn_kind='subkernel')
arg_types = OrderedDict()
optarg_types = OrderedDict()
for param in signature.parameters.values():
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
diag = diagnostic.Diagnostic("error",
"subkernels must only use positional arguments; '{argument}' isn't",
{"argument": param.name},
self._function_loc(function),
notes=self._call_site_note(loc, fn_kind='subkernel'))
self.engine.process(diag)
arg_type = self._type_of_param(function, loc, param, fn_kind='subkernel')
if param.default is inspect.Parameter.empty:
arg_types[param.name] = arg_type
else:
optarg_types[param.name] = arg_type
function_type = types.TSubkernel(arg_types, optarg_types, ret_type,
sid=self.embedding_map.store_object(host_function),
destination=host_function.artiq_embedded.destination)
self.functions[function] = function_type
return function_type
def _quote_rpc(self, function, loc): def _quote_rpc(self, function, loc):
if isinstance(function, SpecializedFunction): if isinstance(function, SpecializedFunction):
host_function = function.host_function host_function = function.host_function
@ -1271,8 +1333,18 @@ class Stitcher:
(host_function.artiq_embedded.core_name is None and (host_function.artiq_embedded.core_name is None and
host_function.artiq_embedded.portable is False and host_function.artiq_embedded.portable is False and
host_function.artiq_embedded.syscall is None and host_function.artiq_embedded.syscall is None and
host_function.artiq_embedded.destination is None and
host_function.artiq_embedded.forbidden is False): host_function.artiq_embedded.forbidden is False):
self._quote_rpc(function, loc) self._quote_rpc(function, loc)
elif host_function.artiq_embedded.destination is not None and \
host_function.artiq_embedded.destination != self.destination:
# treat subkernels as kernels if running on the same device
if not 0 < host_function.artiq_embedded.destination <= 255:
diag = diagnostic.Diagnostic("error",
"subkernel destination must be between 1 and 255 (inclusive)", {},
self._function_loc(host_function))
self.engine.process(diag)
self._quote_subkernel(function, loc)
elif host_function.artiq_embedded.function is not None: elif host_function.artiq_embedded.function is not None:
if host_function.__name__ == "<lambda>": if host_function.__name__ == "<lambda>":
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
@ -1296,8 +1368,13 @@ class Stitcher:
notes=[note]) notes=[note])
self.engine.process(diag) self.engine.process(diag)
destination = host_function.artiq_embedded.destination
# remote_fn only for first call in subkernels
remote_fn = destination is not None and self.first_call
self._quote_embedded_function(function, self._quote_embedded_function(function,
flags=host_function.artiq_embedded.flags) flags=host_function.artiq_embedded.flags,
remote_fn=remote_fn)
self.first_call = False
elif host_function.artiq_embedded.syscall is not None: elif host_function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler # Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call. # to perform a system call instead of a regular call.

View File

@ -706,6 +706,64 @@ class SetLocal(Instruction):
def value(self): def value(self):
return self.operands[1] return self.operands[1]
class GetArgFromRemote(Instruction):
"""
An instruction that receives function arguments from remote
(ie. subkernel in DRTIO context)
:ivar arg_name: (string) argument name
:ivar arg_type: argument type
"""
"""
:param arg_name: (string) argument name
:param arg_type: argument type
"""
def __init__(self, arg_name, arg_type, name=""):
assert isinstance(arg_name, str)
super().__init__([], arg_type, name)
self.arg_name = arg_name
self.arg_type = arg_type
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.arg_name = self.arg_name
self_copy.arg_type = self.arg_type
return self_copy
def opcode(self):
return "getargfromremote({})".format(repr(self.arg_name))
class GetOptArgFromRemote(GetArgFromRemote):
"""
An instruction that may or may not retrieve an optional function argument
from remote, depending on number of values received by firmware.
:ivar rcv_count: number of received values,
determined by firmware
:ivar index: (integer) index of the current argument,
in reference to remote arguments
"""
"""
:param rcv_count: number of received valuese
:param index: (integer) index of the current argument,
in reference to remote arguments
"""
def __init__(self, arg_name, arg_type, rcv_count, index, name=""):
super().__init__(arg_name, arg_type, name)
self.rcv_count = rcv_count
self.index = index
def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.rcv_count = self.rcv_count
self_copy.index = self.index
return self_copy
def opcode(self):
return "getoptargfromremote({})".format(repr(self.arg_name))
class GetAttr(Instruction): class GetAttr(Instruction):
""" """
An intruction that loads an attribute from an object, An intruction that loads an attribute from an object,
@ -728,7 +786,7 @@ class GetAttr(Instruction):
typ = obj.type.attributes[attr] typ = obj.type.attributes[attr]
else: else:
typ = obj.type.constructor.attributes[attr] typ = obj.type.constructor.attributes[attr]
if types.is_function(typ) or types.is_rpc(typ): if types.is_function(typ) or types.is_rpc(typ) or types.is_subkernel(typ):
typ = types.TMethod(obj.type, typ) typ = types.TMethod(obj.type, typ)
super().__init__([obj], typ, name) super().__init__([obj], typ, name)
self.attr = attr self.attr = attr
@ -1190,14 +1248,18 @@ class IndirectBranch(Terminator):
class Return(Terminator): class Return(Terminator):
""" """
A return instruction. A return instruction.
:param remote_return: (bool)
marks a return in subkernel context,
where the return value is sent back through DRTIO
""" """
""" """
:param value: (:class:`Value`) return value :param value: (:class:`Value`) return value
""" """
def __init__(self, value, name=""): def __init__(self, value, remote_return=False, name=""):
assert isinstance(value, Value) assert isinstance(value, Value)
super().__init__([value], builtins.TNone(), name) super().__init__([value], builtins.TNone(), name)
self.remote_return = remote_return
def opcode(self): def opcode(self):
return "return" return "return"

View File

@ -84,6 +84,8 @@ class Module:
constant_hoister.process(self.artiq_ir) constant_hoister.process(self.artiq_ir)
if remarks: if remarks:
invariant_detection.process(self.artiq_ir) invariant_detection.process(self.artiq_ir)
# for subkernels: main kernel inferencer output, to be passed to further compilations
self.subkernel_arg_types = inferencer.subkernel_arg_types
def build_llvm_ir(self, target): def build_llvm_ir(self, target):
"""Compile the module to LLVM IR for the specified target.""" """Compile the module to LLVM IR for the specified target."""

View File

@ -37,6 +37,7 @@ def globals():
# ARTIQ decorators # ARTIQ decorators
"kernel": builtins.fn_kernel(), "kernel": builtins.fn_kernel(),
"subkernel": builtins.fn_kernel(),
"portable": builtins.fn_kernel(), "portable": builtins.fn_kernel(),
"rpc": builtins.fn_kernel(), "rpc": builtins.fn_kernel(),
@ -54,4 +55,8 @@ def globals():
# ARTIQ utility functions # ARTIQ utility functions
"rtio_log": builtins.fn_rtio_log(), "rtio_log": builtins.fn_rtio_log(),
"core_log": builtins.fn_print(), "core_log": builtins.fn_print(),
# ARTIQ subkernel utility functions
"subkernel_await": builtins.fn_subkernel_await(),
"subkernel_preload": builtins.fn_subkernel_preload(),
} }

View File

@ -94,8 +94,9 @@ class Target:
tool_symbolizer = "llvm-symbolizer" tool_symbolizer = "llvm-symbolizer"
tool_cxxfilt = "llvm-cxxfilt" tool_cxxfilt = "llvm-cxxfilt"
def __init__(self): def __init__(self, subkernel_id=None):
self.llcontext = ll.Context() self.llcontext = ll.Context()
self.subkernel_id = subkernel_id
def target_machine(self): def target_machine(self):
lltarget = llvm.Target.from_triple(self.triple) lltarget = llvm.Target.from_triple(self.triple)
@ -148,7 +149,8 @@ class Target:
ir.BasicBlock._dump_loc = False ir.BasicBlock._dump_loc = False
type_printer = types.TypePrinter() type_printer = types.TypePrinter()
_dump(os.getenv("ARTIQ_DUMP_IR"), "ARTIQ IR", ".txt", suffix = "_subkernel_{}".format(self.subkernel_id) if self.subkernel_id is not None else ""
_dump(os.getenv("ARTIQ_DUMP_IR"), "ARTIQ IR", suffix + ".txt",
lambda: "\n".join(fn.as_entity(type_printer) for fn in module.artiq_ir)) lambda: "\n".join(fn.as_entity(type_printer) for fn in module.artiq_ir))
llmod = module.build_llvm_ir(self) llmod = module.build_llvm_ir(self)
@ -160,12 +162,12 @@ class Target:
_dump("", "LLVM IR (broken)", ".ll", lambda: str(llmod)) _dump("", "LLVM IR (broken)", ".ll", lambda: str(llmod))
raise raise
_dump(os.getenv("ARTIQ_DUMP_UNOPT_LLVM"), "LLVM IR (generated)", "_unopt.ll", _dump(os.getenv("ARTIQ_DUMP_UNOPT_LLVM"), "LLVM IR (generated)", suffix + "_unopt.ll",
lambda: str(llparsedmod)) lambda: str(llparsedmod))
self.optimize(llparsedmod) self.optimize(llparsedmod)
_dump(os.getenv("ARTIQ_DUMP_LLVM"), "LLVM IR (optimized)", ".ll", _dump(os.getenv("ARTIQ_DUMP_LLVM"), "LLVM IR (optimized)", suffix + ".ll",
lambda: str(llparsedmod)) lambda: str(llparsedmod))
return llparsedmod return llparsedmod

View File

@ -108,6 +108,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.current_args = None self.current_args = None
self.current_assign = None self.current_assign = None
self.current_exception = None self.current_exception = None
self.current_remote_fn = False
self.break_target = None self.break_target = None
self.continue_target = None self.continue_target = None
self.return_target = None self.return_target = None
@ -211,7 +212,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
old_priv_env, self.current_private_env = self.current_private_env, priv_env old_priv_env, self.current_private_env = self.current_private_env, priv_env
self.generic_visit(node) self.generic_visit(node)
self.terminate(ir.Return(ir.Constant(None, builtins.TNone()))) self.terminate(ir.Return(ir.Constant(None, builtins.TNone()),
remote_return=self.current_remote_fn))
return self.functions return self.functions
finally: finally:
@ -294,6 +296,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
old_block, self.current_block = self.current_block, entry old_block, self.current_block = self.current_block, entry
old_globals, self.current_globals = self.current_globals, node.globals_in_scope old_globals, self.current_globals = self.current_globals, node.globals_in_scope
old_remote_fn = self.current_remote_fn
self.current_remote_fn = getattr(node, "remote_fn", False)
env_without_globals = \ env_without_globals = \
{var: node.typing_env[var] {var: node.typing_env[var]
@ -326,7 +330,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.terminate(ir.Return(result)) self.terminate(ir.Return(result))
elif builtins.is_none(typ.ret): elif builtins.is_none(typ.ret):
if not self.current_block.is_terminated(): if not self.current_block.is_terminated():
self.current_block.append(ir.Return(ir.Constant(None, builtins.TNone()))) self.current_block.append(ir.Return(ir.Constant(None, builtins.TNone()),
remote_return=self.current_remote_fn))
else: else:
if not self.current_block.is_terminated(): if not self.current_block.is_terminated():
if len(self.current_block.predecessors()) != 0: if len(self.current_block.predecessors()) != 0:
@ -345,6 +350,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.current_block = old_block self.current_block = old_block
self.current_globals = old_globals self.current_globals = old_globals
self.current_env = old_env self.current_env = old_env
self.current_remote_fn = old_remote_fn
if not is_lambda: if not is_lambda:
self.current_private_env = old_priv_env self.current_private_env = old_priv_env
@ -367,7 +373,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
return_value = self.visit(node.value) return_value = self.visit(node.value)
if self.return_target is None: if self.return_target is None:
self.append(ir.Return(return_value)) self.append(ir.Return(return_value,
remote_return=self.current_remote_fn))
else: else:
self.append(ir.SetLocal(self.current_private_env, "$return", return_value)) self.append(ir.SetLocal(self.current_private_env, "$return", return_value))
self.append(ir.Branch(self.return_target)) self.append(ir.Branch(self.return_target))
@ -2524,6 +2531,33 @@ class ARTIQIRGenerator(algorithm.Visitor):
or types.is_builtin(typ, "at_mu"): or types.is_builtin(typ, "at_mu"):
return self.append(ir.Builtin(typ.name, return self.append(ir.Builtin(typ.name,
[self.visit(arg) for arg in node.args], node.type)) [self.visit(arg) for arg in node.args], node.type))
elif types.is_builtin(typ, "subkernel_await"):
if len(node.args) == 2 and len(node.keywords) == 0:
fn = node.args[0].type
timeout = self.visit(node.args[1])
elif len(node.args) == 1 and len(node.keywords) == 0:
fn = node.args[0].type
timeout = ir.Constant(10_000, builtins.TInt64())
else:
assert False
if types.is_method(fn):
fn = types.get_method_function(fn)
sid = ir.Constant(fn.sid, builtins.TInt32())
if not builtins.is_none(fn.ret):
ret = self.append(ir.Builtin("subkernel_retrieve_return", [sid, timeout], fn.ret))
else:
ret = ir.Constant(None, builtins.TNone())
self.append(ir.Builtin("subkernel_await_finish", [sid, timeout], builtins.TNone()))
return ret
elif types.is_builtin(typ, "subkernel_preload"):
if len(node.args) == 1 and len(node.keywords) == 0:
fn = node.args[0].type
else:
assert False
if types.is_method(fn):
fn = types.get_method_function(fn)
sid = ir.Constant(fn.sid, builtins.TInt32())
return self.append(ir.Builtin("subkernel_preload", [sid], builtins.TNone()))
elif types.is_exn_constructor(typ): elif types.is_exn_constructor(typ):
return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args]) return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args])
elif types.is_constructor(typ): elif types.is_constructor(typ):
@ -2535,8 +2569,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
node.loc) node.loc)
self.engine.process(diag) self.engine.process(diag)
def _user_call(self, callee, positional, keywords, arg_exprs={}): def _user_call(self, callee, positional, keywords, arg_exprs={}, remote_fn=False):
if types.is_function(callee.type) or types.is_rpc(callee.type): if types.is_function(callee.type) or types.is_rpc(callee.type) or types.is_subkernel(callee.type):
func = callee func = callee
self_arg = None self_arg = None
fn_typ = callee.type fn_typ = callee.type
@ -2551,16 +2585,50 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
assert False assert False
if types.is_rpc(fn_typ): if types.is_rpc(fn_typ) or types.is_subkernel(fn_typ):
if self_arg is None: if self_arg is None or types.is_subkernel(fn_typ):
# self is not passed to subkernels by remote
args = positional args = positional
else: elif self_arg is not None:
args = [self_arg] + positional args = [self_arg] + positional
for keyword in keywords: for keyword in keywords:
arg = keywords[keyword] arg = keywords[keyword]
args.append(self.append(ir.Alloc([ir.Constant(keyword, builtins.TStr()), arg], args.append(self.append(ir.Alloc([ir.Constant(keyword, builtins.TStr()), arg],
ir.TKeyword(arg.type)))) ir.TKeyword(arg.type))))
elif remote_fn:
assert self_arg is None
assert len(fn_typ.args) >= len(positional)
assert len(keywords) == 0 # no keyword support
args = [None] * fn_typ.arity()
index = 0
# fill in first available args
for arg in positional:
args[index] = arg
index += 1
# remaining args are received through DRTIO
if index < len(args):
# min/max args received remotely (minus already filled)
offset = index
min_args = ir.Constant(len(fn_typ.args)-offset, builtins.TInt8())
max_args = ir.Constant(fn_typ.arity()-offset, builtins.TInt8())
rcvd_count = self.append(ir.Builtin("subkernel_await_args", [min_args, max_args], builtins.TNone()))
arg_types = list(fn_typ.args.items())[offset:]
# obligatory arguments
for arg_name, arg_type in arg_types:
args[index] = self.append(ir.GetArgFromRemote(arg_name, arg_type,
name="ARG.{}".format(arg_name)))
index += 1
# optional arguments
for optarg_name, optarg_type in fn_typ.optargs.items():
idx = ir.Constant(index-offset, builtins.TInt8())
args[index] = \
self.append(ir.GetOptArgFromRemote(optarg_name, optarg_type, rcvd_count, idx))
index += 1
else: else:
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
@ -2646,7 +2714,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
assert False, "Broadcasting for {} arguments not implemented".format(len) assert False, "Broadcasting for {} arguments not implemented".format(len)
else: else:
insn = self._user_call(callee, args, keywords, node.arg_exprs) remote_fn = getattr(node, "remote_fn", False)
insn = self._user_call(callee, args, keywords, node.arg_exprs, remote_fn)
if isinstance(node.func, asttyped.AttributeT): if isinstance(node.func, asttyped.AttributeT):
attr_node = node.func attr_node = node.func
self.method_map[(attr_node.value.type.find(), self.method_map[(attr_node.value.type.find(),

View File

@ -238,7 +238,7 @@ class ASTTypedRewriter(algorithm.Transformer):
body=node.body, decorator_list=node.decorator_list, body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc, keyword_loc=node.keyword_loc, name_loc=node.name_loc,
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
loc=node.loc) loc=node.loc, remote_fn=False)
try: try:
self.env_stack.append(node.typing_env) self.env_stack.append(node.typing_env)
@ -439,8 +439,9 @@ class ASTTypedRewriter(algorithm.Transformer):
def visit_Call(self, node): def visit_Call(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={}, node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={},
func=node.func, args=node.args, keywords=node.keywords, remote_fn=False, func=node.func,
args=node.args, keywords=node.keywords,
starargs=node.starargs, kwargs=node.kwargs, starargs=node.starargs, kwargs=node.kwargs,
star_loc=node.star_loc, dstar_loc=node.dstar_loc, star_loc=node.star_loc, dstar_loc=node.dstar_loc,
begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc) begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc)

View File

@ -46,6 +46,7 @@ class Inferencer(algorithm.Visitor):
self.function = None # currently visited function, for Return inference self.function = None # currently visited function, for Return inference
self.in_loop = False self.in_loop = False
self.has_return = False self.has_return = False
self.subkernel_arg_types = dict()
def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""): def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""):
try: try:
@ -178,7 +179,7 @@ class Inferencer(algorithm.Visitor):
# Convert to a method. # Convert to a method.
attr_type = types.TMethod(object_type, attr_type) attr_type = types.TMethod(object_type, attr_type)
self._unify_method_self(attr_type, attr_name, attr_loc, loc, value_node.loc) self._unify_method_self(attr_type, attr_name, attr_loc, loc, value_node.loc)
elif types.is_rpc(attr_type): elif types.is_rpc(attr_type) or types.is_subkernel(attr_type):
# Convert to a method. We don't have to bother typechecking # Convert to a method. We don't have to bother typechecking
# the self argument, since for RPCs anything goes. # the self argument, since for RPCs anything goes.
attr_type = types.TMethod(object_type, attr_type) attr_type = types.TMethod(object_type, attr_type)
@ -1293,6 +1294,55 @@ class Inferencer(algorithm.Visitor):
# Ignored. # Ignored.
self._unify(node.type, builtins.TNone(), self._unify(node.type, builtins.TNone(),
node.loc, None) node.loc, None)
elif types.is_builtin(typ, "subkernel_await"):
valid_forms = lambda: [
valid_form("subkernel_await(f: subkernel) -> f return type"),
valid_form("subkernel_await(f: subkernel, timeout: numpy.int64) -> f return type")
]
if 1 <= len(node.args) <= 2:
arg0 = node.args[0].type
if types.is_var(arg0):
pass # undetermined yet
else:
if types.is_method(arg0):
fn = types.get_method_function(arg0)
elif types.is_function(arg0) or types.is_subkernel(arg0):
fn = arg0
else:
diagnose(valid_forms())
self._unify(node.type, fn.ret,
node.loc, None)
if len(node.args) == 2:
arg1 = node.args[1]
if types.is_var(arg1.type):
pass
elif builtins.is_int(arg1.type):
# promote to TInt64
self._unify(arg1.type, builtins.TInt64(),
arg1.loc, None)
else:
diagnose(valid_forms())
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "subkernel_preload"):
valid_forms = lambda: [
valid_form("subkernel_preload(f: subkernel) -> None")
]
if len(node.args) == 1:
arg0 = node.args[0].type
if types.is_var(arg0):
pass # undetermined yet
else:
if types.is_method(arg0):
fn = types.get_method_function(arg0)
elif types.is_function(arg0) or types.is_subkernel(arg0):
fn = arg0
else:
diagnose(valid_forms())
self._unify(node.type, fn.ret,
node.loc, None)
else:
diagnose(valid_forms())
else: else:
assert False assert False
@ -1331,6 +1381,7 @@ class Inferencer(algorithm.Visitor):
typ_args = typ.args typ_args = typ.args
typ_optargs = typ.optargs typ_optargs = typ.optargs
typ_ret = typ.ret typ_ret = typ.ret
typ_func = typ
else: else:
typ_self = types.get_method_self(typ) typ_self = types.get_method_self(typ)
typ_func = types.get_method_function(typ) typ_func = types.get_method_function(typ)
@ -1388,12 +1439,23 @@ class Inferencer(algorithm.Visitor):
other_node=node.args[0]) other_node=node.args[0])
self._unify(node.type, ret, node.loc, None) self._unify(node.type, ret, node.loc, None)
return return
if types.is_subkernel(typ_func) and typ_func.sid not in self.subkernel_arg_types:
self.subkernel_arg_types[typ_func.sid] = []
for actualarg, (formalname, formaltyp) in \ for actualarg, (formalname, formaltyp) in \
zip(node.args, list(typ_args.items()) + list(typ_optargs.items())): zip(node.args, list(typ_args.items()) + list(typ_optargs.items())):
self._unify(actualarg.type, formaltyp, self._unify(actualarg.type, formaltyp,
actualarg.loc, None) actualarg.loc, None)
passed_args[formalname] = actualarg.loc passed_args[formalname] = actualarg.loc
if types.is_subkernel(typ_func):
if types.is_instance(actualarg.type):
# objects cannot be passed to subkernels, as rpc code doesn't support them
diag = diagnostic.Diagnostic("error",
"argument '{name}' of type: {typ} is not supported in subkernels",
{"name": formalname, "typ": actualarg.type},
actualarg.loc, [])
self.engine.process(diag)
self.subkernel_arg_types[typ_func.sid].append((formalname, formaltyp))
for keyword in node.keywords: for keyword in node.keywords:
if keyword.arg in passed_args: if keyword.arg in passed_args:
@ -1424,7 +1486,7 @@ class Inferencer(algorithm.Visitor):
passed_args[keyword.arg] = keyword.arg_loc passed_args[keyword.arg] = keyword.arg_loc
for formalname in typ_args: for formalname in typ_args:
if formalname not in passed_args: if formalname not in passed_args and not node.remote_fn:
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
"the called function is of type {type}", "the called function is of type {type}",
{"type": types.TypePrinter().name(node.func.type)}, {"type": types.TypePrinter().name(node.func.type)},

View File

@ -280,7 +280,7 @@ class IODelayEstimator(algorithm.Visitor):
context="as an argument for delay_mu()") context="as an argument for delay_mu()")
call_delay = value call_delay = value
elif not types.is_builtin(typ): elif not types.is_builtin(typ):
if types.is_function(typ) or types.is_rpc(typ): if types.is_function(typ) or types.is_rpc(typ) or types.is_subkernel(typ):
offset = 0 offset = 0
elif types.is_method(typ): elif types.is_method(typ):
offset = 1 offset = 1
@ -288,7 +288,7 @@ class IODelayEstimator(algorithm.Visitor):
else: else:
assert False assert False
if types.is_rpc(typ): if types.is_rpc(typ) or types.is_subkernel(typ):
call_delay = iodelay.Const(0) call_delay = iodelay.Const(0)
else: else:
delay = typ.find().delay.find() delay = typ.find().delay.find()
@ -311,13 +311,20 @@ class IODelayEstimator(algorithm.Visitor):
args[arg_name] = arg_node args[arg_name] = arg_node
free_vars = delay.duration.free_vars() free_vars = delay.duration.free_vars()
node.arg_exprs = { try:
arg: self.evaluate(args[arg], abort=abort, node.arg_exprs = {
context="in the expression for argument '{}' " arg: self.evaluate(args[arg], abort=abort,
"that affects I/O delay".format(arg)) context="in the expression for argument '{}' "
for arg in free_vars "that affects I/O delay".format(arg))
} for arg in free_vars
call_delay = delay.duration.fold(node.arg_exprs) }
call_delay = delay.duration.fold(node.arg_exprs)
except KeyError as e:
if getattr(node, "remote_fn", False):
note = diagnostic.Diagnostic("note",
"function called here", {},
node.loc)
self.abort("due to arguments passed remotely", node.loc, note)
else: else:
assert False assert False
else: else:

View File

@ -215,7 +215,7 @@ class LLVMIRGenerator:
typ = typ.find() typ = typ.find()
if types.is_tuple(typ): if types.is_tuple(typ):
return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts]) return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts])
elif types.is_rpc(typ) or types.is_external_function(typ): elif types.is_rpc(typ) or types.is_external_function(typ) or types.is_subkernel(typ):
if for_return: if for_return:
return llvoid return llvoid
else: else:
@ -398,6 +398,15 @@ class LLVMIRGenerator:
elif name == "rpc_recv": elif name == "rpc_recv":
llty = ll.FunctionType(lli32, [llptr]) llty = ll.FunctionType(lli32, [llptr])
elif name == "subkernel_send_message":
llty = ll.FunctionType(llvoid, [lli32, lli8, llsliceptr, llptrptr])
elif name == "subkernel_load_run":
llty = ll.FunctionType(llvoid, [lli32, lli1])
elif name == "subkernel_await_finish":
llty = ll.FunctionType(llvoid, [lli32, lli64])
elif name == "subkernel_await_message":
llty = ll.FunctionType(lli8, [lli32, lli64, lli8, lli8])
# with now-pinning # with now-pinning
elif name == "now": elif name == "now":
llty = lli64 llty = lli64
@ -874,6 +883,53 @@ class LLVMIRGenerator:
llvalue = self.llbuilder.bitcast(llvalue, llptr.type.pointee) llvalue = self.llbuilder.bitcast(llvalue, llptr.type.pointee)
return self.llbuilder.store(llvalue, llptr) return self.llbuilder.store(llvalue, llptr)
def process_GetArgFromRemote(self, insn):
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
name="subkernel.arg.stack")
llval = self._build_rpc_recv(insn.arg_type, llstackptr)
return llval
def process_GetOptArgFromRemote(self, insn):
# optarg = index < rcv_count ? Some(rcv_recv()) : None
llhead = self.llbuilder.basic_block
llrcv = self.llbuilder.append_basic_block(name="optarg.get.{}".format(insn.arg_name))
# argument received
self.llbuilder.position_at_end(llrcv)
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
name="subkernel.arg.stack")
llval = self._build_rpc_recv(insn.arg_type, llstackptr)
llrpcretblock = self.llbuilder.basic_block # 'return' from rpc_recv, will be needed later
# create the tail block, needs to be after the rpc recv tail block
lltail = self.llbuilder.append_basic_block(name="optarg.tail.{}".format(insn.arg_name))
self.llbuilder.branch(lltail)
# go back to head to add a branch to the tail
self.llbuilder.position_at_end(llhead)
llargrcvd = self.llbuilder.icmp_unsigned("<", self.map(insn.index), self.map(insn.rcv_count))
self.llbuilder.cbranch(llargrcvd, llrcv, lltail)
# argument not received/after arg recvd
self.llbuilder.position_at_end(lltail)
llargtype = self.llty_of_type(insn.arg_type)
llphi_arg_present = self.llbuilder.phi(lli1, name="optarg.phi.present.{}".format(insn.arg_name))
llphi_arg = self.llbuilder.phi(llargtype, name="optarg.phi.{}".format(insn.arg_name))
llphi_arg_present.add_incoming(ll.Constant(lli1, 0), llhead)
llphi_arg.add_incoming(ll.Constant(llargtype, ll.Undefined), llhead)
llphi_arg_present.add_incoming(ll.Constant(lli1, 1), llrpcretblock)
llphi_arg.add_incoming(llval, llrpcretblock)
lloptarg = ll.Constant(ll.LiteralStructType([lli1, llargtype]), ll.Undefined)
lloptarg = self.llbuilder.insert_value(lloptarg, llphi_arg_present, 0)
lloptarg = self.llbuilder.insert_value(lloptarg, llphi_arg, 1)
return lloptarg
def attr_index(self, typ, attr): def attr_index(self, typ, attr):
return list(typ.attributes.keys()).index(attr) return list(typ.attributes.keys()).index(attr)
@ -898,8 +954,8 @@ class LLVMIRGenerator:
def get_global_closure_ptr(self, typ, attr): def get_global_closure_ptr(self, typ, attr):
closure_type = typ.attributes[attr] closure_type = typ.attributes[attr]
assert types.is_constructor(typ) assert types.is_constructor(typ)
assert types.is_function(closure_type) or types.is_rpc(closure_type) assert types.is_function(closure_type) or types.is_rpc(closure_type) or types.is_subkernel(closure_type)
if types.is_external_function(closure_type) or types.is_rpc(closure_type): if types.is_external_function(closure_type) or types.is_rpc(closure_type) or types.is_subkernel(closure_type):
return None return None
llty = self.llty_of_type(typ.attributes[attr]) llty = self.llty_of_type(typ.attributes[attr])
@ -1344,6 +1400,29 @@ class LLVMIRGenerator:
return self.llbuilder.call(self.llbuiltin("delay_mu"), [llinterval]) return self.llbuilder.call(self.llbuiltin("delay_mu"), [llinterval])
elif insn.op == "end_catch": elif insn.op == "end_catch":
return self.llbuilder.call(self.llbuiltin("__artiq_end_catch"), []) return self.llbuilder.call(self.llbuiltin("__artiq_end_catch"), [])
elif insn.op == "subkernel_await_args":
llmin = self.map(insn.operands[0])
llmax = self.map(insn.operands[1])
return self.llbuilder.call(self.llbuiltin("subkernel_await_message"),
[ll.Constant(lli32, 0), ll.Constant(lli64, 10_000), llmin, llmax],
name="subkernel.await.args")
elif insn.op == "subkernel_await_finish":
llsid = self.map(insn.operands[0])
lltimeout = self.map(insn.operands[1])
return self.llbuilder.call(self.llbuiltin("subkernel_await_finish"), [llsid, lltimeout],
name="subkernel.await.finish")
elif insn.op == "subkernel_retrieve_return":
llsid = self.map(insn.operands[0])
lltimeout = self.map(insn.operands[1])
self.llbuilder.call(self.llbuiltin("subkernel_await_message"), [llsid, lltimeout, ll.Constant(lli8, 1), ll.Constant(lli8, 1)],
name="subkernel.await.message")
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
name="subkernel.arg.stack")
return self._build_rpc_recv(insn.type, llstackptr)
elif insn.op == "subkernel_preload":
llsid = self.map(insn.operands[0])
return self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, ll.Constant(lli1, 0)],
name="subkernel.preload")
else: else:
assert False assert False
@ -1426,6 +1505,58 @@ class LLVMIRGenerator:
return llfun, list(llargs), llarg_attrs, llcallstackptr return llfun, list(llargs), llarg_attrs, llcallstackptr
def _build_rpc_recv(self, ret, llstackptr, llnormalblock=None, llunwindblock=None):
# T result = {
# void *ret_ptr = alloca(sizeof(T));
# void *ptr = ret_ptr;
# loop: int size = rpc_recv(ptr);
# // Non-zero: Provide `size` bytes of extra storage for variable-length data.
# if(size) { ptr = alloca(size); goto loop; }
# else *(T*)ret_ptr
# }
llprehead = self.llbuilder.basic_block
llhead = self.llbuilder.append_basic_block(name="rpc.head")
if llunwindblock:
llheadu = self.llbuilder.append_basic_block(name="rpc.head.unwind")
llalloc = self.llbuilder.append_basic_block(name="rpc.continue")
lltail = self.llbuilder.append_basic_block(name="rpc.tail")
llretty = self.llty_of_type(ret)
llslot = self.llbuilder.alloca(llretty, name="rpc.ret.alloc")
llslotgen = self.llbuilder.bitcast(llslot, llptr, name="rpc.ret.ptr")
self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(llhead)
llphi = self.llbuilder.phi(llslotgen.type, name="rpc.ptr")
llphi.add_incoming(llslotgen, llprehead)
if llunwindblock:
llsize = self.llbuilder.invoke(self.llbuiltin("rpc_recv"), [llphi],
llheadu, llunwindblock,
name="rpc.size.next")
self.llbuilder.position_at_end(llheadu)
else:
llsize = self.llbuilder.call(self.llbuiltin("rpc_recv"), [llphi],
name="rpc.size.next")
lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0),
name="rpc.done")
self.llbuilder.cbranch(lldone, lltail, llalloc)
self.llbuilder.position_at_end(llalloc)
llalloca = self.llbuilder.alloca(lli8, llsize, name="rpc.alloc")
llalloca.align = self.max_target_alignment
llphi.add_incoming(llalloca, llalloc)
self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(lltail)
llret = self.llbuilder.load(llslot, name="rpc.ret")
if not ret.fold(False, lambda r, t: r or builtins.is_allocated(t)):
# We didn't allocate anything except the slot for the value itself.
# Don't waste stack space.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
if llnormalblock:
self.llbuilder.branch(llnormalblock)
return llret
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock): def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
llservice = ll.Constant(lli32, fun_type.service) llservice = ll.Constant(lli32, fun_type.service)
@ -1501,57 +1632,103 @@ class LLVMIRGenerator:
return ll.Undefined return ll.Undefined
# T result = { llret = self._build_rpc_recv(fun_type.ret, llstackptr, llnormalblock, llunwindblock)
# void *ret_ptr = alloca(sizeof(T));
# void *ptr = ret_ptr;
# loop: int size = rpc_recv(ptr);
# // Non-zero: Provide `size` bytes of extra storage for variable-length data.
# if(size) { ptr = alloca(size); goto loop; }
# else *(T*)ret_ptr
# }
llprehead = self.llbuilder.basic_block
llhead = self.llbuilder.append_basic_block(name="rpc.head")
if llunwindblock:
llheadu = self.llbuilder.append_basic_block(name="rpc.head.unwind")
llalloc = self.llbuilder.append_basic_block(name="rpc.continue")
lltail = self.llbuilder.append_basic_block(name="rpc.tail")
llretty = self.llty_of_type(fun_type.ret)
llslot = self.llbuilder.alloca(llretty, name="rpc.ret.alloc")
llslotgen = self.llbuilder.bitcast(llslot, llptr, name="rpc.ret.ptr")
self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(llhead)
llphi = self.llbuilder.phi(llslotgen.type, name="rpc.ptr")
llphi.add_incoming(llslotgen, llprehead)
if llunwindblock:
llsize = self.llbuilder.invoke(self.llbuiltin("rpc_recv"), [llphi],
llheadu, llunwindblock,
name="rpc.size.next")
self.llbuilder.position_at_end(llheadu)
else:
llsize = self.llbuilder.call(self.llbuiltin("rpc_recv"), [llphi],
name="rpc.size.next")
lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0),
name="rpc.done")
self.llbuilder.cbranch(lldone, lltail, llalloc)
self.llbuilder.position_at_end(llalloc)
llalloca = self.llbuilder.alloca(lli8, llsize, name="rpc.alloc")
llalloca.align = self.max_target_alignment
llphi.add_incoming(llalloca, llalloc)
self.llbuilder.branch(llhead)
self.llbuilder.position_at_end(lltail)
llret = self.llbuilder.load(llslot, name="rpc.ret")
if not fun_type.ret.fold(False, lambda r, t: r or builtins.is_allocated(t)):
# We didn't allocate anything except the slot for the value itself.
# Don't waste stack space.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
if llnormalblock:
self.llbuilder.branch(llnormalblock)
return llret return llret
def _build_subkernel_call(self, fun_loc, fun_type, args):
llsid = ll.Constant(lli32, fun_type.sid)
tag = b""
for arg in args:
def arg_error_handler(typ):
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(typ)},
arg.loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in subkernel calls",
{"type": printer.name(arg.type)},
arg.loc, notes=[note])
self.engine.process(diag)
tag += ir.rpc_tag(arg.type, arg_error_handler)
tag += b":"
# run the kernel first
self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, ll.Constant(lli1, 1)])
# arg sent in the same vein as RPC
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
name="subkernel.stack")
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
lltagptr = self.llbuilder.alloca(lltag.type)
self.llbuilder.store(lltag, lltagptr)
if args:
# only send args if there's anything to send, 'self' is excluded
llargs = self.llbuilder.alloca(llptr, ll.Constant(lli32, len(args)),
name="subkernel.args")
for index, arg in enumerate(args):
if builtins.is_none(arg.type):
llargslot = self.llbuilder.alloca(llunit,
name="subkernel.arg{}".format(index))
else:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type,
name="subkernel.arg{}".format(index))
self.llbuilder.store(llarg, llargslot)
llargslot = self.llbuilder.bitcast(llargslot, llptr)
llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)])
self.llbuilder.store(llargslot, llargptr)
llargcount = ll.Constant(lli8, len(args))
self.llbuilder.call(self.llbuiltin("subkernel_send_message"),
[llsid, llargcount, lltagptr, llargs])
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
return llsid
def _build_subkernel_return(self, insn):
# builds a remote return.
# unlike args, return only sends one thing.
if builtins.is_none(insn.value().type):
# do not waste time and bandwidth on Nones
return
def ret_error_handler(typ):
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(typ)},
fun_loc)
diag = diagnostic.Diagnostic("error",
"return type {type} is not supported in subkernel returns",
{"type": printer.name(fun_type.ret)},
fun_loc, notes=[note])
self.engine.process(diag)
tag = ir.rpc_tag(insn.value().type, ret_error_handler)
tag += b":"
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
lltagptr = self.llbuilder.alloca(lltag.type)
self.llbuilder.store(lltag, lltagptr)
llrets = self.llbuilder.alloca(llptr, ll.Constant(lli32, 1),
name="subkernel.return")
llret = self.map(insn.value())
llretslot = self.llbuilder.alloca(llret.type, name="subkernel.retval")
self.llbuilder.store(llret, llretslot)
llretslot = self.llbuilder.bitcast(llretslot, llptr)
self.llbuilder.store(llretslot, llrets)
llsid = ll.Constant(lli32, 0) # return goes back to master, sid is ignored
lltagcount = ll.Constant(lli8, 1) # only one thing is returned
self.llbuilder.call(self.llbuiltin("subkernel_send_message"),
[llsid, lltagcount, lltagptr, llrets])
def process_Call(self, insn): def process_Call(self, insn):
functiontyp = insn.target_function().type functiontyp = insn.target_function().type
if types.is_rpc(functiontyp): if types.is_rpc(functiontyp):
@ -1559,6 +1736,10 @@ class LLVMIRGenerator:
functiontyp, functiontyp,
insn.arguments(), insn.arguments(),
llnormalblock=None, llunwindblock=None) llnormalblock=None, llunwindblock=None)
elif types.is_subkernel(functiontyp):
return self._build_subkernel_call(insn.target_function().loc,
functiontyp,
insn.arguments())
elif types.is_external_function(functiontyp): elif types.is_external_function(functiontyp):
llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn) llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn)
else: else:
@ -1595,6 +1776,11 @@ class LLVMIRGenerator:
functiontyp, functiontyp,
insn.arguments(), insn.arguments(),
llnormalblock, llunwindblock) llnormalblock, llunwindblock)
elif types.is_subkernel(functiontyp):
return self._build_subkernel_call(insn.target_function().loc,
functiontyp,
insn.arguments(),
llnormalblock, llunwindblock)
elif types.is_external_function(functiontyp): elif types.is_external_function(functiontyp):
llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn) llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn)
else: else:
@ -1673,7 +1859,8 @@ class LLVMIRGenerator:
attrvalue = getattr(value, attr) attrvalue = getattr(value, attr)
is_class_function = (types.is_constructor(typ) and is_class_function = (types.is_constructor(typ) and
types.is_function(typ.attributes[attr]) and types.is_function(typ.attributes[attr]) and
not types.is_external_function(typ.attributes[attr])) not types.is_external_function(typ.attributes[attr]) and
not types.is_subkernel(typ.attributes[attr]))
if is_class_function: if is_class_function:
attrvalue = self.embedding_map.specialize_function(typ.instance, attrvalue) attrvalue = self.embedding_map.specialize_function(typ.instance, attrvalue)
if not (types.is_instance(typ) and attr in typ.constant_attributes): if not (types.is_instance(typ) and attr in typ.constant_attributes):
@ -1758,7 +1945,8 @@ class LLVMIRGenerator:
llelts = [self._quote(v, t, lambda: path() + [str(i)]) llelts = [self._quote(v, t, lambda: path() + [str(i)])
for i, (v, t) in enumerate(zip(value, typ.elts))] for i, (v, t) in enumerate(zip(value, typ.elts))]
return ll.Constant(llty, llelts) return ll.Constant(llty, llelts)
elif types.is_rpc(typ) or types.is_external_function(typ) or types.is_builtin_function(typ): elif types.is_rpc(typ) or types.is_external_function(typ) or \
types.is_builtin_function(typ) or types.is_subkernel(typ):
# RPC, C and builtin functions have no runtime representation. # RPC, C and builtin functions have no runtime representation.
return ll.Constant(llty, ll.Undefined) return ll.Constant(llty, ll.Undefined)
elif types.is_function(typ): elif types.is_function(typ):
@ -1813,6 +2001,8 @@ class LLVMIRGenerator:
return llinsn return llinsn
def process_Return(self, insn): def process_Return(self, insn):
if insn.remote_return:
self._build_subkernel_return(insn)
if builtins.is_none(insn.value().type): if builtins.is_none(insn.value().type):
return self.llbuilder.ret_void() return self.llbuilder.ret_void()
else: else:

View File

@ -385,6 +385,50 @@ class TRPC(Type):
def __hash__(self): def __hash__(self):
return hash(self.service) return hash(self.service)
class TSubkernel(TFunction):
"""
A kernel to be run on a satellite.
:ivar args: (:class:`collections.OrderedDict` of string to :class:`Type`)
function arguments
:ivar ret: (:class:`Type`)
return type
:ivar sid: (int) subkernel ID number
:ivar destination: (int) satellite destination number
"""
attributes = OrderedDict()
def __init__(self, args, optargs, ret, sid, destination):
assert isinstance(ret, Type)
super().__init__(args, optargs, ret)
self.sid, self.destination = sid, destination
self.delay = TFixedDelay(iodelay.Const(0))
def unify(self, other):
if other is self:
return
if isinstance(other, TSubkernel) and \
self.sid == other.sid and \
self.destination == other.destination:
self.ret.unify(other.ret)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
def __repr__(self):
if getattr(builtins, "__in_sphinx__", False):
return str(self)
return "artiq.compiler.types.TSubkernel({})".format(repr(self.ret))
def __eq__(self, other):
return isinstance(other, TSubkernel) and \
self.sid == other.sid
def __hash__(self):
return hash(self.sid)
class TBuiltin(Type): class TBuiltin(Type):
""" """
An instance of builtin type. Every instance of a builtin An instance of builtin type. Every instance of a builtin
@ -644,6 +688,9 @@ def is_function(typ):
def is_rpc(typ): def is_rpc(typ):
return isinstance(typ.find(), TRPC) return isinstance(typ.find(), TRPC)
def is_subkernel(typ):
return isinstance(typ.find(), TSubkernel)
def is_external_function(typ, name=None): def is_external_function(typ, name=None):
typ = typ.find() typ = typ.find()
if name is None: if name is None:
@ -810,6 +857,10 @@ class TypePrinter(object):
return "[rpc{} #{}](...)->{}".format(typ.service, return "[rpc{} #{}](...)->{}".format(typ.service,
" async" if typ.is_async else "", " async" if typ.is_async else "",
self.name(typ.ret, depth + 1)) self.name(typ.ret, depth + 1))
elif isinstance(typ, TSubkernel):
return "<subkernel{} dest#{}>->{}".format(typ.sid,
typ.destination,
self.name(typ.ret, depth + 1))
elif isinstance(typ, TBuiltinFunction): elif isinstance(typ, TBuiltinFunction):
return "<function {}>".format(typ.name) return "<function {}>".format(typ.name)
elif isinstance(typ, (TConstructor, TExceptionConstructor)): elif isinstance(typ, (TConstructor, TExceptionConstructor)):

View File

@ -73,8 +73,8 @@ def main():
finally: finally:
dataset_db.close_db() dataset_db.close_db()
if object_map.has_rpc(): if object_map.has_rpc_or_subkernel():
raise ValueError("Experiment must not use RPC") raise ValueError("Experiment must not use RPC or subkernels")
output = args.output output = args.output
if output is None: if output is None:

View File

@ -7,7 +7,7 @@ from functools import wraps
import numpy import numpy
__all__ = ["kernel", "portable", "rpc", "syscall", "host_only", __all__ = ["kernel", "portable", "rpc", "subkernel", "syscall", "host_only",
"kernel_from_string", "set_time_manager", "set_watchdog_factory", "kernel_from_string", "set_time_manager", "set_watchdog_factory",
"TerminationRequested"] "TerminationRequested"]
@ -21,7 +21,7 @@ __all__.extend(kernel_globals)
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo", _ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo",
"core_name portable function syscall forbidden flags") "core_name portable function syscall forbidden destination flags")
def kernel(arg=None, flags={}): def kernel(arg=None, flags={}):
""" """
@ -54,7 +54,7 @@ def kernel(arg=None, flags={}):
return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs) return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs)
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo( run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
core_name=arg, portable=False, function=function, syscall=None, core_name=arg, portable=False, function=function, syscall=None,
forbidden=False, flags=set(flags)) forbidden=False, destination=None, flags=set(flags))
return run_on_core return run_on_core
return inner_decorator return inner_decorator
elif arg is None: elif arg is None:
@ -64,6 +64,50 @@ def kernel(arg=None, flags={}):
else: else:
return kernel("core", flags)(arg) return kernel("core", flags)(arg)
def subkernel(arg=None, destination=0, flags={}):
"""
This decorator marks an object's method or function for execution on a satellite device.
Destination must be given, and it must be between 1 and 255 (inclusive).
Subkernels behave similarly to kernels, with few key differences:
- they are started from main kernels,
- they do not support RPCs, or running subsequent subkernels on other devices,
- but they can call other kernels or subkernels with the same destination.
Subkernels can accept arguments and return values. However, they must be fully
annotated with ARTIQ types.
To call a subkernel, call it like a normal function.
To await its finishing execution, call ``subkernel_await(subkernel, [timeout])``.
The timeout parameter is optional, and by default is equal to 10000 (miliseconds).
This time can be adjusted for subkernels that take a long time to execute.
The compiled subkernel is copied to satellites, but not yet to the kernel core
until it's called. For bigger subkernels it may take some time before they
actually start running. To help with that, subkernels can be preloaded, with
``subkernel_preload(subkernel)`` function. A call to a preloaded subkernel
will take less time, but only one subkernel can be preloaded at a time.
"""
if isinstance(arg, str):
def inner_decorator(function):
@wraps(function)
def run_subkernel(self, *k_args, **k_kwargs):
sid = getattr(self, arg).prepare_subkernel(destination, run_subkernel, ((self,) + k_args), k_kwargs)
getattr(self, arg).run_subkernel(sid)
run_subkernel.artiq_embedded = _ARTIQEmbeddedInfo(
core_name=arg, portable=False, function=function, syscall=None,
forbidden=False, destination=destination, flags=set(flags))
return run_subkernel
return inner_decorator
elif arg is None:
def inner_decorator(function):
return subkernel(function, destination, flags)
return inner_decorator
else:
return subkernel("core", destination, flags)(arg)
def portable(arg=None, flags={}): def portable(arg=None, flags={}):
""" """
This decorator marks a function for execution on the same device as its This decorator marks a function for execution on the same device as its
@ -84,7 +128,7 @@ def portable(arg=None, flags={}):
else: else:
arg.artiq_embedded = \ arg.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, portable=True, function=arg, syscall=None, _ARTIQEmbeddedInfo(core_name=None, portable=True, function=arg, syscall=None,
forbidden=False, flags=set(flags)) forbidden=False, destination=None, flags=set(flags))
return arg return arg
def rpc(arg=None, flags={}): def rpc(arg=None, flags={}):
@ -100,7 +144,7 @@ def rpc(arg=None, flags={}):
else: else:
arg.artiq_embedded = \ arg.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=arg, syscall=None, _ARTIQEmbeddedInfo(core_name=None, portable=False, function=arg, syscall=None,
forbidden=False, flags=set(flags)) forbidden=False, destination=None, flags=set(flags))
return arg return arg
def syscall(arg=None, flags={}): def syscall(arg=None, flags={}):
@ -118,7 +162,7 @@ def syscall(arg=None, flags={}):
def inner_decorator(function): def inner_decorator(function):
function.artiq_embedded = \ function.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, _ARTIQEmbeddedInfo(core_name=None, portable=False, function=None,
syscall=arg, forbidden=False, syscall=arg, forbidden=False, destination=None,
flags=set(flags)) flags=set(flags))
return function return function
return inner_decorator return inner_decorator
@ -136,7 +180,7 @@ def host_only(function):
""" """
function.artiq_embedded = \ function.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, syscall=None, _ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, syscall=None,
forbidden=True, flags={}) forbidden=True, destination=None, flags={})
return function return function