1
0
forked from M-Labs/nac3

[artiq] add subkernel decorator, call subkernels

This commit is contained in:
mwojcik 2025-01-03 11:34:19 +08:00
parent 2f0847d77b
commit fa22ac80ae
5 changed files with 321 additions and 4 deletions

View File

@ -5,12 +5,17 @@ class EmbeddingMap:
self.string_map = {} self.string_map = {}
self.string_reverse_map = {} self.string_reverse_map = {}
self.function_map = {} self.function_map = {}
self.subkernel_map = {}
self.attributes_writeback = [] self.attributes_writeback = []
def store_function(self, key, fun): def store_function(self, key, fun):
self.function_map[key] = fun self.function_map[key] = fun
return key return key
def store_subkernel(self, key, fun):
self.subkernel_map[key] = fun
return key
def store_object(self, obj): def store_object(self, obj):
obj_id = id(obj) obj_id = id(obj)
if obj_id in self.object_inverse_map: if obj_id in self.object_inverse_map:
@ -37,3 +42,5 @@ class EmbeddingMap:
def retrieve_str(self, key): def retrieve_str(self, key):
return self.string_map[key] return self.string_map[key]
def subkernels(self):
return self.subkernel_map

View File

@ -13,7 +13,7 @@ __all__ = [
"Kernel", "KernelInvariant", "virtual", "ConstGeneric", "Kernel", "KernelInvariant", "virtual", "ConstGeneric",
"Option", "Some", "none", "UnwrapNoneError", "Option", "Some", "none", "UnwrapNoneError",
"round64", "floor64", "ceil64", "round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3", "extern", "kernel", "subkernel", "portable", "nac3",
"rpc", "ms", "us", "ns", "rpc", "ms", "us", "ns",
"print_int32", "print_int64", "print_int32", "print_int64",
"Core", "TTLOut", "Core", "TTLOut",
@ -137,6 +137,14 @@ def kernel(function_or_method):
raise RuntimeError("Kernel functions need explicit core.run()") raise RuntimeError("Kernel functions need explicit core.run()")
return run_on_core return run_on_core
def subkernel(function_or_method, destination):
assert 0 < destination < 255
register_function(function_or_method)
@wraps(function_or_method)
def run_on_core(*args, **kwargs):
raise RuntimeError("Subkernels cannot be called by the host")
run_on_core._destination = destination
return
def portable(function): def portable(function):
"""Decorates a function or method to be executed on the same device (host/core device) as the caller.""" """Decorates a function or method to be executed on the same device (host/core device) as the caller."""

View File

@ -962,6 +962,189 @@ fn rpc_codegen_callback_fn<'ctx>(
} }
} }
fn subkernel_call_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
destination: u8,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type();
let bool_type = ctx.ctx.bool_type();
let size_type = generator.get_size_type(ctx.ctx);
let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let subkernel_id = int32.const_int(fun.1 .0 as u64, false);
let destination = int8.const_int(destination as u64, false);
// -- setup rpc tags
let mut tag = Vec::new();
if obj.is_some() {
tag.push(b'O');
}
for arg in &fun.0.args {
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
}
tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher);
let hash = format!("{}", hasher.finish());
let tag_ptr = ctx
.module
.get_global(hash.as_str())
.unwrap_or_else(|| {
let tag_arr_ptr = ctx.module.add_global(
int8.array_type(tag.len() as u32),
None,
format!("tagptr{}", fun.1 .0).as_str(),
);
tag_arr_ptr.set_initializer(&int8.const_array(
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
));
tag_arr_ptr.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
tag_ptr.set_linkage(Linkage::Private);
tag_ptr.set_initializer(&ctx.ctx.const_struct(
&[
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag.len() as u64, false).into(),
],
false,
));
tag_ptr
})
.as_pointer_value();
let arg_length = args.len() + usize::from(obj.is_some());
let stackptr = call_stacksave(ctx, Some("subkernel.stack"));
let args_ptr = ctx
.builder
.build_array_alloca(
ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false),
"argptr",
)
.unwrap();
// -- subkernel args handling
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in args {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
// in old compiler, subkernels would generate default values, and they would not be sent
// TODO: see if it makes sense
for k in keys {
mapping
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
}
// 'self' is skipped for subkernels
let no_self: Vec<_> = fun.0.args.iter().filter(|arg| arg.name != "self".into()).collect();
// reorder the parameters
let mut real_params = no_self
.iter()
.map(|arg| {
mapping
.remove(&arg.name)
.unwrap()
.to_basic_value_enum(ctx, generator, arg.ty)
.map(|llvm_val| (llvm_val, arg.ty))
})
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
} else {
// should be an error here...
panic!("only host object is allowed");
}
}
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
let arg_ptr = unsafe {
ctx.builder.build_gep(
args_ptr,
&[int32.const_int(i as u64, false)],
&format!("subkernel.arg{i}"),
)
}
.unwrap();
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
}
// call subkernel first
let subkernel_call = ctx.module.get_function("subkernel_load_run").unwrap_or_else(|| {
ctx.module.add_function(
"subkernel_load_run",
ctx.ctx.void_type().fn_type(&[int32.into(), int8.into(), bool_type.into()], false),
None,
)
});
ctx.builder
.build_call(
subkernel_call,
&[subkernel_id.into(), destination.into(), bool_type.const_all_ones().into()],
"subkernel.call",
)
.unwrap();
// send the arguments (if any)
if real_params.len() > 0 {
let subkernel_send =
ctx.module.get_function("subkernel_send_message").unwrap_or_else(|| {
ctx.module.add_function(
"subkernel_send_message",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
bool_type.into(),
int8.into(),
int8.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
)
});
ctx.builder
.build_call(
subkernel_send,
&[
subkernel_id.into(),
bool_type.const_zero().into(),
destination.into(),
int32.const_int(real_params.len() as u64, false).into(),
tag_ptr.into(),
args_ptr.into(),
],
"subkernel.send",
)
.unwrap();
}
if real_params.len() > 0 {
// reclaim stack space used by arguments
call_stackrestore(ctx, stackptr);
}
// calling a subkernel returns nothing
Ok(None)
}
pub fn attributes_writeback<'ctx>( pub fn attributes_writeback<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
@ -1076,6 +1259,12 @@ pub fn rpc_codegen_callback(is_async: bool) -> Arc<GenCall> {
}))) })))
} }
pub fn subkernel_call_codegen_callback(destination: u8) -> Arc<GenCall> {
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
subkernel_call_codegen_callback_fn(ctx, obj, fun, args, generator, destination)
})))
}
/// Returns the `fprintf` format constant for the given [`llvm_int_t`][`IntType`] on a platform with /// Returns the `fprintf` format constant for the given [`llvm_int_t`][`IntType`] on a platform with
/// [`llvm_usize`] as its native word size. /// [`llvm_usize`] as its native word size.
/// ///

