Implement passing kwarg to RPC #533 #579

Open
ramtej wants to merge 11 commits from ramtej/nac3:feature/rpc-keywords into master
2 changed files with 145 additions and 101 deletions

View File

@ -0,0 +1,37 @@
from min_artiq import *
from numpy import int32
@rpc
def sum_3(a: int32, b: int32 = 10, c: int32 = 20) -> int32:
"""
An RPC function to test NAC3's handling of positional/keyword arguments.
"""
return int32(a + b + c)
@nac3
class RpcKwargTest:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def run(self):
#1) All positional => a=1, b=2, c=3 -> total=6
s1 = sum_3(1, 2, 3)
assert s1 == 6

use assert instead?

use ``assert`` instead?
Outdated
Review

We don't support exceptions in runkernel (yet?). One advantage of nac3artiq/demo is it can run on x86 without hardware.

We don't support exceptions in ``runkernel`` (yet?). One advantage of ``nac3artiq/demo`` is it can run on x86 without hardware.
Outdated
Review

Though of course it's also not clear how to support RPC in runkernel.

Though of course it's also not clear how to support RPC in runkernel.
#2) Use the default b=10, c=20 => a=5 => total=35
s2 = sum_3(5)
assert s2 == 35
#3) a=1 (positional), b=100 (keyword), omit c => c=20 => total=121
s3 = sum_3(1, b=100)
assert s3 == 121
#4) a=2, c=300 => b=10 (default) => total=312
s4 = sum_3(a=2, c=300)
assert s4 == 312
if __name__ == "__main__":
RpcKwargTest().run()

View File

@ -1,5 +1,5 @@
use std::{
collections::{hash_map::DefaultHasher, HashMap},
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
iter::once,
mem,
@ -43,7 +43,9 @@ use nac3core::{
},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
typedef::{
iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, TypeEnum::*, VarMap,
},
},
};
@ -83,8 +85,8 @@ pub struct ArtiqCodeGenerator<'a> {
/// The [`ParallelMode`] of the current parallel context.
///
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
/// statement, which is used to determine when and how the timeline should be updated.
/// The current parallel context refers to the nearest `with` statement,
/// which is used to determine when and how the timeline should be updated.
parallel_mode: ParallelMode,
/// Specially treated python IDs to identify `with parallel` and `with sequential` blocks.
@ -412,9 +414,12 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
fn gen_rpc_tag(
ctx: &mut CodeGenContext<'_, '_>,
ty: Type,
is_kwarg: bool,
buffer: &mut Vec<u8>,
) -> Result<(), String> {
use nac3core::typecheck::typedef::TypeEnum::*;
if is_kwarg {
buffer.push(b'k');
}
let PrimitiveStore { int32, int64, float, bool, str, none, .. } = ctx.primitives;
@ -437,14 +442,14 @@ fn gen_rpc_tag(
buffer.push(b't');
buffer.push(ty.len() as u8);
for ty in ty {
gen_rpc_tag(ctx, *ty, buffer)?;
gen_rpc_tag(ctx, *ty, false, buffer)?;
}
}
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let ty = iter_type_vars(params).next().unwrap().ty;
buffer.push(b'l');
gen_rpc_tag(ctx, ty, buffer)?;
gen_rpc_tag(ctx, ty, false, buffer)?;
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
ramtej marked this conversation as resolved Outdated

Please remove this comment.

Please remove this comment.
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
@ -468,7 +473,7 @@ fn gen_rpc_tag(
buffer.push(b'a');
buffer.push((ndarray_ndims & 0xFF) as u8);
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
gen_rpc_tag(ctx, ndarray_dtype, false, buffer)?;
}
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
}
@ -833,106 +838,119 @@ fn rpc_codegen_callback_fn<'ctx>(
let int32 = ctx.ctx.i32_type();
let size_type = ctx.get_size_type();
let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags
let mut tag = Vec::new();
if obj.is_some() {
tag.push(b'O');
}
ramtej marked this conversation as resolved Outdated

Suggestion: collect_vec()

