diff --git a/nac3artiq/demo/rpc_kwargs_test.py b/nac3artiq/demo/rpc_kwargs_test.py new file mode 100644 index 00000000..d4a444b2 --- /dev/null +++ b/nac3artiq/demo/rpc_kwargs_test.py @@ -0,0 +1,37 @@ +from min_artiq import * +from numpy import int32 + +@rpc +def sum_3(a: int32, b: int32 = 10, c: int32 = 20) -> int32: + """ + An RPC function to test NAC3's handling of positional/keyword arguments. + """ + return int32(a + b + c) + +@nac3 +class RpcKwargTest: + core: KernelInvariant[Core] + + def __init__(self): + self.core = Core() + + @kernel + def run(self): + #1) All positional => a=1, b=2, c=3 -> total=6 + s1 = sum_3(1, 2, 3) + assert s1 == 6 + + #2) Use the default b=10, c=20 => a=5 => total=35 + s2 = sum_3(5) + assert s2 == 35 + + #3) a=1 (positional), b=100 (keyword), omit c => c=20 => total=121 + s3 = sum_3(1, b=100) + assert s3 == 121 + + #4) a=2, c=300 => b=10 (default) => total=312 + s4 = sum_3(a=2, c=300) + assert s4 == 312 + +if __name__ == "__main__": + RpcKwargTest().run() \ No newline at end of file diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index cc625a02..34ae43f7 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1,5 +1,5 @@ use std::{ - collections::{hash_map::DefaultHasher, HashMap}, + collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, iter::once, mem, @@ -43,7 +43,9 @@ use nac3core::{ }, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, + typedef::{ + iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, TypeEnum::*, VarMap, + }, }, }; @@ -83,8 +85,8 @@ pub struct ArtiqCodeGenerator<'a> { /// The [`ParallelMode`] of the current parallel context. /// - /// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel` - /// statement, which is used to determine when and how the timeline should be updated. + /// The current parallel context refers to the nearest `with` statement, + /// which is used to determine when and how the timeline should be updated. parallel_mode: ParallelMode, /// Specially treated python IDs to identify `with parallel` and `with sequential` blocks. @@ -412,9 +414,12 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> { fn gen_rpc_tag( ctx: &mut CodeGenContext<'_, '_>, ty: Type, + is_kwarg: bool, buffer: &mut Vec, ) -> Result<(), String> { - use nac3core::typecheck::typedef::TypeEnum::*; + if is_kwarg { + buffer.push(b'k'); + } let PrimitiveStore { int32, int64, float, bool, str, none, .. } = ctx.primitives; @@ -437,14 +442,14 @@ fn gen_rpc_tag( buffer.push(b't'); buffer.push(ty.len() as u8); for ty in ty { - gen_rpc_tag(ctx, *ty, buffer)?; + gen_rpc_tag(ctx, *ty, false, buffer)?; } } TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { let ty = iter_type_vars(params).next().unwrap().ty; buffer.push(b'l'); - gen_rpc_tag(ctx, ty, buffer)?; + gen_rpc_tag(ctx, ty, false, buffer)?; } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); @@ -468,7 +473,7 @@ fn gen_rpc_tag( buffer.push(b'a'); buffer.push((ndarray_ndims & 0xFF) as u8); - gen_rpc_tag(ctx, ndarray_dtype, buffer)?; + gen_rpc_tag(ctx, ndarray_dtype, false, buffer)?; } _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } @@ -833,106 +838,119 @@ fn rpc_codegen_callback_fn<'ctx>( let int32 = ctx.ctx.i32_type(); let size_type = ctx.get_size_type(); let ptr_type = int8.ptr_type(AddressSpace::default()); - let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); - let service_id = int32.const_int(fun.1 .0 as u64, false); // -- setup rpc tags let mut tag = Vec::new(); if obj.is_some() { tag.push(b'O'); } - for arg in &fun.0.args { - gen_rpc_tag(ctx, arg.ty, &mut tag)?; + for param in &fun.0.args { + gen_rpc_tag(ctx, param.ty, false, &mut tag)?; } tag.push(b':'); - gen_rpc_tag(ctx, fun.0.ret, &mut tag)?; + gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?; let mut hasher = DefaultHasher::new(); tag.hash(&mut hasher); - let hash = format!("{}", hasher.finish()); + let hash = format!("rpc_tag_{}", hasher.finish()); - let tag_ptr = ctx - .module - .get_global(hash.as_str()) - .unwrap_or_else(|| { - let tag_arr_ptr = ctx.module.add_global( - int8.array_type(tag.len() as u32), - None, - format!("tagptr{}", fun.1 .0).as_str(), - ); - tag_arr_ptr.set_initializer(&int8.const_array( - &tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::>(), - )); - tag_arr_ptr.set_linkage(Linkage::Private); - let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); - tag_ptr.set_linkage(Linkage::Private); - tag_ptr.set_initializer(&ctx.ctx.const_struct( - &[ - tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(), - size_type.const_int(tag.len() as u64, false).into(), - ], - false, - )); - tag_ptr - }) - .as_pointer_value(); + let maybe_existing = ctx.module.get_global(&hash); + let tag_ptr = if let Some(gv) = maybe_existing { + gv.as_pointer_value() + } else { + let tag_len = tag.len(); + let arr_ty = int8.array_type(tag_len as u32); + let tag_const = int8 + .const_array(&tag.iter().map(|&b| int8.const_int(u64::from(b), false)).collect_vec()); + let arr_gv = ctx.module.add_global(arr_ty, None, &format!("{hash}.arr")); + arr_gv.set_linkage(Linkage::Private); + arr_gv.set_initializer(&tag_const); - let arg_length = args.len() + usize::from(obj.is_some()); + let st = ctx.ctx.const_struct( + &[ + arr_gv.as_pointer_value().const_cast(ptr_type).into(), + size_type.const_int(tag_len as u64, false).into(), + ], + false, + ); + let st_gv = ctx.module.add_global(st.get_type(), None, &hash); + st_gv.set_linkage(Linkage::Private); + st_gv.set_initializer(&st); - let stackptr = call_stacksave(ctx, Some("rpc.stack")); - let args_ptr = ctx - .builder - .build_array_alloca( - ptr_type, - ctx.ctx.i32_type().const_int(arg_length as u64, false), - "argptr", - ) - .unwrap(); + st_gv.as_pointer_value() + }; - // -- rpc args handling - let mut keys = fun.0.args.clone(); - let mut mapping = HashMap::new(); - for (key, value) in args { - mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + let n_params = fun.0.args.len(); + let mut param_map: Vec>> = vec![None; n_params]; + + let mut pos_index = 0usize; + + if let Some((_obj_ty, obj_val)) = obj { + param_map[0] = Some(obj_val); + pos_index = 1; } - // default value handling - for k in keys { - mapping - .insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()); - } - // reorder the parameters - let mut real_params = fun - .0 - .args - .iter() - .map(|arg| { - mapping - .remove(&arg.name) - .unwrap() - .to_basic_value_enum(ctx, generator, arg.ty) - .map(|llvm_val| (llvm_val, arg.ty)) - }) - .collect::, _>>()?; - if let Some(obj) = obj { - if let ValueEnum::Static(obj_val) = obj.1 { - real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0)); + for (maybe_key, val_enum) in args { + if let Some(kw_name) = maybe_key { + let param_pos = fun + .0 + .args + .iter() + .position(|arg| arg.name == kw_name) + .ok_or_else(|| format!("Unknown keyword argument '{kw_name}'"))?; + + if param_map[param_pos].is_some() { + return Err(format!("Multiple values for argument '{kw_name}'")); + } + param_map[param_pos] = Some(val_enum); } else { - // should be an error here... - panic!("only host object is allowed"); + while pos_index < n_params && param_map[pos_index].is_some() { + pos_index += 1; + } + if pos_index >= n_params { + return Err("Too many positional arguments given to function.".to_string()); + } + param_map[pos_index] = Some(val_enum); + pos_index += 1; } } - for (i, (arg, arg_ty)) in real_params.iter().enumerate() { - let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i)); - let arg_ptr = unsafe { + for (i, param) in fun.0.args.iter().enumerate() { + if param_map[i].is_none() { + if let Some(default_expr) = ¶m.default_value { + let default_val = ctx.gen_symbol_val(generator, default_expr, param.ty).into(); + param_map[i] = Some(default_val); + } else { + return Err(format!("Missing required argument '{}'", param.name)); + } + } + } + let mut real_params = Vec::with_capacity(n_params); + for (i, param_spec) in fun.0.args.iter().enumerate() { + let some_valenum = param_map[i].take().unwrap(); + let llvm_val = some_valenum.to_basic_value_enum(ctx, generator, param_spec.ty)?; + real_params.push((llvm_val, param_spec.ty)); + } + + let arg_count = real_params.len() as u64; + let stackptr = call_stacksave(ctx, Some("rpc.stack")); + + let i32_ty = ctx.ctx.i32_type(); + let arg_array = ctx + .builder + .build_array_alloca(ptr_type, i32_ty.const_int(arg_count, false), "rpc.arg_array") + .unwrap(); + + for (i, (llvm_val, ty)) in real_params.iter().enumerate() { + let arg_slot_ptr = unsafe { ctx.builder.build_gep( - args_ptr, - &[int32.const_int(i as u64, false)], - &format!("rpc.arg{i}"), + arg_array, + &[i32_ty.const_int(i as u64, false)], + &format!("rpc.arg_slot_{i}"), ) } .unwrap(); - ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); + let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i)); + ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap(); } // call @@ -940,7 +958,7 @@ fn rpc_codegen_callback_fn<'ctx>( ctx, if is_async { "rpc_send_async" } else { "rpc_send" }, None, - &[service_id.into(), tag_ptr.into(), args_ptr.into()], + &[service_id.into(), tag_ptr.into(), arg_array.into()], Some("rpc.send"), None, ); @@ -948,19 +966,8 @@ fn rpc_codegen_callback_fn<'ctx>( // reclaim stack space used by arguments call_stackrestore(ctx, stackptr); - if is_async { - // async RPCs do not return any values - Ok(None) - } else { - let result = format_rpc_ret(generator, ctx, fun.0.ret); - - if !result.is_some_and(|res| res.get_type().is_pointer_type()) { - // An RPC returning an NDArray would not touch here. - call_stackrestore(ctx, stackptr); - } - - Ok(result) - } + let maybe_ret = format_rpc_ret(generator, ctx, fun.0.ret); + Ok(maybe_ret) } pub fn attributes_writeback<'ctx>( @@ -1008,7 +1015,7 @@ pub fn attributes_writeback<'ctx>( if !is_mutable { continue; } - if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() { attributes.push(name.to_string()); let (index, _) = ctx.get_attr_index(ty, *name); values.push(( @@ -1031,7 +1038,7 @@ pub fn attributes_writeback<'ctx>( TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { let elem_ty = iter_type_vars(params).next().unwrap().ty; - if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() { + if gen_rpc_tag(ctx, elem_ty, false, &mut scratch_buffer).is_ok() { let pydict = PyDict::new(py); pydict.set_item("obj", val)?; host_attributes.append(pydict)?; @@ -1049,7 +1056,7 @@ pub fn attributes_writeback<'ctx>( if *is_method { continue; } - if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() { fields.push(name.to_string()); let (index, _) = ctx.get_attr_index(ty, *name); values.push((