View File

@ -60,7 +60,8 @@ use nac3core::{
use nac3ld::Linker; use nac3ld::Linker;
use codegen::{ use codegen::{
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator, attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback,
subkernel_call_codegen_callback, ArtiqCodeGenerator,
}; };
use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver}; use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
use timeline::TimeFns; use timeline::TimeFns;
@ -208,7 +209,10 @@ impl Nac3 {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) { if let Some(id) = decorator_id_string(decorator) {
id == "kernel" || id == "portable" || id == "rpc" id == "kernel"
|| id == "portable"
|| id == "rpc"
|| id == "subkernel"
} else { } else {
false false
} }
@ -222,7 +226,11 @@ impl Nac3 {
StmtKind::FunctionDef { ref decorator_list, .. } => { StmtKind::FunctionDef { ref decorator_list, .. } => {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) { if let Some(id) = decorator_id_string(decorator) {
id == "extern" || id == "kernel" || id == "portable" || id == "rpc" id == "extern"
|| id == "kernel"
|| id == "portable"
|| id == "rpc"
|| id == "subkernel"
} else { } else {
false false
} }
@ -394,6 +402,7 @@ impl Nac3 {
let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py);
let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); let store_str = embedding_map.getattr("store_str").unwrap().to_object(py);
let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py); let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py);
let store_subk = embedding_map.getattr("store_subkernel").unwrap().to_object(py);
let host_attributes = embedding_map.getattr("attributes_writeback").unwrap().to_object(py); let host_attributes = embedding_map.getattr("attributes_writeback").unwrap().to_object(py);
let global_value_ids: Arc<RwLock<HashMap<_, _>>> = Arc::new(RwLock::new(HashMap::new())); let global_value_ids: Arc<RwLock<HashMap<_, _>>> = Arc::new(RwLock::new(HashMap::new()));
let helper = PythonHelper { let helper = PythonHelper {
@ -424,6 +433,7 @@ impl Nac3 {
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new(); let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
let mut rpc_ids = vec![]; let mut rpc_ids = vec![];
let mut subkernel_ids = vec![];
for (stmt, path, module) in &self.top_levels { for (stmt, path, module) in &self.top_levels {
let py_module: &PyAny = module.extract(py)?; let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
@ -507,6 +517,22 @@ impl Nac3 {
.any(|constant| *constant == Constant::Str("async".into())) .any(|constant| *constant == Constant::Str("async".into()))
}); });
rpc_ids.push((None, def_id, is_async)); rpc_ids.push((None, def_id, is_async));
} else if decorator_list.iter().any(|decorator| {
decorator_id_string(decorator) == Some("subkernel".to_string())
}) {
if let Some(Constant::Int(dest)) = decorator_get_destination(decorator_list)
{
store_subk
.call1(
py,
(
def_id.0.into_py(py),
module.getattr(py, name.to_string().as_str()).unwrap(),
),
)
.unwrap();
subkernel_ids.push((None, def_id, dest));
}
} }
} }
StmtKind::ClassDef { name, body, .. } => { StmtKind::ClassDef { name, body, .. } => {
@ -529,6 +555,24 @@ impl Nac3 {
))); )));
} }
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async)); rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
} else if decorator_list.iter().any(|decorator| {
decorator_id_string(decorator) == Some("subkernel".to_string())
}) {
if name == &"__init__".into() {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nThe constructor of class {} should not be decorated with subkernel decorator (at {})",
class_name, stmt.location
)));
}
if let Some(Constant::Int(dest)) =
decorator_get_destination(decorator_list)
{
subkernel_ids.push((
Some((class_obj.clone(), *name)),
def_id,
dest,
));
}
} }
} }
} }
@ -667,6 +711,45 @@ impl Nac3 {
} }
} }
} }
for (class_data, id, destination) in &subkernel_ids {
let mut def = defs[id.0].write();
match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => {
*codegen_callback =
Some(subkernel_call_codegen_callback(*destination as u8));
}
TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap();
for (name, _, id) in &*methods {
if name != method_name {
continue;
}
if let TopLevelDef::Function { codegen_callback, .. } =
&mut *defs[id.0].write()
{
*codegen_callback =
Some(subkernel_call_codegen_callback(*destination as u8));
store_fun
.call1(
py,
(
id.0.into_py(py),
class_def
.getattr(py, name.to_string().as_str())
.unwrap(),
),
)
.unwrap();
}
}
}
TopLevelDef::Variable { .. } => {
return Err(CompileError::new_err(String::from(
"Unsupported @subkernel annotation on global variable",
)))
}
}
}
} }
let instance = { let instance = {
@ -923,6 +1006,23 @@ fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
flags flags
} }
/// Retrieves destination from subkernel decorator.
fn decorator_get_destination(decorator_list: &Vec<Located<ExprKind>>) -> Option<Constant> {
for decorator in decorator_list {
if let ExprKind::Call { keywords, .. } = &decorator.node {
for keyword in keywords {
if keyword.node.arg != Some("destination".into()) {
continue;
}
if let ExprKind::Constant { value, .. } = &keyword.node.value.node {
return Some(value.clone());
}
}
}
}
None
}
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> { fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
let linker_args = vec![ let linker_args = vec![
"-shared".to_string(), "-shared".to_string(),

View File

@ -1832,6 +1832,19 @@ impl TopLevelComposer {
continue; continue;
} }
} }
if let ExprKind::Call { func, .. } = &decorator_list[0].node {
if matches!(&func.node, ExprKind::Name { id, .. } if id == &"subkernel".into())
{
let TopLevelDef::Function { instance_to_symbol, .. } =
&mut *def.write()
else {
unreachable!()
};
instance_to_symbol.insert(String::new(), simple_name.to_string());
continue;
}
}
} }
let fun_body = let fun_body =