Implement passing kwarg to RPC #533 #579

Open
ramtej wants to merge 11 commits from ramtej/nac3:feature/rpc-keywords into master
2 changed files with 145 additions and 101 deletions

View File

@ -0,0 +1,37 @@
from min_artiq import *
from numpy import int32
@rpc
def sum_3(a: int32, b: int32 = 10, c: int32 = 20) -> int32:
"""
An RPC function to test NAC3's handling of positional/keyword arguments.
"""
return int32(a + b + c)
@nac3
class RpcKwargTest:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def run(self):
#1) All positional => a=1, b=2, c=3 -> total=6
s1 = sum_3(1, 2, 3)
assert s1 == 6
#2) Use the default b=10, c=20 => a=5 => total=35
s2 = sum_3(5)
assert s2 == 35
#3) a=1 (positional), b=100 (keyword), omit c => c=20 => total=121
s3 = sum_3(1, b=100)
assert s3 == 121
#4) a=2, c=300 => b=10 (default) => total=312
s4 = sum_3(a=2, c=300)
assert s4 == 312
if __name__ == "__main__":
RpcKwargTest().run()

View File

@ -1,5 +1,5 @@
use std::{ use std::{
collections::{hash_map::DefaultHasher, HashMap}, collections::hash_map::DefaultHasher,
hash::{Hash, Hasher}, hash::{Hash, Hasher},
iter::once, iter::once,
mem, mem,
@ -43,7 +43,9 @@ use nac3core::{
}, },
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, typedef::{
iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, TypeEnum::*, VarMap,
},
}, },
}; };
@ -83,8 +85,8 @@ pub struct ArtiqCodeGenerator<'a> {
/// The [`ParallelMode`] of the current parallel context. /// The [`ParallelMode`] of the current parallel context.
/// ///
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel` /// The current parallel context refers to the nearest `with` statement,
/// statement, which is used to determine when and how the timeline should be updated. /// which is used to determine when and how the timeline should be updated.
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
/// Specially treated python IDs to identify `with parallel` and `with sequential` blocks. /// Specially treated python IDs to identify `with parallel` and `with sequential` blocks.
@ -412,9 +414,12 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
fn gen_rpc_tag( fn gen_rpc_tag(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
ty: Type, ty: Type,
is_kwarg: bool,
buffer: &mut Vec<u8>, buffer: &mut Vec<u8>,
) -> Result<(), String> { ) -> Result<(), String> {
use nac3core::typecheck::typedef::TypeEnum::*; if is_kwarg {
buffer.push(b'k');
}
let PrimitiveStore { int32, int64, float, bool, str, none, .. } = ctx.primitives; let PrimitiveStore { int32, int64, float, bool, str, none, .. } = ctx.primitives;
@ -437,14 +442,14 @@ fn gen_rpc_tag(
buffer.push(b't'); buffer.push(b't');
buffer.push(ty.len() as u8); buffer.push(ty.len() as u8);
for ty in ty { for ty in ty {
gen_rpc_tag(ctx, *ty, buffer)?; gen_rpc_tag(ctx, *ty, false, buffer)?;
} }
} }
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let ty = iter_type_vars(params).next().unwrap().ty; let ty = iter_type_vars(params).next().unwrap().ty;
buffer.push(b'l'); buffer.push(b'l');
gen_rpc_tag(ctx, ty, buffer)?; gen_rpc_tag(ctx, ty, false, buffer)?;
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
@ -468,7 +473,7 @@ fn gen_rpc_tag(
buffer.push(b'a'); buffer.push(b'a');
buffer.push((ndarray_ndims & 0xFF) as u8); buffer.push((ndarray_ndims & 0xFF) as u8);
gen_rpc_tag(ctx, ndarray_dtype, buffer)?; gen_rpc_tag(ctx, ndarray_dtype, false, buffer)?;
} }
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
} }
@ -833,106 +838,119 @@ fn rpc_codegen_callback_fn<'ctx>(
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_type = ctx.get_size_type(); let size_type = ctx.get_size_type();
let ptr_type = int8.ptr_type(AddressSpace::default()); let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1 .0 as u64, false); let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags // -- setup rpc tags
let mut tag = Vec::new(); let mut tag = Vec::new();
if obj.is_some() { if obj.is_some() {
tag.push(b'O'); tag.push(b'O');
} }
for arg in &fun.0.args { for param in &fun.0.args {
gen_rpc_tag(ctx, arg.ty, &mut tag)?; gen_rpc_tag(ctx, param.ty, false, &mut tag)?;
} }
tag.push(b':'); tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?; gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher); tag.hash(&mut hasher);
let hash = format!("{}", hasher.finish()); let hash = format!("rpc_tag_{}", hasher.finish());
let tag_ptr = ctx let maybe_existing = ctx.module.get_global(&hash);
.module let tag_ptr = if let Some(gv) = maybe_existing {
.get_global(hash.as_str()) gv.as_pointer_value()
.unwrap_or_else(|| { } else {
let tag_arr_ptr = ctx.module.add_global( let tag_len = tag.len();
int8.array_type(tag.len() as u32), let arr_ty = int8.array_type(tag_len as u32);
None, let tag_const = int8
format!("tagptr{}", fun.1 .0).as_str(), .const_array(&tag.iter().map(|&b| int8.const_int(u64::from(b), false)).collect_vec());
); let arr_gv = ctx.module.add_global(arr_ty, None, &format!("{hash}.arr"));
tag_arr_ptr.set_initializer(&int8.const_array( arr_gv.set_linkage(Linkage::Private);
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(), arr_gv.set_initializer(&tag_const);
));
tag_arr_ptr.set_linkage(Linkage::Private); let st = ctx.ctx.const_struct(
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
tag_ptr.set_linkage(Linkage::Private);
tag_ptr.set_initializer(&ctx.ctx.const_struct(
&[ &[
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(), arr_gv.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag.len() as u64, false).into(), size_type.const_int(tag_len as u64, false).into(),
], ],
false, false,
)); );
tag_ptr let st_gv = ctx.module.add_global(st.get_type(), None, &hash);
}) st_gv.set_linkage(Linkage::Private);
.as_pointer_value(); st_gv.set_initializer(&st);
let arg_length = args.len() + usize::from(obj.is_some()); st_gv.as_pointer_value()
};
let stackptr = call_stacksave(ctx, Some("rpc.stack")); let n_params = fun.0.args.len();
let args_ptr = ctx let mut param_map: Vec<Option<ValueEnum<'ctx>>> = vec![None; n_params];
.builder
.build_array_alloca(
ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false),
"argptr",
)
.unwrap();
// -- rpc args handling let mut pos_index = 0usize;
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new(); if let Some((_obj_ty, obj_val)) = obj {
for (key, value) in args { param_map[0] = Some(obj_val);
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); pos_index = 1;
} }
// default value handling for (maybe_key, val_enum) in args {
for k in keys { if let Some(kw_name) = maybe_key {
mapping let param_pos = fun
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
}
// reorder the parameters
let mut real_params = fun
.0 .0
.args .args
.iter() .iter()
.map(|arg| { .position(|arg| arg.name == kw_name)
mapping .ok_or_else(|| format!("Unknown keyword argument '{kw_name}'"))?;
.remove(&arg.name)
.unwrap() if param_map[param_pos].is_some() {
.to_basic_value_enum(ctx, generator, arg.ty) return Err(format!("Multiple values for argument '{kw_name}'"));
.map(|llvm_val| (llvm_val, arg.ty)) }
}) param_map[param_pos] = Some(val_enum);
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
} else { } else {
// should be an error here... while pos_index < n_params && param_map[pos_index].is_some() {
panic!("only host object is allowed"); pos_index += 1;
}
if pos_index >= n_params {
return Err("Too many positional arguments given to function.".to_string());
}
param_map[pos_index] = Some(val_enum);
pos_index += 1;
} }
} }
for (i, (arg, arg_ty)) in real_params.iter().enumerate() { for (i, param) in fun.0.args.iter().enumerate() {
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i)); if param_map[i].is_none() {
let arg_ptr = unsafe { if let Some(default_expr) = &param.default_value {
let default_val = ctx.gen_symbol_val(generator, default_expr, param.ty).into();
param_map[i] = Some(default_val);
} else {
return Err(format!("Missing required argument '{}'", param.name));
}
}
}
let mut real_params = Vec::with_capacity(n_params);
for (i, param_spec) in fun.0.args.iter().enumerate() {
let some_valenum = param_map[i].take().unwrap();
let llvm_val = some_valenum.to_basic_value_enum(ctx, generator, param_spec.ty)?;
real_params.push((llvm_val, param_spec.ty));
}
let arg_count = real_params.len() as u64;
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let i32_ty = ctx.ctx.i32_type();
let arg_array = ctx
.builder
.build_array_alloca(ptr_type, i32_ty.const_int(arg_count, false), "rpc.arg_array")
.unwrap();
for (i, (llvm_val, ty)) in real_params.iter().enumerate() {
let arg_slot_ptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(
args_ptr, arg_array,
&[int32.const_int(i as u64, false)], &[i32_ty.const_int(i as u64, false)],
&format!("rpc.arg{i}"), &format!("rpc.arg_slot_{i}"),
) )
} }
.unwrap(); .unwrap();
ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i));
ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap();
} }
// call // call
@ -940,7 +958,7 @@ fn rpc_codegen_callback_fn<'ctx>(
ctx, ctx,
if is_async { "rpc_send_async" } else { "rpc_send" }, if is_async { "rpc_send_async" } else { "rpc_send" },
None, None,
&[service_id.into(), tag_ptr.into(), args_ptr.into()], &[service_id.into(), tag_ptr.into(), arg_array.into()],
Some("rpc.send"), Some("rpc.send"),
None, None,
); );
@ -948,19 +966,8 @@ fn rpc_codegen_callback_fn<'ctx>(
// reclaim stack space used by arguments // reclaim stack space used by arguments
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
if is_async { let maybe_ret = format_rpc_ret(generator, ctx, fun.0.ret);
// async RPCs do not return any values Ok(maybe_ret)
Ok(None)
} else {
let result = format_rpc_ret(generator, ctx, fun.0.ret);
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
// An RPC returning an NDArray would not touch here.
call_stackrestore(ctx, stackptr);
}
Ok(result)
}
} }
pub fn attributes_writeback<'ctx>( pub fn attributes_writeback<'ctx>(
@ -1008,7 +1015,7 @@ pub fn attributes_writeback<'ctx>(
if !is_mutable { if !is_mutable {
continue; continue;
} }
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string()); attributes.push(name.to_string());
let (index, _) = ctx.get_attr_index(ty, *name); let (index, _) = ctx.get_attr_index(ty, *name);
values.push(( values.push((
@ -1031,7 +1038,7 @@ pub fn attributes_writeback<'ctx>(
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let elem_ty = iter_type_vars(params).next().unwrap().ty; let elem_ty = iter_type_vars(params).next().unwrap().ty;
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() { if gen_rpc_tag(ctx, elem_ty, false, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py); let pydict = PyDict::new(py);
pydict.set_item("obj", val)?; pydict.set_item("obj", val)?;
host_attributes.append(pydict)?; host_attributes.append(pydict)?;
@ -1049,7 +1056,7 @@ pub fn attributes_writeback<'ctx>(
if *is_method { if *is_method {
continue; continue;
} }
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() {
fields.push(name.to_string()); fields.push(name.to_string());
let (index, _) = ctx.get_attr_index(ty, *name); let (index, _) = ctx.get_attr_index(ty, *name);
values.push(( values.push((