diff --git a/nac3artiq/demo/embedding_map.py b/nac3artiq/demo/embedding_map.py index a43af69..d7f909f 100644 --- a/nac3artiq/demo/embedding_map.py +++ b/nac3artiq/demo/embedding_map.py @@ -5,12 +5,17 @@ class EmbeddingMap: self.string_map = {} self.string_reverse_map = {} self.function_map = {} + self.subkernel_map = {} self.attributes_writeback = [] def store_function(self, key, fun): self.function_map[key] = fun return key + def store_subkernel(self, key, fun): + self.subkernel_map[key] = fun + return key + def store_object(self, obj): obj_id = id(obj) if obj_id in self.object_inverse_map: @@ -37,3 +42,5 @@ class EmbeddingMap: def retrieve_str(self, key): return self.string_map[key] + def subkernels(self): + return self.subkernel_map \ No newline at end of file diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 62d32cc..e4e5d8a 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -13,7 +13,7 @@ __all__ = [ "Kernel", "KernelInvariant", "virtual", "ConstGeneric", "Option", "Some", "none", "UnwrapNoneError", "round64", "floor64", "ceil64", - "extern", "kernel", "portable", "nac3", + "extern", "kernel", "subkernel", "portable", "nac3", "rpc", "ms", "us", "ns", "print_int32", "print_int64", "Core", "TTLOut", @@ -137,6 +137,14 @@ def kernel(function_or_method): raise RuntimeError("Kernel functions need explicit core.run()") return run_on_core +def subkernel(function_or_method, destination): + assert 0 < destination < 255 + register_function(function_or_method) + @wraps(function_or_method) + def run_on_core(*args, **kwargs): + raise RuntimeError("Subkernels cannot be called by the host") + run_on_core._destination = destination + return def portable(function): """Decorates a function or method to be executed on the same device (host/core device) as the caller.""" diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 653f41a..baa9a29 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -962,6 +962,189 @@ fn rpc_codegen_callback_fn<'ctx>( } } +fn subkernel_call_codegen_callback_fn<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, + destination: u8, +) -> Result>, String> { + let int8 = ctx.ctx.i8_type(); + let int32 = ctx.ctx.i32_type(); + let bool_type = ctx.ctx.bool_type(); + let size_type = generator.get_size_type(ctx.ctx); + let ptr_type = int8.ptr_type(AddressSpace::default()); + let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); + + let subkernel_id = int32.const_int(fun.1 .0 as u64, false); + let destination = int8.const_int(destination as u64, false); + // -- setup rpc tags + let mut tag = Vec::new(); + if obj.is_some() { + tag.push(b'O'); + } + for arg in &fun.0.args { + gen_rpc_tag(ctx, arg.ty, &mut tag)?; + } + tag.push(b':'); + gen_rpc_tag(ctx, fun.0.ret, &mut tag)?; + + let mut hasher = DefaultHasher::new(); + tag.hash(&mut hasher); + let hash = format!("{}", hasher.finish()); + + let tag_ptr = ctx + .module + .get_global(hash.as_str()) + .unwrap_or_else(|| { + let tag_arr_ptr = ctx.module.add_global( + int8.array_type(tag.len() as u32), + None, + format!("tagptr{}", fun.1 .0).as_str(), + ); + tag_arr_ptr.set_initializer(&int8.const_array( + &tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::>(), + )); + tag_arr_ptr.set_linkage(Linkage::Private); + let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); + tag_ptr.set_linkage(Linkage::Private); + tag_ptr.set_initializer(&ctx.ctx.const_struct( + &[ + tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(), + size_type.const_int(tag.len() as u64, false).into(), + ], + false, + )); + tag_ptr + }) + .as_pointer_value(); + + let arg_length = args.len() + usize::from(obj.is_some()); + + let stackptr = call_stacksave(ctx, Some("subkernel.stack")); + let args_ptr = ctx + .builder + .build_array_alloca( + ptr_type, + ctx.ctx.i32_type().const_int(arg_length as u64, false), + "argptr", + ) + .unwrap(); + + // -- subkernel args handling + let mut keys = fun.0.args.clone(); + let mut mapping = HashMap::new(); + for (key, value) in args { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + } + + // default value handling + // in old compiler, subkernels would generate default values, and they would not be sent + // TODO: see if it makes sense + for k in keys { + mapping + .insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()); + } + + // 'self' is skipped for subkernels + let no_self: Vec<_> = fun.0.args.iter().filter(|arg| arg.name != "self".into()).collect(); + // reorder the parameters + let mut real_params = no_self + .iter() + .map(|arg| { + mapping + .remove(&arg.name) + .unwrap() + .to_basic_value_enum(ctx, generator, arg.ty) + .map(|llvm_val| (llvm_val, arg.ty)) + }) + .collect::, _>>()?; + + if let Some(obj) = obj { + if let ValueEnum::Static(obj_val) = obj.1 { + real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0)); + } else { + // should be an error here... + panic!("only host object is allowed"); + } + } + + for (i, (arg, arg_ty)) in real_params.iter().enumerate() { + let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i)); + let arg_ptr = unsafe { + ctx.builder.build_gep( + args_ptr, + &[int32.const_int(i as u64, false)], + &format!("subkernel.arg{i}"), + ) + } + .unwrap(); + ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); + } + + // call subkernel first + let subkernel_call = ctx.module.get_function("subkernel_load_run").unwrap_or_else(|| { + ctx.module.add_function( + "subkernel_load_run", + ctx.ctx.void_type().fn_type(&[int32.into(), int8.into(), bool_type.into()], false), + None, + ) + }); + ctx.builder + .build_call( + subkernel_call, + &[subkernel_id.into(), destination.into(), bool_type.const_all_ones().into()], + "subkernel.call", + ) + .unwrap(); + + // send the arguments (if any) + if real_params.len() > 0 { + let subkernel_send = + ctx.module.get_function("subkernel_send_message").unwrap_or_else(|| { + ctx.module.add_function( + "subkernel_send_message", + ctx.ctx.void_type().fn_type( + &[ + int32.into(), + bool_type.into(), + int8.into(), + int8.into(), + tag_ptr_type.ptr_type(AddressSpace::default()).into(), + ptr_type.ptr_type(AddressSpace::default()).into(), + ], + false, + ), + None, + ) + }); + + ctx.builder + .build_call( + subkernel_send, + &[ + subkernel_id.into(), + bool_type.const_zero().into(), + destination.into(), + int32.const_int(real_params.len() as u64, false).into(), + tag_ptr.into(), + args_ptr.into(), + ], + "subkernel.send", + ) + .unwrap(); + } + + if real_params.len() > 0 { + // reclaim stack space used by arguments + call_stackrestore(ctx, stackptr); + } + + // calling a subkernel returns nothing + Ok(None) +} + pub fn attributes_writeback<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut dyn CodeGenerator, @@ -1076,6 +1259,12 @@ pub fn rpc_codegen_callback(is_async: bool) -> Arc { }))) } +pub fn subkernel_call_codegen_callback(destination: u8) -> Arc { + Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { + subkernel_call_codegen_callback_fn(ctx, obj, fun, args, generator, destination) + }))) +} + /// Returns the `fprintf` format constant for the given [`llvm_int_t`][`IntType`] on a platform with /// [`llvm_usize`] as its native word size. /// diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index ca2f2f1..a6b66c4 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -60,7 +60,8 @@ use nac3core::{ use nac3ld::Linker; use codegen::{ - attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator, + attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, + subkernel_call_codegen_callback, ArtiqCodeGenerator, }; use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver}; use timeline::TimeFns; @@ -208,7 +209,10 @@ impl Nac3 { if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { decorator_list.iter().any(|decorator| { if let Some(id) = decorator_id_string(decorator) { - id == "kernel" || id == "portable" || id == "rpc" + id == "kernel" + || id == "portable" + || id == "rpc" + || id == "subkernel" } else { false } @@ -222,7 +226,11 @@ impl Nac3 { StmtKind::FunctionDef { ref decorator_list, .. } => { decorator_list.iter().any(|decorator| { if let Some(id) = decorator_id_string(decorator) { - id == "extern" || id == "kernel" || id == "portable" || id == "rpc" + id == "extern" + || id == "kernel" + || id == "portable" + || id == "rpc" + || id == "subkernel" } else { false } @@ -394,6 +402,7 @@ impl Nac3 { let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py); + let store_subk = embedding_map.getattr("store_subkernel").unwrap().to_object(py); let host_attributes = embedding_map.getattr("attributes_writeback").unwrap().to_object(py); let global_value_ids: Arc>> = Arc::new(RwLock::new(HashMap::new())); let helper = PythonHelper { @@ -424,6 +433,7 @@ impl Nac3 { let mut module_to_resolver_cache: HashMap = HashMap::new(); let mut rpc_ids = vec![]; + let mut subkernel_ids = vec![]; for (stmt, path, module) in &self.top_levels { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; @@ -507,6 +517,22 @@ impl Nac3 { .any(|constant| *constant == Constant::Str("async".into())) }); rpc_ids.push((None, def_id, is_async)); + } else if decorator_list.iter().any(|decorator| { + decorator_id_string(decorator) == Some("subkernel".to_string()) + }) { + if let Some(Constant::Int(dest)) = decorator_get_destination(decorator_list) + { + store_subk + .call1( + py, + ( + def_id.0.into_py(py), + module.getattr(py, name.to_string().as_str()).unwrap(), + ), + ) + .unwrap(); + subkernel_ids.push((None, def_id, dest)); + } } } StmtKind::ClassDef { name, body, .. } => { @@ -529,6 +555,24 @@ impl Nac3 { ))); } rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async)); + } else if decorator_list.iter().any(|decorator| { + decorator_id_string(decorator) == Some("subkernel".to_string()) + }) { + if name == &"__init__".into() { + return Err(CompileError::new_err(format!( + "compilation failed\n----------\nThe constructor of class {} should not be decorated with subkernel decorator (at {})", + class_name, stmt.location + ))); + } + if let Some(Constant::Int(dest)) = + decorator_get_destination(decorator_list) + { + subkernel_ids.push(( + Some((class_obj.clone(), *name)), + def_id, + dest, + )); + } } } } @@ -667,6 +711,45 @@ impl Nac3 { } } } + for (class_data, id, destination) in &subkernel_ids { + let mut def = defs[id.0].write(); + match &mut *def { + TopLevelDef::Function { codegen_callback, .. } => { + *codegen_callback = + Some(subkernel_call_codegen_callback(*destination as u8)); + } + TopLevelDef::Class { methods, .. } => { + let (class_def, method_name) = class_data.as_ref().unwrap(); + for (name, _, id) in &*methods { + if name != method_name { + continue; + } + if let TopLevelDef::Function { codegen_callback, .. } = + &mut *defs[id.0].write() + { + *codegen_callback = + Some(subkernel_call_codegen_callback(*destination as u8)); + store_fun + .call1( + py, + ( + id.0.into_py(py), + class_def + .getattr(py, name.to_string().as_str()) + .unwrap(), + ), + ) + .unwrap(); + } + } + } + TopLevelDef::Variable { .. } => { + return Err(CompileError::new_err(String::from( + "Unsupported @subkernel annotation on global variable", + ))) + } + } + } } let instance = { @@ -923,6 +1006,23 @@ fn decorator_get_flags(decorator: &Located) -> Vec { flags } +/// Retrieves destination from subkernel decorator. +fn decorator_get_destination(decorator_list: &Vec>) -> Option { + for decorator in decorator_list { + if let ExprKind::Call { keywords, .. } = &decorator.node { + for keyword in keywords { + if keyword.node.arg != Some("destination".into()) { + continue; + } + if let ExprKind::Constant { value, .. } = &keyword.node.value.node { + return Some(value.clone()); + } + } + } + } + None +} + fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> { let linker_args = vec![ "-shared".to_string(), diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index bd9a921..2ad9a32 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1832,6 +1832,19 @@ impl TopLevelComposer { continue; } } + + if let ExprKind::Call { func, .. } = &decorator_list[0].node { + if matches!(&func.node, ExprKind::Name { id, .. } if id == &"subkernel".into()) + { + let TopLevelDef::Function { instance_to_symbol, .. } = + &mut *def.write() + else { + unreachable!() + }; + instance_to_symbol.insert(String::new(), simple_name.to_string()); + continue; + } + } } let fun_body =