Suggestion: `collect_vec()`
for arg in &fun.0.args {
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
for param in &fun.0.args {
gen_rpc_tag(ctx, param.ty, false, &mut tag)?;
}
tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher);
let hash = format!("{}", hasher.finish());
let hash = format!("rpc_tag_{}", hasher.finish());
let tag_ptr = ctx
.module
.get_global(hash.as_str())
.unwrap_or_else(|| {
let tag_arr_ptr = ctx.module.add_global(
int8.array_type(tag.len() as u32),
None,
format!("tagptr{}", fun.1 .0).as_str(),
);
tag_arr_ptr.set_initializer(&int8.const_array(
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
));
tag_arr_ptr.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
tag_ptr.set_linkage(Linkage::Private);
tag_ptr.set_initializer(&ctx.ctx.const_struct(
&[
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag.len() as u64, false).into(),
],
false,
));
tag_ptr
})
.as_pointer_value();
let maybe_existing = ctx.module.get_global(&hash);
let tag_ptr = if let Some(gv) = maybe_existing {
gv.as_pointer_value()
} else {
let tag_len = tag.len();
let arr_ty = int8.array_type(tag_len as u32);
let tag_const = int8
.const_array(&tag.iter().map(|&b| int8.const_int(u64::from(b), false)).collect_vec());
let arr_gv = ctx.module.add_global(arr_ty, None, &format!("{hash}.arr"));
arr_gv.set_linkage(Linkage::Private);
arr_gv.set_initializer(&tag_const);
ramtej marked this conversation as resolved Outdated

I don't think drain is necessary here. You can just iterate on args and it would consume the entire Vec - Plus the mut change to the parameter is also made unnecessary.

I don't think `drain` is necessary here. You can just iterate on `args` and it would consume the entire `Vec` - Plus the `mut` change to the parameter is also made unnecessary.
let arg_length = args.len() + usize::from(obj.is_some());
let st = ctx.ctx.const_struct(
&[
arr_gv.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag_len as u64, false).into(),
],
false,
);
let st_gv = ctx.module.add_global(st.get_type(), None, &hash);
st_gv.set_linkage(Linkage::Private);
st_gv.set_initializer(&st);
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let args_ptr = ctx
.builder
.build_array_alloca(
ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false),
"argptr",
)
.unwrap();
st_gv.as_pointer_value()
};
// -- rpc args handling
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in args {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
let n_params = fun.0.args.len();
let mut param_map: Vec<Option<ValueEnum<'ctx>>> = vec![None; n_params];
let mut pos_index = 0usize;
if let Some((_obj_ty, obj_val)) = obj {
param_map[0] = Some(obj_val);
pos_index = 1;
}
// default value handling
for k in keys {
mapping
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
}
// reorder the parameters
let mut real_params = fun
.0
.args
.iter()
.map(|arg| {
mapping
.remove(&arg.name)
.unwrap()
.to_basic_value_enum(ctx, generator, arg.ty)
.map(|llvm_val| (llvm_val, arg.ty))
})
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
for (maybe_key, val_enum) in args {
if let Some(kw_name) = maybe_key {
let param_pos = fun
.0
.args
.iter()
.position(|arg| arg.name == kw_name)
.ok_or_else(|| format!("Unknown keyword argument '{kw_name}'"))?;
if param_map[param_pos].is_some() {
return Err(format!("Multiple values for argument '{kw_name}'"));
}
param_map[param_pos] = Some(val_enum);
} else {
// should be an error here...
panic!("only host object is allowed");
while pos_index < n_params && param_map[pos_index].is_some() {
pos_index += 1;
}
if pos_index >= n_params {
return Err("Too many positional arguments given to function.".to_string());
}
param_map[pos_index] = Some(val_enum);
pos_index += 1;
}
}
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
let arg_ptr = unsafe {
for (i, param) in fun.0.args.iter().enumerate() {
if param_map[i].is_none() {
if let Some(default_expr) = &param.default_value {
let default_val = ctx.gen_symbol_val(generator, default_expr, param.ty).into();
param_map[i] = Some(default_val);
} else {
return Err(format!("Missing required argument '{}'", param.name));
}
}
}
let mut real_params = Vec::with_capacity(n_params);
for (i, param_spec) in fun.0.args.iter().enumerate() {
let some_valenum = param_map[i].take().unwrap();
let llvm_val = some_valenum.to_basic_value_enum(ctx, generator, param_spec.ty)?;
real_params.push((llvm_val, param_spec.ty));
}
let arg_count = real_params.len() as u64;
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let i32_ty = ctx.ctx.i32_type();
let arg_array = ctx
.builder
.build_array_alloca(ptr_type, i32_ty.const_int(arg_count, false), "rpc.arg_array")
.unwrap();
for (i, (llvm_val, ty)) in real_params.iter().enumerate() {
let arg_slot_ptr = unsafe {
ctx.builder.build_gep(
args_ptr,
&[int32.const_int(i as u64, false)],
&format!("rpc.arg{i}"),
arg_array,
&[i32_ty.const_int(i as u64, false)],
&format!("rpc.arg_slot_{i}"),
)
}
.unwrap();
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i));
ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap();
}
// call
@ -940,7 +958,7 @@ fn rpc_codegen_callback_fn<'ctx>(
ctx,
if is_async { "rpc_send_async" } else { "rpc_send" },
None,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
&[service_id.into(), tag_ptr.into(), arg_array.into()],
Some("rpc.send"),
None,
);
@ -948,19 +966,8 @@ fn rpc_codegen_callback_fn<'ctx>(
// reclaim stack space used by arguments
call_stackrestore(ctx, stackptr);
if is_async {
// async RPCs do not return any values
Ok(None)
} else {
let result = format_rpc_ret(generator, ctx, fun.0.ret);
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
// An RPC returning an NDArray would not touch here.
call_stackrestore(ctx, stackptr);
}
Ok(result)
}
let maybe_ret = format_rpc_ret(generator, ctx, fun.0.ret);
Ok(maybe_ret)
}
pub fn attributes_writeback<'ctx>(
@ -1008,7 +1015,7 @@ pub fn attributes_writeback<'ctx>(
if !is_mutable {
continue;
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string());
let (index, _) = ctx.get_attr_index(ty, *name);
values.push((
@ -1031,7 +1038,7 @@ pub fn attributes_writeback<'ctx>(
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let elem_ty = iter_type_vars(params).next().unwrap().ty;
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
if gen_rpc_tag(ctx, elem_ty, false, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
@ -1049,7 +1056,7 @@ pub fn attributes_writeback<'ctx>(
if *is_method {
continue;
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() {
fields.push(name.to_string());
let (index, _) = ctx.get_attr_index(ty, *name);
values.push((