diff --git a/nac3artiq/demo/rpc_kwargs_test.py b/nac3artiq/demo/rpc_kwargs_test.py index acc86bba..d4a444b2 100644 --- a/nac3artiq/demo/rpc_kwargs_test.py +++ b/nac3artiq/demo/rpc_kwargs_test.py @@ -19,23 +19,19 @@ class RpcKwargTest: def run(self): #1) All positional => a=1, b=2, c=3 -> total=6 s1 = sum_3(1, 2, 3) - if s1 != 6: - raise ValueError("sum_3(1,2,3) gave the wrong result.") + assert s1 == 6 #2) Use the default b=10, c=20 => a=5 => total=35 s2 = sum_3(5) - if s2 != 35: - raise ValueError("sum_3(5) gave the wrong result.") + assert s2 == 35 #3) a=1 (positional), b=100 (keyword), omit c => c=20 => total=121 s3 = sum_3(1, b=100) - if s3 != 121: - raise ValueError("sum_3(1, b=100) gave the wrong result.") + assert s3 == 121 #4) a=2, c=300 => b=10 (default) => total=312 s4 = sum_3(a=2, c=300) - if s4 != 312: - raise ValueError("sum_3(a=2, c=300) gave the wrong result.") + assert s4 == 312 if __name__ == "__main__": RpcKwargTest().run() \ No newline at end of file diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 26916279..9f0211e6 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1,5 +1,5 @@ use std::{ - collections::{hash_map::DefaultHasher, HashMap}, + collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, iter::once, mem, @@ -79,7 +79,8 @@ pub struct ArtiqCodeGenerator<'a> { /// The [`ParallelMode`] of the current parallel context. /// - /// The current parallel context refers to the nearest `with` statement, which is used to determine when and how the timeline should be updated. + /// The current parallel context refers to the nearest `with` statement, + /// which is used to determine when and how the timeline should be updated. parallel_mode: ParallelMode, } @@ -372,12 +373,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { fn gen_rpc_tag( ctx: &mut CodeGenContext<'_, '_>, ty: Type, - is_kwarg: bool, // Add this parameter + is_kwarg: bool, buffer: &mut Vec, ) -> Result<(), String> { - // Add kwarg marker if needed if is_kwarg { - buffer.push(b'k'); // 'k' for keyword argument + buffer.push(b'k'); } use nac3core::typecheck::typedef::TypeEnum::*; @@ -408,14 +408,14 @@ fn gen_rpc_tag( buffer.push(b't'); buffer.push(ty.len() as u8); for ty in ty { - gen_rpc_tag(ctx, *ty, false, buffer)?; // Pass false for is_kwarg + gen_rpc_tag(ctx, *ty, false, buffer)?; } } TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { let ty = iter_type_vars(params).next().unwrap().ty; buffer.push(b'l'); - gen_rpc_tag(ctx, ty, false, buffer)?; // Pass false for is_kwarg + gen_rpc_tag(ctx, ty, false, buffer)?; } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); @@ -796,7 +796,7 @@ pub fn rpc_codegen_callback_fn<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), - mut args: Vec<(Option, ValueEnum<'ctx>)>, + args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator, is_async: bool, ) -> Result>, String> { @@ -820,8 +820,6 @@ pub fn rpc_codegen_callback_fn<'ctx>( tag.push(b':'); gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?; - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; let mut hasher = DefaultHasher::new(); tag.hash(&mut hasher); let hash = format!("rpc_tag_{}", hasher.finish()); @@ -832,12 +830,9 @@ pub fn rpc_codegen_callback_fn<'ctx>( } else { let tag_len = tag.len(); let arr_ty = int8.array_type(tag_len as u32); - let tag_const = int8.const_array( - &tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::>(), - ); - let arr_gv = ctx - .module - .add_global(arr_ty, None, &format!("{}.arr", hash)); + let tag_const = int8 + .const_array(&tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::>()); + let arr_gv = ctx.module.add_global(arr_ty, None, &format!("{}.arr", hash)); arr_gv.set_linkage(Linkage::Private); arr_gv.set_initializer(&tag_const); @@ -860,13 +855,14 @@ pub fn rpc_codegen_callback_fn<'ctx>( let mut pos_index = 0usize; - if let Some((obj_ty, obj_val)) = obj { + if let Some((_obj_ty, obj_val)) = obj { param_map[0] = Some(obj_val); pos_index = 1; } - for (maybe_key, val_enum) in args.drain(..) { + for (maybe_key, val_enum) in args { if let Some(kw_name) = maybe_key { - let param_pos = fun.0 + let param_pos = fun + .0 .args .iter() .position(|arg| arg.name == kw_name) @@ -911,23 +907,18 @@ pub fn rpc_codegen_callback_fn<'ctx>( let i32_ty = ctx.ctx.i32_type(); let arg_array = ctx .builder - .build_array_alloca( - ptr_type, - i32_ty.const_int(arg_count, false), - "rpc.arg_array", - ) + .build_array_alloca(ptr_type, i32_ty.const_int(arg_count, false), "rpc.arg_array") .unwrap(); for (i, (llvm_val, ty)) in real_params.iter().enumerate() { let arg_slot_ptr = unsafe { ctx.builder.build_gep( arg_array, - &[ - i32_ty.const_int(i as u64, false), - ], + &[i32_ty.const_int(i as u64, false)], &format!("rpc.arg_slot_{}", i), ) - }.unwrap(); + } + .unwrap(); let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i)); ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap(); } @@ -937,7 +928,11 @@ pub fn rpc_codegen_callback_fn<'ctx>( ctx.module.add_function( "rpc_send_async", ctx.ctx.void_type().fn_type( - &[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()], + &[ + int32.into(), + tag_ptr_type.into(), + ptr_type.ptr_type(AddressSpace::default()).into(), + ], false, ), None, @@ -946,11 +941,7 @@ pub fn rpc_codegen_callback_fn<'ctx>( ctx.builder .build_call( rpc_send_async, - &[ - service_id.into(), - tag_ptr.into(), - arg_array.into(), - ], + &[service_id.into(), tag_ptr.into(), arg_array.into()], "rpc.send_async", ) .unwrap(); @@ -961,7 +952,11 @@ pub fn rpc_codegen_callback_fn<'ctx>( ctx.module.add_function( "rpc_send", ctx.ctx.void_type().fn_type( - &[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()], + &[ + int32.into(), + tag_ptr_type.into(), + ptr_type.ptr_type(AddressSpace::default()).into(), + ], false, ), None, @@ -970,11 +965,7 @@ pub fn rpc_codegen_callback_fn<'ctx>( ctx.builder .build_call( rpc_send, - &[ - service_id.into(), - tag_ptr.into(), - arg_array.into(), - ], + &[service_id.into(), tag_ptr.into(), arg_array.into()], "rpc.send", ) .unwrap();