diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 5906320b..4d163278 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -482,6 +482,9 @@ impl Nac3 { if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap(); rpc_ids.push((None, def_id)); + } else if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "subkernel".into())) { + store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap(); + //rpc_ids.push((None, def_id)); } } StmtKind::ClassDef { name, body, .. } => { diff --git a/nac3artiq/src/subkernels.rs b/nac3artiq/src/subkernels.rs index 3bbd1725..f2c777b4 100644 --- a/nac3artiq/src/subkernels.rs +++ b/nac3artiq/src/subkernels.rs @@ -58,10 +58,20 @@ impl Subkernels { ] } - fn gen_subkernel_await<'ctx>(ctx: &mut CodeGenContext<'ctx, '_>, awaited: BasicValueEnum<'ctx>, timeout: BasicValueEnum<'ctx>) { + fn gen_subkernel_await<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, + ) { let sid_type = ctx.ctx.i32_type(); - // how to deal with optional arguments? let timeout_type = ctx.ctx.i64_type(); + assert!(matches!(args.len(), 1..=2)); + let timeout = if args.len() == 1 { + timeout_type.const_zero().to_basic_value_enum(context, generator, obj_ty)?; // ? + } else { args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; } // ? + let subkernel_await_finish = ctx.module.get_function("subkernel_await_finish").unwrap_or_else(|| { ctx.module.add_function( "subkernel_await_finish", @@ -74,18 +84,257 @@ impl Subkernels { // generate RPC for receiving return value depending on fun ret } - fn gen_subkernel_preload<'ctx>(ctx: &mut CodeGenContext<'ctx, '_>, preloaded: BasicValueEnum<'ctx>) { + fn gen_subkernel_preload<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, + ) { + assert_eq!(args.len(), 1); let sid_type = ctx.ctx.i32_type(); let dest_type = ctx.ctx.i8_type(); - let run_type = ctx.ctx.i1_type(); + let run_type = ctx.ctx.bool_type(); let subkernel_load_run = ctx.module.get_function("subkernel_load_run").unwrap_or_else(|| { ctx.module.add_function( "subkernel_load_run", - ctx.ctx.void_type().fn_type(&[sid_type.into(), dest_type, run_type], false), + ctx.ctx.void_type().fn_type(&[sid_type.into(), dest_type.into(), run_type.into()], false), None, ) }); - // retrieve destination and sid from the fn (?) - // call or invoke + + let subkernel_id = int32.const_int(fun.1 .0 as u64, false); + let destination = int32.const_int(fun.? as u64, false); // TODO + ctx.builder + .build_call_or_invoke(subkernel_load_run, &[subkernel_id.into(), destination.into(), run_type.const_zero()], "subkernel.preload") + .unwrap(); + } + + fn subkernel_callback_fn<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, + ) -> Result>, String> { + let int8 = ctx.ctx.i8_type(); + let int32 = ctx.ctx.i32_type(); + let size_type = generator.get_size_type(ctx.ctx); + let ptr_type = int8.ptr_type(AddressSpace::default()); + let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); + + let subkernel_id = int32.const_int(fun.1 .0 as u64, false); + let destination = int32.const_int(fun.? as u64, false); // TODO + + // -- start the subkernel + let sid_type = ctx.ctx.i32_type(); + let dest_type = ctx.ctx.i8_type(); + let run_type = ctx.ctx.i1_type(); + let subkernel_start = ctx.module.get_function("subkernel_load_run").unwrap_or_else(|| { + ctx.module.add_function( + "subkernel_load_run", + ctx.ctx.void_type().fn_type( + &[ + ctx.ctx.void_type().fn_type(&[sid_type.into(), dest_type.into(), run_type.into()], false), + ], + false, + ), + None, + ) + }); + ctx.builder + .build_call_or_invoke(subkernel_start, &[subkernel_id.into(), destination.into(), run_type.const_int(1, false)], "subkernel.run") + .unwrap(); + // -- setup rpc tags + let mut tag = Vec::new(); + if obj.is_some() { + tag.push(b'O'); + } + for arg in &fun.0.args { + gen_rpc_tag(ctx, arg.ty, &mut tag)?; + } + tag.push(b':'); + gen_rpc_tag(ctx, fun.0.ret, &mut tag)?; + + let mut hasher = DefaultHasher::new(); + tag.hash(&mut hasher); + let hash = format!("{}", hasher.finish()); + + let tag_ptr = ctx + .module + .get_global(hash.as_str()) + .unwrap_or_else(|| { + let tag_arr_ptr = ctx.module.add_global( + int8.array_type(tag.len() as u32), + None, + format!("tagptr{}", fun.1 .0).as_str(), + ); + tag_arr_ptr.set_initializer(&int8.const_array( + &tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::>(), + )); + tag_arr_ptr.set_linkage(Linkage::Private); + let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); + tag_ptr.set_linkage(Linkage::Private); + tag_ptr.set_initializer(&ctx.ctx.const_struct( + &[ + tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(), + size_type.const_int(tag.len() as u64, false).into(), + ], + false, + )); + tag_ptr + }) + .as_pointer_value(); + + let arg_length = args.len() + usize::from(obj.is_some()); + + let stackptr = call_stacksave(ctx, Some("rpc.stack")); + let args_ptr = ctx + .builder + .build_array_alloca( + ptr_type, + ctx.ctx.i32_type().const_int(arg_length as u64, false), + "argptr", + ) + .unwrap(); + + // -- rpc args handling + let mut keys = fun.0.args.clone(); + let mut mapping = HashMap::new(); + for (key, value) in args { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + } + // default value handling + for k in keys { + mapping + .insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()); + } + // reorder the parameters + let mut real_params = fun + .0 + .args + .iter() + .map(|arg| { + mapping + .remove(&arg.name) + .unwrap() + .to_basic_value_enum(ctx, generator, arg.ty) + .map(|llvm_val| (llvm_val, arg.ty)) + }) + .collect::, _>>()?; + if let Some(obj) = obj { + if let ValueEnum::Static(obj_val) = obj.1 { + real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0)); + } else { + // should be an error here... + panic!("only host object is allowed"); + } + } + + for (i, (arg, arg_ty)) in real_params.iter().enumerate() { + let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i)); + let arg_ptr = unsafe { + ctx.builder.build_gep( + args_ptr, + &[int32.const_int(i as u64, false)], + &format!("rpc.arg{i}"), + ) + } + .unwrap(); + ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); + } + + // send the message + let subkernel_send = ctx.module.get_function("subkernel_send_message").unwrap_or_else(|| { + ctx.module.add_function( + "subkernel_send_message", + ctx.ctx.void_type().fn_type( + &[ + int32.into(), + tag_ptr_type.ptr_type(AddressSpace::default()).into(), + ptr_type.ptr_type(AddressSpace::default()).into(), + ], + false, + ), + None, + ) + }); + ctx.builder + .build_call_or_invoke(subkernel_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") + .unwrap(); + // reclaim stack space used by arguments + call_stackrestore(ctx, stackptr); + } + + pub fn subkernel_codegen_callback() -> Arc { + Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| { + subkernel_codegen_callback_fn(ctx, obj, fun, args, generator) + }))) + } + + fn subkernel_recv_message<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, + ) -> Result>, String> { + // -- receive value: + // T result = { + // void *ret_ptr = alloca(sizeof(T)); + // void *ptr = ret_ptr; + // loop: int size = rpc_recv(ptr); + // // Non-zero: Provide `size` bytes of extra storage for variable-length data. + // if(size) { ptr = alloca(size); goto loop; } + // else *(T*)ret_ptr + // } + let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { + ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None) + }); + + if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { + ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); + return Ok(None); + } + + let prehead_bb = ctx.builder.get_insert_block().unwrap(); + let current_function = prehead_bb.get_parent().unwrap(); + let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head"); + let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue"); + let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail"); + + let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret); + let need_load = !ret_ty.is_pointer_type(); + let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap(); + let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap(); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + ctx.builder.position_at_end(head_bb); + + let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap(); + phi.add_incoming(&[(&slotgen, prehead_bb)]); + let alloc_size = ctx + .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") + .unwrap() + .into_int_value(); + let is_done = ctx + .builder + .build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done") + .unwrap(); + + ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); + ctx.builder.position_at_end(alloc_bb); + + let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap(); + phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + + ctx.builder.position_at_end(tail_bb); + + let result = ctx.builder.build_load(slot, "rpc.result").unwrap(); + if need_load { + call_stackrestore(ctx, stackptr); + } + Ok(Some(result)) } }