mirror of https://github.com/m-labs/artiq.git
compiler: support subkernels
This commit is contained in:
parent
1a0fc317df
commit
0a750c77e8
|
@ -21,13 +21,19 @@ class scoped(object):
|
|||
set of variables resolved as globals
|
||||
"""
|
||||
|
||||
class remote(object):
|
||||
"""
|
||||
:ivar remote_fn: (bool) whether function is ran on a remote device,
|
||||
meaning arguments are received remotely and return is sent remotely
|
||||
"""
|
||||
|
||||
# Typed versions of untyped nodes
|
||||
class argT(ast.arg, commontyped):
|
||||
pass
|
||||
|
||||
class ClassDefT(ast.ClassDef):
|
||||
_types = ("constructor_type",)
|
||||
class FunctionDefT(ast.FunctionDef, scoped):
|
||||
class FunctionDefT(ast.FunctionDef, scoped, remote):
|
||||
_types = ("signature_type",)
|
||||
class QuotedFunctionDefT(FunctionDefT):
|
||||
"""
|
||||
|
@ -58,7 +64,7 @@ class BinOpT(ast.BinOp, commontyped):
|
|||
pass
|
||||
class BoolOpT(ast.BoolOp, commontyped):
|
||||
pass
|
||||
class CallT(ast.Call, commontyped):
|
||||
class CallT(ast.Call, commontyped, remote):
|
||||
"""
|
||||
:ivar iodelay: (:class:`iodelay.Expr`)
|
||||
:ivar arg_exprs: (dict of str to :class:`iodelay.Expr`)
|
||||
|
|
|
@ -38,6 +38,9 @@ class TInt(types.TMono):
|
|||
def one():
|
||||
return 1
|
||||
|
||||
def TInt8():
|
||||
return TInt(types.TValue(8))
|
||||
|
||||
def TInt32():
|
||||
return TInt(types.TValue(32))
|
||||
|
||||
|
@ -244,6 +247,12 @@ def fn_at_mu():
|
|||
def fn_rtio_log():
|
||||
return types.TBuiltinFunction("rtio_log")
|
||||
|
||||
def fn_subkernel_await():
|
||||
return types.TBuiltinFunction("subkernel_await")
|
||||
|
||||
def fn_subkernel_preload():
|
||||
return types.TBuiltinFunction("subkernel_preload")
|
||||
|
||||
# Accessors
|
||||
|
||||
def is_none(typ):
|
||||
|
@ -326,7 +335,7 @@ def get_iterable_elt(typ):
|
|||
# n-dimensional arrays, rather than the n-1 dimensional result of iterating over
|
||||
# the first axis, which makes the name a bit misleading.
|
||||
if is_str(typ) or is_bytes(typ) or is_bytearray(typ):
|
||||
return TInt(types.TValue(8))
|
||||
return TInt8()
|
||||
elif types._is_pointer(typ) or is_iterable(typ):
|
||||
return typ.find()["elt"].find()
|
||||
else:
|
||||
|
@ -342,5 +351,5 @@ def is_allocated(typ):
|
|||
is_float(typ) or is_range(typ) or
|
||||
types._is_pointer(typ) or types.is_function(typ) or
|
||||
types.is_external_function(typ) or types.is_rpc(typ) or
|
||||
types.is_method(typ) or types.is_tuple(typ) or
|
||||
types.is_value(typ))
|
||||
types.is_subkernel(typ) or types.is_method(typ) or
|
||||
types.is_tuple(typ) or types.is_value(typ))
|
||||
|
|
|
@ -74,7 +74,9 @@ class EmbeddingMap:
|
|||
"CacheError",
|
||||
"SPIError",
|
||||
"0:ZeroDivisionError",
|
||||
"0:IndexError"])
|
||||
"0:IndexError",
|
||||
"UnwrapNoneError",
|
||||
"SubkernelError"])
|
||||
|
||||
def preallocate_runtime_exception_names(self, names):
|
||||
for i, name in enumerate(names):
|
||||
|
@ -183,7 +185,15 @@ class EmbeddingMap:
|
|||
obj_typ, _ = self.type_map[type(obj_ref)]
|
||||
yield obj_id, obj_ref, obj_typ
|
||||
|
||||
def has_rpc(self):
|
||||
def subkernels(self):
|
||||
subkernels = {}
|
||||
for k, v in self.object_forward_map.items():
|
||||
if hasattr(v, "artiq_embedded"):
|
||||
if v.artiq_embedded.destination is not None:
|
||||
subkernels[k] = v
|
||||
return subkernels
|
||||
|
||||
def has_rpc_or_subkernel(self):
|
||||
return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x),
|
||||
self.object_forward_map.values()))
|
||||
|
||||
|
@ -469,7 +479,7 @@ class ASTSynthesizer:
|
|||
return asttyped.QuoteT(value=value, type=instance_type,
|
||||
loc=loc)
|
||||
|
||||
def call(self, callee, args, kwargs, callback=None):
|
||||
def call(self, callee, args, kwargs, callback=None, remote_fn=False):
|
||||
"""
|
||||
Construct an AST fragment calling a function specified by
|
||||
an AST node `function_node`, with given arguments.
|
||||
|
@ -513,7 +523,7 @@ class ASTSynthesizer:
|
|||
starargs=None, kwargs=None,
|
||||
type=types.TVar(), iodelay=None, arg_exprs={},
|
||||
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
|
||||
loc=callee_node.loc.join(end_loc))
|
||||
loc=callee_node.loc.join(end_loc), remote_fn=remote_fn)
|
||||
|
||||
if callback is not None:
|
||||
node = asttyped.CallT(
|
||||
|
@ -548,7 +558,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
|||
arg=node.arg, annotation=None,
|
||||
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||
|
||||
def visit_quoted_function(self, node, function):
|
||||
def visit_quoted_function(self, node, function, remote_fn):
|
||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||
extractor.visit(node)
|
||||
|
||||
|
@ -569,7 +579,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
|||
body=node.body, decorator_list=node.decorator_list,
|
||||
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
||||
loc=node.loc)
|
||||
loc=node.loc, remote_fn=remote_fn)
|
||||
|
||||
try:
|
||||
self.env_stack.append(node.typing_env)
|
||||
|
@ -777,7 +787,7 @@ class TypedtreeHasher(algorithm.Visitor):
|
|||
return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields))
|
||||
|
||||
class Stitcher:
|
||||
def __init__(self, core, dmgr, engine=None, print_as_rpc=True):
|
||||
def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[]):
|
||||
self.core = core
|
||||
self.dmgr = dmgr
|
||||
if engine is None:
|
||||
|
@ -803,11 +813,19 @@ class Stitcher:
|
|||
self.value_map = defaultdict(lambda: [])
|
||||
self.definitely_changed = False
|
||||
|
||||
self.destination = destination
|
||||
self.first_call = True
|
||||
# for non-annotated subkernels:
|
||||
# main kernel inferencer output with types of arguments
|
||||
self.subkernel_arg_types = subkernel_arg_types
|
||||
|
||||
def stitch_call(self, function, args, kwargs, callback=None):
|
||||
# We synthesize source code for the initial call so that
|
||||
# diagnostics would have something meaningful to display to the user.
|
||||
synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function))
|
||||
call_node = synthesizer.call(function, args, kwargs, callback)
|
||||
# first call of a subkernel will get its arguments from remote (DRTIO)
|
||||
remote_fn = self.destination != 0
|
||||
call_node = synthesizer.call(function, args, kwargs, callback, remote_fn=remote_fn)
|
||||
synthesizer.finalize()
|
||||
self.typedtree.append(call_node)
|
||||
|
||||
|
@ -919,6 +937,10 @@ class Stitcher:
|
|||
return [diagnostic.Diagnostic("note",
|
||||
"in kernel function here", {},
|
||||
call_loc)]
|
||||
elif fn_kind == 'subkernel':
|
||||
return [diagnostic.Diagnostic("note",
|
||||
"in subkernel call here", {},
|
||||
call_loc)]
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
|
@ -938,7 +960,7 @@ class Stitcher:
|
|||
self._function_loc(function),
|
||||
notes=self._call_site_note(loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
elif fn_kind == 'rpc' and param.default is not inspect.Parameter.empty:
|
||||
elif fn_kind == 'rpc' or fn_kind == 'subkernel' and param.default is not inspect.Parameter.empty:
|
||||
notes = []
|
||||
notes.append(diagnostic.Diagnostic("note",
|
||||
"expanded from here while trying to infer a type for an"
|
||||
|
@ -957,11 +979,18 @@ class Stitcher:
|
|||
Inferencer(engine=self.engine).visit(ast)
|
||||
IntMonomorphizer(engine=self.engine).visit(ast)
|
||||
return ast.type
|
||||
else:
|
||||
# Let the rest of the program decide.
|
||||
return types.TVar()
|
||||
elif fn_kind == 'kernel' and self.first_call and self.destination != 0:
|
||||
# subkernels do not have access to the main kernel code to infer
|
||||
# arg types - so these are cached and passed onto subkernel
|
||||
# compilation, to avoid having to annotate them fully
|
||||
for name, typ in self.subkernel_arg_types:
|
||||
if param.name == name:
|
||||
return typ
|
||||
|
||||
def _quote_embedded_function(self, function, flags):
|
||||
# Let the rest of the program decide.
|
||||
return types.TVar()
|
||||
|
||||
def _quote_embedded_function(self, function, flags, remote_fn=False):
|
||||
# we are now parsing new functions... definitely changed the type
|
||||
self.definitely_changed = True
|
||||
|
||||
|
@ -1060,7 +1089,7 @@ class Stitcher:
|
|||
engine=self.engine, prelude=self.prelude,
|
||||
globals=self.globals, host_environment=host_environment,
|
||||
quote=self._quote)
|
||||
function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function)
|
||||
function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function, remote_fn)
|
||||
function_node.flags = flags
|
||||
|
||||
# Add it into our typedtree so that it gets inferenced and codegen'd.
|
||||
|
@ -1174,7 +1203,6 @@ class Stitcher:
|
|||
signature = inspect.signature(function)
|
||||
|
||||
arg_types = OrderedDict()
|
||||
optarg_types = OrderedDict()
|
||||
for param in signature.parameters.values():
|
||||
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
|
@ -1212,6 +1240,40 @@ class Stitcher:
|
|||
self.functions[function] = function_type
|
||||
return function_type
|
||||
|
||||
def _quote_subkernel(self, function, loc):
|
||||
if isinstance(function, SpecializedFunction):
|
||||
host_function = function.host_function
|
||||
else:
|
||||
host_function = function
|
||||
ret_type = builtins.TNone()
|
||||
signature = inspect.signature(host_function)
|
||||
|
||||
if signature.return_annotation is not inspect.Signature.empty:
|
||||
ret_type = self._extract_annot(host_function, signature.return_annotation,
|
||||
"return type", loc, fn_kind='subkernel')
|
||||
arg_types = OrderedDict()
|
||||
optarg_types = OrderedDict()
|
||||
for param in signature.parameters.values():
|
||||
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"subkernels must only use positional arguments; '{argument}' isn't",
|
||||
{"argument": param.name},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(loc, fn_kind='subkernel'))
|
||||
self.engine.process(diag)
|
||||
|
||||
arg_type = self._type_of_param(function, loc, param, fn_kind='subkernel')
|
||||
if param.default is inspect.Parameter.empty:
|
||||
arg_types[param.name] = arg_type
|
||||
else:
|
||||
optarg_types[param.name] = arg_type
|
||||
|
||||
function_type = types.TSubkernel(arg_types, optarg_types, ret_type,
|
||||
sid=self.embedding_map.store_object(host_function),
|
||||
destination=host_function.artiq_embedded.destination)
|
||||
self.functions[function] = function_type
|
||||
return function_type
|
||||
|
||||
def _quote_rpc(self, function, loc):
|
||||
if isinstance(function, SpecializedFunction):
|
||||
host_function = function.host_function
|
||||
|
@ -1271,8 +1333,18 @@ class Stitcher:
|
|||
(host_function.artiq_embedded.core_name is None and
|
||||
host_function.artiq_embedded.portable is False and
|
||||
host_function.artiq_embedded.syscall is None and
|
||||
host_function.artiq_embedded.destination is None and
|
||||
host_function.artiq_embedded.forbidden is False):
|
||||
self._quote_rpc(function, loc)
|
||||
elif host_function.artiq_embedded.destination is not None and \
|
||||
host_function.artiq_embedded.destination != self.destination:
|
||||
# treat subkernels as kernels if running on the same device
|
||||
if not 0 < host_function.artiq_embedded.destination <= 255:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"subkernel destination must be between 1 and 255 (inclusive)", {},
|
||||
self._function_loc(host_function))
|
||||
self.engine.process(diag)
|
||||
self._quote_subkernel(function, loc)
|
||||
elif host_function.artiq_embedded.function is not None:
|
||||
if host_function.__name__ == "<lambda>":
|
||||
note = diagnostic.Diagnostic("note",
|
||||
|
@ -1296,8 +1368,13 @@ class Stitcher:
|
|||
notes=[note])
|
||||
self.engine.process(diag)
|
||||
|
||||
destination = host_function.artiq_embedded.destination
|
||||
# remote_fn only for first call in subkernels
|
||||
remote_fn = destination is not None and self.first_call
|
||||
self._quote_embedded_function(function,
|
||||
flags=host_function.artiq_embedded.flags)
|
||||
flags=host_function.artiq_embedded.flags,
|
||||
remote_fn=remote_fn)
|
||||
self.first_call = False
|
||||
elif host_function.artiq_embedded.syscall is not None:
|
||||
# Insert a storage-less global whose type instructs the compiler
|
||||
# to perform a system call instead of a regular call.
|
||||
|
|
|
@ -706,6 +706,64 @@ class SetLocal(Instruction):
|
|||
def value(self):
|
||||
return self.operands[1]
|
||||
|
||||
class GetArgFromRemote(Instruction):
|
||||
"""
|
||||
An instruction that receives function arguments from remote
|
||||
(ie. subkernel in DRTIO context)
|
||||
|
||||
:ivar arg_name: (string) argument name
|
||||
:ivar arg_type: argument type
|
||||
"""
|
||||
|
||||
"""
|
||||
:param arg_name: (string) argument name
|
||||
:param arg_type: argument type
|
||||
"""
|
||||
def __init__(self, arg_name, arg_type, name=""):
|
||||
assert isinstance(arg_name, str)
|
||||
super().__init__([], arg_type, name)
|
||||
self.arg_name = arg_name
|
||||
self.arg_type = arg_type
|
||||
|
||||
def copy(self, mapper):
|
||||
self_copy = super().copy(mapper)
|
||||
self_copy.arg_name = self.arg_name
|
||||
self_copy.arg_type = self.arg_type
|
||||
return self_copy
|
||||
|
||||
def opcode(self):
|
||||
return "getargfromremote({})".format(repr(self.arg_name))
|
||||
|
||||
class GetOptArgFromRemote(GetArgFromRemote):
|
||||
"""
|
||||
An instruction that may or may not retrieve an optional function argument
|
||||
from remote, depending on number of values received by firmware.
|
||||
|
||||
:ivar rcv_count: number of received values,
|
||||
determined by firmware
|
||||
:ivar index: (integer) index of the current argument,
|
||||
in reference to remote arguments
|
||||
"""
|
||||
|
||||
"""
|
||||
:param rcv_count: number of received valuese
|
||||
:param index: (integer) index of the current argument,
|
||||
in reference to remote arguments
|
||||
"""
|
||||
def __init__(self, arg_name, arg_type, rcv_count, index, name=""):
|
||||
super().__init__(arg_name, arg_type, name)
|
||||
self.rcv_count = rcv_count
|
||||
self.index = index
|
||||
|
||||
def copy(self, mapper):
|
||||
self_copy = super().copy(mapper)
|
||||
self_copy.rcv_count = self.rcv_count
|
||||
self_copy.index = self.index
|
||||
return self_copy
|
||||
|
||||
def opcode(self):
|
||||
return "getoptargfromremote({})".format(repr(self.arg_name))
|
||||
|
||||
class GetAttr(Instruction):
|
||||
"""
|
||||
An intruction that loads an attribute from an object,
|
||||
|
@ -728,7 +786,7 @@ class GetAttr(Instruction):
|
|||
typ = obj.type.attributes[attr]
|
||||
else:
|
||||
typ = obj.type.constructor.attributes[attr]
|
||||
if types.is_function(typ) or types.is_rpc(typ):
|
||||
if types.is_function(typ) or types.is_rpc(typ) or types.is_subkernel(typ):
|
||||
typ = types.TMethod(obj.type, typ)
|
||||
super().__init__([obj], typ, name)
|
||||
self.attr = attr
|
||||
|
@ -1190,14 +1248,18 @@ class IndirectBranch(Terminator):
|
|||
class Return(Terminator):
|
||||
"""
|
||||
A return instruction.
|
||||
:param remote_return: (bool)
|
||||
marks a return in subkernel context,
|
||||
where the return value is sent back through DRTIO
|
||||
"""
|
||||
|
||||
"""
|
||||
:param value: (:class:`Value`) return value
|
||||
"""
|
||||
def __init__(self, value, name=""):
|
||||
def __init__(self, value, remote_return=False, name=""):
|
||||
assert isinstance(value, Value)
|
||||
super().__init__([value], builtins.TNone(), name)
|
||||
self.remote_return = remote_return
|
||||
|
||||
def opcode(self):
|
||||
return "return"
|
||||
|
|
|
@ -84,6 +84,8 @@ class Module:
|
|||
constant_hoister.process(self.artiq_ir)
|
||||
if remarks:
|
||||
invariant_detection.process(self.artiq_ir)
|
||||
# for subkernels: main kernel inferencer output, to be passed to further compilations
|
||||
self.subkernel_arg_types = inferencer.subkernel_arg_types
|
||||
|
||||
def build_llvm_ir(self, target):
|
||||
"""Compile the module to LLVM IR for the specified target."""
|
||||
|
|
|
@ -37,6 +37,7 @@ def globals():
|
|||
|
||||
# ARTIQ decorators
|
||||
"kernel": builtins.fn_kernel(),
|
||||
"subkernel": builtins.fn_kernel(),
|
||||
"portable": builtins.fn_kernel(),
|
||||
"rpc": builtins.fn_kernel(),
|
||||
|
||||
|
@ -54,4 +55,8 @@ def globals():
|
|||
# ARTIQ utility functions
|
||||
"rtio_log": builtins.fn_rtio_log(),
|
||||
"core_log": builtins.fn_print(),
|
||||
|
||||
# ARTIQ subkernel utility functions
|
||||
"subkernel_await": builtins.fn_subkernel_await(),
|
||||
"subkernel_preload": builtins.fn_subkernel_preload(),
|
||||
}
|
||||
|
|
|
@ -94,8 +94,9 @@ class Target:
|
|||
tool_symbolizer = "llvm-symbolizer"
|
||||
tool_cxxfilt = "llvm-cxxfilt"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, subkernel_id=None):
|
||||
self.llcontext = ll.Context()
|
||||
self.subkernel_id = subkernel_id
|
||||
|
||||
def target_machine(self):
|
||||
lltarget = llvm.Target.from_triple(self.triple)
|
||||
|
@ -148,7 +149,8 @@ class Target:
|
|||
ir.BasicBlock._dump_loc = False
|
||||
|
||||
type_printer = types.TypePrinter()
|
||||
_dump(os.getenv("ARTIQ_DUMP_IR"), "ARTIQ IR", ".txt",
|
||||
suffix = "_subkernel_{}".format(self.subkernel_id) if self.subkernel_id is not None else ""
|
||||
_dump(os.getenv("ARTIQ_DUMP_IR"), "ARTIQ IR", suffix + ".txt",
|
||||
lambda: "\n".join(fn.as_entity(type_printer) for fn in module.artiq_ir))
|
||||
|
||||
llmod = module.build_llvm_ir(self)
|
||||
|
@ -160,12 +162,12 @@ class Target:
|
|||
_dump("", "LLVM IR (broken)", ".ll", lambda: str(llmod))
|
||||
raise
|
||||
|
||||
_dump(os.getenv("ARTIQ_DUMP_UNOPT_LLVM"), "LLVM IR (generated)", "_unopt.ll",
|
||||
_dump(os.getenv("ARTIQ_DUMP_UNOPT_LLVM"), "LLVM IR (generated)", suffix + "_unopt.ll",
|
||||
lambda: str(llparsedmod))
|
||||
|
||||
self.optimize(llparsedmod)
|
||||
|
||||
_dump(os.getenv("ARTIQ_DUMP_LLVM"), "LLVM IR (optimized)", ".ll",
|
||||
_dump(os.getenv("ARTIQ_DUMP_LLVM"), "LLVM IR (optimized)", suffix + ".ll",
|
||||
lambda: str(llparsedmod))
|
||||
|
||||
return llparsedmod
|
||||
|
|
|
@ -108,6 +108,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
self.current_args = None
|
||||
self.current_assign = None
|
||||
self.current_exception = None
|
||||
self.current_remote_fn = False
|
||||
self.break_target = None
|
||||
self.continue_target = None
|
||||
self.return_target = None
|
||||
|
@ -211,7 +212,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
old_priv_env, self.current_private_env = self.current_private_env, priv_env
|
||||
|
||||
self.generic_visit(node)
|
||||
self.terminate(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||
self.terminate(ir.Return(ir.Constant(None, builtins.TNone()),
|
||||
remote_return=self.current_remote_fn))
|
||||
|
||||
return self.functions
|
||||
finally:
|
||||
|
@ -294,6 +296,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
old_block, self.current_block = self.current_block, entry
|
||||
|
||||
old_globals, self.current_globals = self.current_globals, node.globals_in_scope
|
||||
old_remote_fn = self.current_remote_fn
|
||||
self.current_remote_fn = getattr(node, "remote_fn", False)
|
||||
|
||||
env_without_globals = \
|
||||
{var: node.typing_env[var]
|
||||
|
@ -326,7 +330,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
self.terminate(ir.Return(result))
|
||||
elif builtins.is_none(typ.ret):
|
||||
if not self.current_block.is_terminated():
|
||||
self.current_block.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||
self.current_block.append(ir.Return(ir.Constant(None, builtins.TNone()),
|
||||
remote_return=self.current_remote_fn))
|
||||
else:
|
||||
if not self.current_block.is_terminated():
|
||||
if len(self.current_block.predecessors()) != 0:
|
||||
|
@ -345,6 +350,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
self.current_block = old_block
|
||||
self.current_globals = old_globals
|
||||
self.current_env = old_env
|
||||
self.current_remote_fn = old_remote_fn
|
||||
if not is_lambda:
|
||||
self.current_private_env = old_priv_env
|
||||
|
||||
|
@ -367,7 +373,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
return_value = self.visit(node.value)
|
||||
|
||||
if self.return_target is None:
|
||||
self.append(ir.Return(return_value))
|
||||
self.append(ir.Return(return_value,
|
||||
remote_return=self.current_remote_fn))
|
||||
else:
|
||||
self.append(ir.SetLocal(self.current_private_env, "$return", return_value))
|
||||
self.append(ir.Branch(self.return_target))
|
||||
|
@ -2524,6 +2531,33 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
or types.is_builtin(typ, "at_mu"):
|
||||
return self.append(ir.Builtin(typ.name,
|
||||
[self.visit(arg) for arg in node.args], node.type))
|
||||
elif types.is_builtin(typ, "subkernel_await"):
|
||||
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||
fn = node.args[0].type
|
||||
timeout = self.visit(node.args[1])
|
||||
elif len(node.args) == 1 and len(node.keywords) == 0:
|
||||
fn = node.args[0].type
|
||||
timeout = ir.Constant(10_000, builtins.TInt64())
|
||||
else:
|
||||
assert False
|
||||
if types.is_method(fn):
|
||||
fn = types.get_method_function(fn)
|
||||
sid = ir.Constant(fn.sid, builtins.TInt32())
|
||||
if not builtins.is_none(fn.ret):
|
||||
ret = self.append(ir.Builtin("subkernel_retrieve_return", [sid, timeout], fn.ret))
|
||||
else:
|
||||
ret = ir.Constant(None, builtins.TNone())
|
||||
self.append(ir.Builtin("subkernel_await_finish", [sid, timeout], builtins.TNone()))
|
||||
return ret
|
||||
elif types.is_builtin(typ, "subkernel_preload"):
|
||||
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||
fn = node.args[0].type
|
||||
else:
|
||||
assert False
|
||||
if types.is_method(fn):
|
||||
fn = types.get_method_function(fn)
|
||||
sid = ir.Constant(fn.sid, builtins.TInt32())
|
||||
return self.append(ir.Builtin("subkernel_preload", [sid], builtins.TNone()))
|
||||
elif types.is_exn_constructor(typ):
|
||||
return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args])
|
||||
elif types.is_constructor(typ):
|
||||
|
@ -2535,8 +2569,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
node.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def _user_call(self, callee, positional, keywords, arg_exprs={}):
|
||||
if types.is_function(callee.type) or types.is_rpc(callee.type):
|
||||
def _user_call(self, callee, positional, keywords, arg_exprs={}, remote_fn=False):
|
||||
if types.is_function(callee.type) or types.is_rpc(callee.type) or types.is_subkernel(callee.type):
|
||||
func = callee
|
||||
self_arg = None
|
||||
fn_typ = callee.type
|
||||
|
@ -2551,16 +2585,50 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
else:
|
||||
assert False
|
||||
|
||||
if types.is_rpc(fn_typ):
|
||||
if self_arg is None:
|
||||
if types.is_rpc(fn_typ) or types.is_subkernel(fn_typ):
|
||||
if self_arg is None or types.is_subkernel(fn_typ):
|
||||
# self is not passed to subkernels by remote
|
||||
args = positional
|
||||
else:
|
||||
elif self_arg is not None:
|
||||
args = [self_arg] + positional
|
||||
|
||||
for keyword in keywords:
|
||||
arg = keywords[keyword]
|
||||
args.append(self.append(ir.Alloc([ir.Constant(keyword, builtins.TStr()), arg],
|
||||
ir.TKeyword(arg.type))))
|
||||
elif remote_fn:
|
||||
assert self_arg is None
|
||||
assert len(fn_typ.args) >= len(positional)
|
||||
assert len(keywords) == 0 # no keyword support
|
||||
args = [None] * fn_typ.arity()
|
||||
index = 0
|
||||
# fill in first available args
|
||||
for arg in positional:
|
||||
args[index] = arg
|
||||
index += 1
|
||||
|
||||
# remaining args are received through DRTIO
|
||||
if index < len(args):
|
||||
# min/max args received remotely (minus already filled)
|
||||
offset = index
|
||||
min_args = ir.Constant(len(fn_typ.args)-offset, builtins.TInt8())
|
||||
max_args = ir.Constant(fn_typ.arity()-offset, builtins.TInt8())
|
||||
|
||||
rcvd_count = self.append(ir.Builtin("subkernel_await_args", [min_args, max_args], builtins.TNone()))
|
||||
arg_types = list(fn_typ.args.items())[offset:]
|
||||
# obligatory arguments
|
||||
for arg_name, arg_type in arg_types:
|
||||
args[index] = self.append(ir.GetArgFromRemote(arg_name, arg_type,
|
||||
name="ARG.{}".format(arg_name)))
|
||||
index += 1
|
||||
|
||||
# optional arguments
|
||||
for optarg_name, optarg_type in fn_typ.optargs.items():
|
||||
idx = ir.Constant(index-offset, builtins.TInt8())
|
||||
args[index] = \
|
||||
self.append(ir.GetOptArgFromRemote(optarg_name, optarg_type, rcvd_count, idx))
|
||||
index += 1
|
||||
|
||||
else:
|
||||
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
|
||||
|
||||
|
@ -2646,7 +2714,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
else:
|
||||
assert False, "Broadcasting for {} arguments not implemented".format(len)
|
||||
else:
|
||||
insn = self._user_call(callee, args, keywords, node.arg_exprs)
|
||||
remote_fn = getattr(node, "remote_fn", False)
|
||||
insn = self._user_call(callee, args, keywords, node.arg_exprs, remote_fn)
|
||||
if isinstance(node.func, asttyped.AttributeT):
|
||||
attr_node = node.func
|
||||
self.method_map[(attr_node.value.type.find(),
|
||||
|
|
|
@ -238,7 +238,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||
body=node.body, decorator_list=node.decorator_list,
|
||||
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
||||
loc=node.loc)
|
||||
loc=node.loc, remote_fn=False)
|
||||
|
||||
try:
|
||||
self.env_stack.append(node.typing_env)
|
||||
|
@ -440,7 +440,8 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||
def visit_Call(self, node):
|
||||
node = self.generic_visit(node)
|
||||
node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={},
|
||||
func=node.func, args=node.args, keywords=node.keywords,
|
||||
remote_fn=False, func=node.func,
|
||||
args=node.args, keywords=node.keywords,
|
||||
starargs=node.starargs, kwargs=node.kwargs,
|
||||
star_loc=node.star_loc, dstar_loc=node.dstar_loc,
|
||||
begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc)
|
||||
|
|
|
@ -46,6 +46,7 @@ class Inferencer(algorithm.Visitor):
|
|||
self.function = None # currently visited function, for Return inference
|
||||
self.in_loop = False
|
||||
self.has_return = False
|
||||
self.subkernel_arg_types = dict()
|
||||
|
||||
def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""):
|
||||
try:
|
||||
|
@ -178,7 +179,7 @@ class Inferencer(algorithm.Visitor):
|
|||
# Convert to a method.
|
||||
attr_type = types.TMethod(object_type, attr_type)
|
||||
self._unify_method_self(attr_type, attr_name, attr_loc, loc, value_node.loc)
|
||||
elif types.is_rpc(attr_type):
|
||||
elif types.is_rpc(attr_type) or types.is_subkernel(attr_type):
|
||||
# Convert to a method. We don't have to bother typechecking
|
||||
# the self argument, since for RPCs anything goes.
|
||||
attr_type = types.TMethod(object_type, attr_type)
|
||||
|
@ -1293,6 +1294,55 @@ class Inferencer(algorithm.Visitor):
|
|||
# Ignored.
|
||||
self._unify(node.type, builtins.TNone(),
|
||||
node.loc, None)
|
||||
elif types.is_builtin(typ, "subkernel_await"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("subkernel_await(f: subkernel) -> f return type"),
|
||||
valid_form("subkernel_await(f: subkernel, timeout: numpy.int64) -> f return type")
|
||||
]
|
||||
if 1 <= len(node.args) <= 2:
|
||||
arg0 = node.args[0].type
|
||||
if types.is_var(arg0):
|
||||
pass # undetermined yet
|
||||
else:
|
||||
if types.is_method(arg0):
|
||||
fn = types.get_method_function(arg0)
|
||||
elif types.is_function(arg0) or types.is_subkernel(arg0):
|
||||
fn = arg0
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
self._unify(node.type, fn.ret,
|
||||
node.loc, None)
|
||||
if len(node.args) == 2:
|
||||
arg1 = node.args[1]
|
||||
if types.is_var(arg1.type):
|
||||
pass
|
||||
elif builtins.is_int(arg1.type):
|
||||
# promote to TInt64
|
||||
self._unify(arg1.type, builtins.TInt64(),
|
||||
arg1.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "subkernel_preload"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("subkernel_preload(f: subkernel) -> None")
|
||||
]
|
||||
if len(node.args) == 1:
|
||||
arg0 = node.args[0].type
|
||||
if types.is_var(arg0):
|
||||
pass # undetermined yet
|
||||
else:
|
||||
if types.is_method(arg0):
|
||||
fn = types.get_method_function(arg0)
|
||||
elif types.is_function(arg0) or types.is_subkernel(arg0):
|
||||
fn = arg0
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
self._unify(node.type, fn.ret,
|
||||
node.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
@ -1331,6 +1381,7 @@ class Inferencer(algorithm.Visitor):
|
|||
typ_args = typ.args
|
||||
typ_optargs = typ.optargs
|
||||
typ_ret = typ.ret
|
||||
typ_func = typ
|
||||
else:
|
||||
typ_self = types.get_method_self(typ)
|
||||
typ_func = types.get_method_function(typ)
|
||||
|
@ -1388,12 +1439,23 @@ class Inferencer(algorithm.Visitor):
|
|||
other_node=node.args[0])
|
||||
self._unify(node.type, ret, node.loc, None)
|
||||
return
|
||||
if types.is_subkernel(typ_func) and typ_func.sid not in self.subkernel_arg_types:
|
||||
self.subkernel_arg_types[typ_func.sid] = []
|
||||
|
||||
for actualarg, (formalname, formaltyp) in \
|
||||
zip(node.args, list(typ_args.items()) + list(typ_optargs.items())):
|
||||
self._unify(actualarg.type, formaltyp,
|
||||
actualarg.loc, None)
|
||||
passed_args[formalname] = actualarg.loc
|
||||
if types.is_subkernel(typ_func):
|
||||
if types.is_instance(actualarg.type):
|
||||
# objects cannot be passed to subkernels, as rpc code doesn't support them
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"argument '{name}' of type: {typ} is not supported in subkernels",
|
||||
{"name": formalname, "typ": actualarg.type},
|
||||
actualarg.loc, [])
|
||||
self.engine.process(diag)
|
||||
self.subkernel_arg_types[typ_func.sid].append((formalname, formaltyp))
|
||||
|
||||
for keyword in node.keywords:
|
||||
if keyword.arg in passed_args:
|
||||
|
@ -1424,7 +1486,7 @@ class Inferencer(algorithm.Visitor):
|
|||
passed_args[keyword.arg] = keyword.arg_loc
|
||||
|
||||
for formalname in typ_args:
|
||||
if formalname not in passed_args:
|
||||
if formalname not in passed_args and not node.remote_fn:
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"the called function is of type {type}",
|
||||
{"type": types.TypePrinter().name(node.func.type)},
|
||||
|
|
|
@ -280,7 +280,7 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
context="as an argument for delay_mu()")
|
||||
call_delay = value
|
||||
elif not types.is_builtin(typ):
|
||||
if types.is_function(typ) or types.is_rpc(typ):
|
||||
if types.is_function(typ) or types.is_rpc(typ) or types.is_subkernel(typ):
|
||||
offset = 0
|
||||
elif types.is_method(typ):
|
||||
offset = 1
|
||||
|
@ -288,7 +288,7 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
else:
|
||||
assert False
|
||||
|
||||
if types.is_rpc(typ):
|
||||
if types.is_rpc(typ) or types.is_subkernel(typ):
|
||||
call_delay = iodelay.Const(0)
|
||||
else:
|
||||
delay = typ.find().delay.find()
|
||||
|
@ -311,13 +311,20 @@ class IODelayEstimator(algorithm.Visitor):
|
|||
args[arg_name] = arg_node
|
||||
|
||||
free_vars = delay.duration.free_vars()
|
||||
node.arg_exprs = {
|
||||
arg: self.evaluate(args[arg], abort=abort,
|
||||
context="in the expression for argument '{}' "
|
||||
"that affects I/O delay".format(arg))
|
||||
for arg in free_vars
|
||||
}
|
||||
call_delay = delay.duration.fold(node.arg_exprs)
|
||||
try:
|
||||
node.arg_exprs = {
|
||||
arg: self.evaluate(args[arg], abort=abort,
|
||||
context="in the expression for argument '{}' "
|
||||
"that affects I/O delay".format(arg))
|
||||
for arg in free_vars
|
||||
}
|
||||
call_delay = delay.duration.fold(node.arg_exprs)
|
||||
except KeyError as e:
|
||||
if getattr(node, "remote_fn", False):
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"function called here", {},
|
||||
node.loc)
|
||||
self.abort("due to arguments passed remotely", node.loc, note)
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
|
|
|
@ -215,7 +215,7 @@ class LLVMIRGenerator:
|
|||
typ = typ.find()
|
||||
if types.is_tuple(typ):
|
||||
return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts])
|
||||
elif types.is_rpc(typ) or types.is_external_function(typ):
|
||||
elif types.is_rpc(typ) or types.is_external_function(typ) or types.is_subkernel(typ):
|
||||
if for_return:
|
||||
return llvoid
|
||||
else:
|
||||
|
@ -398,6 +398,15 @@ class LLVMIRGenerator:
|
|||
elif name == "rpc_recv":
|
||||
llty = ll.FunctionType(lli32, [llptr])
|
||||
|
||||
elif name == "subkernel_send_message":
|
||||
llty = ll.FunctionType(llvoid, [lli32, lli8, llsliceptr, llptrptr])
|
||||
elif name == "subkernel_load_run":
|
||||
llty = ll.FunctionType(llvoid, [lli32, lli1])
|
||||
elif name == "subkernel_await_finish":
|
||||
llty = ll.FunctionType(llvoid, [lli32, lli64])
|
||||
elif name == "subkernel_await_message":
|
||||
llty = ll.FunctionType(lli8, [lli32, lli64, lli8, lli8])
|
||||
|
||||
# with now-pinning
|
||||
elif name == "now":
|
||||
llty = lli64
|
||||
|
@ -874,6 +883,53 @@ class LLVMIRGenerator:
|
|||
llvalue = self.llbuilder.bitcast(llvalue, llptr.type.pointee)
|
||||
return self.llbuilder.store(llvalue, llptr)
|
||||
|
||||
def process_GetArgFromRemote(self, insn):
|
||||
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
|
||||
name="subkernel.arg.stack")
|
||||
llval = self._build_rpc_recv(insn.arg_type, llstackptr)
|
||||
return llval
|
||||
|
||||
def process_GetOptArgFromRemote(self, insn):
|
||||
# optarg = index < rcv_count ? Some(rcv_recv()) : None
|
||||
llhead = self.llbuilder.basic_block
|
||||
llrcv = self.llbuilder.append_basic_block(name="optarg.get.{}".format(insn.arg_name))
|
||||
|
||||
# argument received
|
||||
self.llbuilder.position_at_end(llrcv)
|
||||
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
|
||||
name="subkernel.arg.stack")
|
||||
llval = self._build_rpc_recv(insn.arg_type, llstackptr)
|
||||
llrpcretblock = self.llbuilder.basic_block # 'return' from rpc_recv, will be needed later
|
||||
|
||||
# create the tail block, needs to be after the rpc recv tail block
|
||||
lltail = self.llbuilder.append_basic_block(name="optarg.tail.{}".format(insn.arg_name))
|
||||
self.llbuilder.branch(lltail)
|
||||
|
||||
# go back to head to add a branch to the tail
|
||||
self.llbuilder.position_at_end(llhead)
|
||||
llargrcvd = self.llbuilder.icmp_unsigned("<", self.map(insn.index), self.map(insn.rcv_count))
|
||||
self.llbuilder.cbranch(llargrcvd, llrcv, lltail)
|
||||
|
||||
# argument not received/after arg recvd
|
||||
self.llbuilder.position_at_end(lltail)
|
||||
|
||||
llargtype = self.llty_of_type(insn.arg_type)
|
||||
|
||||
llphi_arg_present = self.llbuilder.phi(lli1, name="optarg.phi.present.{}".format(insn.arg_name))
|
||||
llphi_arg = self.llbuilder.phi(llargtype, name="optarg.phi.{}".format(insn.arg_name))
|
||||
|
||||
llphi_arg_present.add_incoming(ll.Constant(lli1, 0), llhead)
|
||||
llphi_arg.add_incoming(ll.Constant(llargtype, ll.Undefined), llhead)
|
||||
|
||||
llphi_arg_present.add_incoming(ll.Constant(lli1, 1), llrpcretblock)
|
||||
llphi_arg.add_incoming(llval, llrpcretblock)
|
||||
|
||||
lloptarg = ll.Constant(ll.LiteralStructType([lli1, llargtype]), ll.Undefined)
|
||||
lloptarg = self.llbuilder.insert_value(lloptarg, llphi_arg_present, 0)
|
||||
lloptarg = self.llbuilder.insert_value(lloptarg, llphi_arg, 1)
|
||||
|
||||
return lloptarg
|
||||
|
||||
def attr_index(self, typ, attr):
|
||||
return list(typ.attributes.keys()).index(attr)
|
||||
|
||||
|
@ -898,8 +954,8 @@ class LLVMIRGenerator:
|
|||
def get_global_closure_ptr(self, typ, attr):
|
||||
closure_type = typ.attributes[attr]
|
||||
assert types.is_constructor(typ)
|
||||
assert types.is_function(closure_type) or types.is_rpc(closure_type)
|
||||
if types.is_external_function(closure_type) or types.is_rpc(closure_type):
|
||||
assert types.is_function(closure_type) or types.is_rpc(closure_type) or types.is_subkernel(closure_type)
|
||||
if types.is_external_function(closure_type) or types.is_rpc(closure_type) or types.is_subkernel(closure_type):
|
||||
return None
|
||||
|
||||
llty = self.llty_of_type(typ.attributes[attr])
|
||||
|
@ -1344,6 +1400,29 @@ class LLVMIRGenerator:
|
|||
return self.llbuilder.call(self.llbuiltin("delay_mu"), [llinterval])
|
||||
elif insn.op == "end_catch":
|
||||
return self.llbuilder.call(self.llbuiltin("__artiq_end_catch"), [])
|
||||
elif insn.op == "subkernel_await_args":
|
||||
llmin = self.map(insn.operands[0])
|
||||
llmax = self.map(insn.operands[1])
|
||||
return self.llbuilder.call(self.llbuiltin("subkernel_await_message"),
|
||||
[ll.Constant(lli32, 0), ll.Constant(lli64, 10_000), llmin, llmax],
|
||||
name="subkernel.await.args")
|
||||
elif insn.op == "subkernel_await_finish":
|
||||
llsid = self.map(insn.operands[0])
|
||||
lltimeout = self.map(insn.operands[1])
|
||||
return self.llbuilder.call(self.llbuiltin("subkernel_await_finish"), [llsid, lltimeout],
|
||||
name="subkernel.await.finish")
|
||||
elif insn.op == "subkernel_retrieve_return":
|
||||
llsid = self.map(insn.operands[0])
|
||||
lltimeout = self.map(insn.operands[1])
|
||||
self.llbuilder.call(self.llbuiltin("subkernel_await_message"), [llsid, lltimeout, ll.Constant(lli8, 1), ll.Constant(lli8, 1)],
|
||||
name="subkernel.await.message")
|
||||
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
|
||||
name="subkernel.arg.stack")
|
||||
return self._build_rpc_recv(insn.type, llstackptr)
|
||||
elif insn.op == "subkernel_preload":
|
||||
llsid = self.map(insn.operands[0])
|
||||
return self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, ll.Constant(lli1, 0)],
|
||||
name="subkernel.preload")
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
@ -1426,6 +1505,58 @@ class LLVMIRGenerator:
|
|||
|
||||
return llfun, list(llargs), llarg_attrs, llcallstackptr
|
||||
|
||||
def _build_rpc_recv(self, ret, llstackptr, llnormalblock=None, llunwindblock=None):
|
||||
# T result = {
|
||||
# void *ret_ptr = alloca(sizeof(T));
|
||||
# void *ptr = ret_ptr;
|
||||
# loop: int size = rpc_recv(ptr);
|
||||
# // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
||||
# if(size) { ptr = alloca(size); goto loop; }
|
||||
# else *(T*)ret_ptr
|
||||
# }
|
||||
llprehead = self.llbuilder.basic_block
|
||||
llhead = self.llbuilder.append_basic_block(name="rpc.head")
|
||||
if llunwindblock:
|
||||
llheadu = self.llbuilder.append_basic_block(name="rpc.head.unwind")
|
||||
llalloc = self.llbuilder.append_basic_block(name="rpc.continue")
|
||||
lltail = self.llbuilder.append_basic_block(name="rpc.tail")
|
||||
|
||||
llretty = self.llty_of_type(ret)
|
||||
llslot = self.llbuilder.alloca(llretty, name="rpc.ret.alloc")
|
||||
llslotgen = self.llbuilder.bitcast(llslot, llptr, name="rpc.ret.ptr")
|
||||
self.llbuilder.branch(llhead)
|
||||
|
||||
self.llbuilder.position_at_end(llhead)
|
||||
llphi = self.llbuilder.phi(llslotgen.type, name="rpc.ptr")
|
||||
llphi.add_incoming(llslotgen, llprehead)
|
||||
if llunwindblock:
|
||||
llsize = self.llbuilder.invoke(self.llbuiltin("rpc_recv"), [llphi],
|
||||
llheadu, llunwindblock,
|
||||
name="rpc.size.next")
|
||||
self.llbuilder.position_at_end(llheadu)
|
||||
else:
|
||||
llsize = self.llbuilder.call(self.llbuiltin("rpc_recv"), [llphi],
|
||||
name="rpc.size.next")
|
||||
lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0),
|
||||
name="rpc.done")
|
||||
self.llbuilder.cbranch(lldone, lltail, llalloc)
|
||||
|
||||
self.llbuilder.position_at_end(llalloc)
|
||||
llalloca = self.llbuilder.alloca(lli8, llsize, name="rpc.alloc")
|
||||
llalloca.align = self.max_target_alignment
|
||||
llphi.add_incoming(llalloca, llalloc)
|
||||
self.llbuilder.branch(llhead)
|
||||
|
||||
self.llbuilder.position_at_end(lltail)
|
||||
llret = self.llbuilder.load(llslot, name="rpc.ret")
|
||||
if not ret.fold(False, lambda r, t: r or builtins.is_allocated(t)):
|
||||
# We didn't allocate anything except the slot for the value itself.
|
||||
# Don't waste stack space.
|
||||
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
|
||||
if llnormalblock:
|
||||
self.llbuilder.branch(llnormalblock)
|
||||
return llret
|
||||
|
||||
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
|
||||
llservice = ll.Constant(lli32, fun_type.service)
|
||||
|
||||
|
@ -1501,57 +1632,103 @@ class LLVMIRGenerator:
|
|||
|
||||
return ll.Undefined
|
||||
|
||||
# T result = {
|
||||
# void *ret_ptr = alloca(sizeof(T));
|
||||
# void *ptr = ret_ptr;
|
||||
# loop: int size = rpc_recv(ptr);
|
||||
# // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
||||
# if(size) { ptr = alloca(size); goto loop; }
|
||||
# else *(T*)ret_ptr
|
||||
# }
|
||||
llprehead = self.llbuilder.basic_block
|
||||
llhead = self.llbuilder.append_basic_block(name="rpc.head")
|
||||
if llunwindblock:
|
||||
llheadu = self.llbuilder.append_basic_block(name="rpc.head.unwind")
|
||||
llalloc = self.llbuilder.append_basic_block(name="rpc.continue")
|
||||
lltail = self.llbuilder.append_basic_block(name="rpc.tail")
|
||||
llret = self._build_rpc_recv(fun_type.ret, llstackptr, llnormalblock, llunwindblock)
|
||||
|
||||
llretty = self.llty_of_type(fun_type.ret)
|
||||
llslot = self.llbuilder.alloca(llretty, name="rpc.ret.alloc")
|
||||
llslotgen = self.llbuilder.bitcast(llslot, llptr, name="rpc.ret.ptr")
|
||||
self.llbuilder.branch(llhead)
|
||||
|
||||
self.llbuilder.position_at_end(llhead)
|
||||
llphi = self.llbuilder.phi(llslotgen.type, name="rpc.ptr")
|
||||
llphi.add_incoming(llslotgen, llprehead)
|
||||
if llunwindblock:
|
||||
llsize = self.llbuilder.invoke(self.llbuiltin("rpc_recv"), [llphi],
|
||||
llheadu, llunwindblock,
|
||||
name="rpc.size.next")
|
||||
self.llbuilder.position_at_end(llheadu)
|
||||
else:
|
||||
llsize = self.llbuilder.call(self.llbuiltin("rpc_recv"), [llphi],
|
||||
name="rpc.size.next")
|
||||
lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0),
|
||||
name="rpc.done")
|
||||
self.llbuilder.cbranch(lldone, lltail, llalloc)
|
||||
|
||||
self.llbuilder.position_at_end(llalloc)
|
||||
llalloca = self.llbuilder.alloca(lli8, llsize, name="rpc.alloc")
|
||||
llalloca.align = self.max_target_alignment
|
||||
llphi.add_incoming(llalloca, llalloc)
|
||||
self.llbuilder.branch(llhead)
|
||||
|
||||
self.llbuilder.position_at_end(lltail)
|
||||
llret = self.llbuilder.load(llslot, name="rpc.ret")
|
||||
if not fun_type.ret.fold(False, lambda r, t: r or builtins.is_allocated(t)):
|
||||
# We didn't allocate anything except the slot for the value itself.
|
||||
# Don't waste stack space.
|
||||
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
|
||||
if llnormalblock:
|
||||
self.llbuilder.branch(llnormalblock)
|
||||
return llret
|
||||
|
||||
def _build_subkernel_call(self, fun_loc, fun_type, args):
|
||||
llsid = ll.Constant(lli32, fun_type.sid)
|
||||
tag = b""
|
||||
|
||||
for arg in args:
|
||||
def arg_error_handler(typ):
|
||||
printer = types.TypePrinter()
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"value of type {type}",
|
||||
{"type": printer.name(typ)},
|
||||
arg.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type {type} is not supported in subkernel calls",
|
||||
{"type": printer.name(arg.type)},
|
||||
arg.loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
tag += ir.rpc_tag(arg.type, arg_error_handler)
|
||||
tag += b":"
|
||||
|
||||
# run the kernel first
|
||||
self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, ll.Constant(lli1, 1)])
|
||||
|
||||
# arg sent in the same vein as RPC
|
||||
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
|
||||
name="subkernel.stack")
|
||||
|
||||
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
|
||||
lltagptr = self.llbuilder.alloca(lltag.type)
|
||||
self.llbuilder.store(lltag, lltagptr)
|
||||
|
||||
if args:
|
||||
# only send args if there's anything to send, 'self' is excluded
|
||||
llargs = self.llbuilder.alloca(llptr, ll.Constant(lli32, len(args)),
|
||||
name="subkernel.args")
|
||||
for index, arg in enumerate(args):
|
||||
if builtins.is_none(arg.type):
|
||||
llargslot = self.llbuilder.alloca(llunit,
|
||||
name="subkernel.arg{}".format(index))
|
||||
else:
|
||||
llarg = self.map(arg)
|
||||
llargslot = self.llbuilder.alloca(llarg.type,
|
||||
name="subkernel.arg{}".format(index))
|
||||
self.llbuilder.store(llarg, llargslot)
|
||||
llargslot = self.llbuilder.bitcast(llargslot, llptr)
|
||||
|
||||
llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)])
|
||||
self.llbuilder.store(llargslot, llargptr)
|
||||
|
||||
llargcount = ll.Constant(lli8, len(args))
|
||||
|
||||
self.llbuilder.call(self.llbuiltin("subkernel_send_message"),
|
||||
[llsid, llargcount, lltagptr, llargs])
|
||||
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
|
||||
|
||||
return llsid
|
||||
|
||||
def _build_subkernel_return(self, insn):
|
||||
# builds a remote return.
|
||||
# unlike args, return only sends one thing.
|
||||
if builtins.is_none(insn.value().type):
|
||||
# do not waste time and bandwidth on Nones
|
||||
return
|
||||
|
||||
def ret_error_handler(typ):
|
||||
printer = types.TypePrinter()
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"value of type {type}",
|
||||
{"type": printer.name(typ)},
|
||||
fun_loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"return type {type} is not supported in subkernel returns",
|
||||
{"type": printer.name(fun_type.ret)},
|
||||
fun_loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
tag = ir.rpc_tag(insn.value().type, ret_error_handler)
|
||||
tag += b":"
|
||||
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
|
||||
lltagptr = self.llbuilder.alloca(lltag.type)
|
||||
self.llbuilder.store(lltag, lltagptr)
|
||||
|
||||
llrets = self.llbuilder.alloca(llptr, ll.Constant(lli32, 1),
|
||||
name="subkernel.return")
|
||||
llret = self.map(insn.value())
|
||||
llretslot = self.llbuilder.alloca(llret.type, name="subkernel.retval")
|
||||
self.llbuilder.store(llret, llretslot)
|
||||
llretslot = self.llbuilder.bitcast(llretslot, llptr)
|
||||
self.llbuilder.store(llretslot, llrets)
|
||||
|
||||
llsid = ll.Constant(lli32, 0) # return goes back to master, sid is ignored
|
||||
lltagcount = ll.Constant(lli8, 1) # only one thing is returned
|
||||
self.llbuilder.call(self.llbuiltin("subkernel_send_message"),
|
||||
[llsid, lltagcount, lltagptr, llrets])
|
||||
|
||||
def process_Call(self, insn):
|
||||
functiontyp = insn.target_function().type
|
||||
if types.is_rpc(functiontyp):
|
||||
|
@ -1559,6 +1736,10 @@ class LLVMIRGenerator:
|
|||
functiontyp,
|
||||
insn.arguments(),
|
||||
llnormalblock=None, llunwindblock=None)
|
||||
elif types.is_subkernel(functiontyp):
|
||||
return self._build_subkernel_call(insn.target_function().loc,
|
||||
functiontyp,
|
||||
insn.arguments())
|
||||
elif types.is_external_function(functiontyp):
|
||||
llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn)
|
||||
else:
|
||||
|
@ -1595,6 +1776,11 @@ class LLVMIRGenerator:
|
|||
functiontyp,
|
||||
insn.arguments(),
|
||||
llnormalblock, llunwindblock)
|
||||
elif types.is_subkernel(functiontyp):
|
||||
return self._build_subkernel_call(insn.target_function().loc,
|
||||
functiontyp,
|
||||
insn.arguments(),
|
||||
llnormalblock, llunwindblock)
|
||||
elif types.is_external_function(functiontyp):
|
||||
llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn)
|
||||
else:
|
||||
|
@ -1673,7 +1859,8 @@ class LLVMIRGenerator:
|
|||
attrvalue = getattr(value, attr)
|
||||
is_class_function = (types.is_constructor(typ) and
|
||||
types.is_function(typ.attributes[attr]) and
|
||||
not types.is_external_function(typ.attributes[attr]))
|
||||
not types.is_external_function(typ.attributes[attr]) and
|
||||
not types.is_subkernel(typ.attributes[attr]))
|
||||
if is_class_function:
|
||||
attrvalue = self.embedding_map.specialize_function(typ.instance, attrvalue)
|
||||
if not (types.is_instance(typ) and attr in typ.constant_attributes):
|
||||
|
@ -1758,7 +1945,8 @@ class LLVMIRGenerator:
|
|||
llelts = [self._quote(v, t, lambda: path() + [str(i)])
|
||||
for i, (v, t) in enumerate(zip(value, typ.elts))]
|
||||
return ll.Constant(llty, llelts)
|
||||
elif types.is_rpc(typ) or types.is_external_function(typ) or types.is_builtin_function(typ):
|
||||
elif types.is_rpc(typ) or types.is_external_function(typ) or \
|
||||
types.is_builtin_function(typ) or types.is_subkernel(typ):
|
||||
# RPC, C and builtin functions have no runtime representation.
|
||||
return ll.Constant(llty, ll.Undefined)
|
||||
elif types.is_function(typ):
|
||||
|
@ -1813,6 +2001,8 @@ class LLVMIRGenerator:
|
|||
return llinsn
|
||||
|
||||
def process_Return(self, insn):
|
||||
if insn.remote_return:
|
||||
self._build_subkernel_return(insn)
|
||||
if builtins.is_none(insn.value().type):
|
||||
return self.llbuilder.ret_void()
|
||||
else:
|
||||
|
|
|
@ -385,6 +385,50 @@ class TRPC(Type):
|
|||
def __hash__(self):
|
||||
return hash(self.service)
|
||||
|
||||
class TSubkernel(TFunction):
|
||||
"""
|
||||
A kernel to be run on a satellite.
|
||||
|
||||
:ivar args: (:class:`collections.OrderedDict` of string to :class:`Type`)
|
||||
function arguments
|
||||
:ivar ret: (:class:`Type`)
|
||||
return type
|
||||
:ivar sid: (int) subkernel ID number
|
||||
:ivar destination: (int) satellite destination number
|
||||
"""
|
||||
|
||||
attributes = OrderedDict()
|
||||
|
||||
def __init__(self, args, optargs, ret, sid, destination):
|
||||
assert isinstance(ret, Type)
|
||||
super().__init__(args, optargs, ret)
|
||||
self.sid, self.destination = sid, destination
|
||||
self.delay = TFixedDelay(iodelay.Const(0))
|
||||
|
||||
def unify(self, other):
|
||||
if other is self:
|
||||
return
|
||||
if isinstance(other, TSubkernel) and \
|
||||
self.sid == other.sid and \
|
||||
self.destination == other.destination:
|
||||
self.ret.unify(other.ret)
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
else:
|
||||
raise UnificationError(self, other)
|
||||
|
||||
def __repr__(self):
|
||||
if getattr(builtins, "__in_sphinx__", False):
|
||||
return str(self)
|
||||
return "artiq.compiler.types.TSubkernel({})".format(repr(self.ret))
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TSubkernel) and \
|
||||
self.sid == other.sid
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sid)
|
||||
|
||||
class TBuiltin(Type):
|
||||
"""
|
||||
An instance of builtin type. Every instance of a builtin
|
||||
|
@ -644,6 +688,9 @@ def is_function(typ):
|
|||
def is_rpc(typ):
|
||||
return isinstance(typ.find(), TRPC)
|
||||
|
||||
def is_subkernel(typ):
|
||||
return isinstance(typ.find(), TSubkernel)
|
||||
|
||||
def is_external_function(typ, name=None):
|
||||
typ = typ.find()
|
||||
if name is None:
|
||||
|
@ -810,6 +857,10 @@ class TypePrinter(object):
|
|||
return "[rpc{} #{}](...)->{}".format(typ.service,
|
||||
" async" if typ.is_async else "",
|
||||
self.name(typ.ret, depth + 1))
|
||||
elif isinstance(typ, TSubkernel):
|
||||
return "<subkernel{} dest#{}>->{}".format(typ.sid,
|
||||
typ.destination,
|
||||
self.name(typ.ret, depth + 1))
|
||||
elif isinstance(typ, TBuiltinFunction):
|
||||
return "<function {}>".format(typ.name)
|
||||
elif isinstance(typ, (TConstructor, TExceptionConstructor)):
|
||||
|
|
|
@ -73,8 +73,8 @@ def main():
|
|||
finally:
|
||||
dataset_db.close_db()
|
||||
|
||||
if object_map.has_rpc():
|
||||
raise ValueError("Experiment must not use RPC")
|
||||
if object_map.has_rpc_or_subkernel():
|
||||
raise ValueError("Experiment must not use RPC or subkernels")
|
||||
|
||||
output = args.output
|
||||
if output is None:
|
||||
|
|
|
@ -7,7 +7,7 @@ from functools import wraps
|
|||
import numpy
|
||||
|
||||
|
||||
__all__ = ["kernel", "portable", "rpc", "syscall", "host_only",
|
||||
__all__ = ["kernel", "portable", "rpc", "subkernel", "syscall", "host_only",
|
||||
"kernel_from_string", "set_time_manager", "set_watchdog_factory",
|
||||
"TerminationRequested"]
|
||||
|
||||
|
@ -21,7 +21,7 @@ __all__.extend(kernel_globals)
|
|||
|
||||
|
||||
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo",
|
||||
"core_name portable function syscall forbidden flags")
|
||||
"core_name portable function syscall forbidden destination flags")
|
||||
|
||||
def kernel(arg=None, flags={}):
|
||||
"""
|
||||
|
@ -54,7 +54,7 @@ def kernel(arg=None, flags={}):
|
|||
return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs)
|
||||
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
|
||||
core_name=arg, portable=False, function=function, syscall=None,
|
||||
forbidden=False, flags=set(flags))
|
||||
forbidden=False, destination=None, flags=set(flags))
|
||||
return run_on_core
|
||||
return inner_decorator
|
||||
elif arg is None:
|
||||
|
@ -64,6 +64,50 @@ def kernel(arg=None, flags={}):
|
|||
else:
|
||||
return kernel("core", flags)(arg)
|
||||
|
||||
def subkernel(arg=None, destination=0, flags={}):
|
||||
"""
|
||||
This decorator marks an object's method or function for execution on a satellite device.
|
||||
Destination must be given, and it must be between 1 and 255 (inclusive).
|
||||
|
||||
Subkernels behave similarly to kernels, with few key differences:
|
||||
|
||||
- they are started from main kernels,
|
||||
- they do not support RPCs, or running subsequent subkernels on other devices,
|
||||
- but they can call other kernels or subkernels with the same destination.
|
||||
|
||||
Subkernels can accept arguments and return values. However, they must be fully
|
||||
annotated with ARTIQ types.
|
||||
|
||||
To call a subkernel, call it like a normal function.
|
||||
|
||||
To await its finishing execution, call ``subkernel_await(subkernel, [timeout])``.
|
||||
The timeout parameter is optional, and by default is equal to 10000 (miliseconds).
|
||||
This time can be adjusted for subkernels that take a long time to execute.
|
||||
|
||||
The compiled subkernel is copied to satellites, but not yet to the kernel core
|
||||
until it's called. For bigger subkernels it may take some time before they
|
||||
actually start running. To help with that, subkernels can be preloaded, with
|
||||
``subkernel_preload(subkernel)`` function. A call to a preloaded subkernel
|
||||
will take less time, but only one subkernel can be preloaded at a time.
|
||||
"""
|
||||
if isinstance(arg, str):
|
||||
def inner_decorator(function):
|
||||
@wraps(function)
|
||||
def run_subkernel(self, *k_args, **k_kwargs):
|
||||
sid = getattr(self, arg).prepare_subkernel(destination, run_subkernel, ((self,) + k_args), k_kwargs)
|
||||
getattr(self, arg).run_subkernel(sid)
|
||||
run_subkernel.artiq_embedded = _ARTIQEmbeddedInfo(
|
||||
core_name=arg, portable=False, function=function, syscall=None,
|
||||
forbidden=False, destination=destination, flags=set(flags))
|
||||
return run_subkernel
|
||||
return inner_decorator
|
||||
elif arg is None:
|
||||
def inner_decorator(function):
|
||||
return subkernel(function, destination, flags)
|
||||
return inner_decorator
|
||||
else:
|
||||
return subkernel("core", destination, flags)(arg)
|
||||
|
||||
def portable(arg=None, flags={}):
|
||||
"""
|
||||
This decorator marks a function for execution on the same device as its
|
||||
|
@ -84,7 +128,7 @@ def portable(arg=None, flags={}):
|
|||
else:
|
||||
arg.artiq_embedded = \
|
||||
_ARTIQEmbeddedInfo(core_name=None, portable=True, function=arg, syscall=None,
|
||||
forbidden=False, flags=set(flags))
|
||||
forbidden=False, destination=None, flags=set(flags))
|
||||
return arg
|
||||
|
||||
def rpc(arg=None, flags={}):
|
||||
|
@ -100,7 +144,7 @@ def rpc(arg=None, flags={}):
|
|||
else:
|
||||
arg.artiq_embedded = \
|
||||
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=arg, syscall=None,
|
||||
forbidden=False, flags=set(flags))
|
||||
forbidden=False, destination=None, flags=set(flags))
|
||||
return arg
|
||||
|
||||
def syscall(arg=None, flags={}):
|
||||
|
@ -118,7 +162,7 @@ def syscall(arg=None, flags={}):
|
|||
def inner_decorator(function):
|
||||
function.artiq_embedded = \
|
||||
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=None,
|
||||
syscall=arg, forbidden=False,
|
||||
syscall=arg, forbidden=False, destination=None,
|
||||
flags=set(flags))
|
||||
return function
|
||||
return inner_decorator
|
||||
|
@ -136,7 +180,7 @@ def host_only(function):
|
|||
"""
|
||||
function.artiq_embedded = \
|
||||
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, syscall=None,
|
||||
forbidden=True, flags={})
|
||||
forbidden=True, destination=None, flags={})
|
||||
return function
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue