From 79931365b72af47a5a56926a56bb6d5968ffbb31 Mon Sep 17 00:00:00 2001 From: ram Date: Wed, 18 Dec 2024 05:10:01 +0000 Subject: [PATCH] Implement Kwargs support, pending tests and artiq implementation --- nac3artiq/demo/min_artiq.py | 25 ++++++++--- nac3artiq/src/codegen.rs | 42 ++++++++++++++----- nac3core/src/typecheck/type_inferencer/mod.rs | 37 +++++++++++++--- 3 files changed, 83 insertions(+), 21 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 62d32cc..8d769f0 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -114,13 +114,26 @@ def extern(function): def rpc(arg=None, flags={}): - """Decorates a function or method to be executed on the host interpreter.""" + """Decorates a function to be executed on the host interpreter with kwargs support.""" + def decorator(function): + @wraps(function) + def wrapper(*args, **kwargs): + # Get function signature + sig = inspect.signature(function) + + # Validate kwargs against signature + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + # Call RPC with both args and kwargs + return _do_rpc(function.__name__, + bound_args.args, + bound_args.kwargs) + return wrapper + if arg is None: - def inner_decorator(function): - return rpc(function, flags) - return inner_decorator - register_function(arg) - return arg + return decorator + return decorator(arg) def kernel(function_or_method): """Decorates a function or method to be executed on the core device.""" diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 653f41a..e877a05 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -79,8 +79,7 @@ pub struct ArtiqCodeGenerator<'a> { /// The [`ParallelMode`] of the current parallel context. /// - /// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel` - /// statement, which is used to determine when and how the timeline should be updated. + /// The current parallel context refers to the nearest `with` statement, which is used to determine when and how the timeline should be updated. parallel_mode: ParallelMode, } @@ -373,8 +372,14 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { fn gen_rpc_tag( ctx: &mut CodeGenContext<'_, '_>, ty: Type, + is_kwarg: bool, // Add this parameter buffer: &mut Vec, ) -> Result<(), String> { + // Add kwarg marker if needed + if is_kwarg { + buffer.push(b'k'); // 'k' for keyword argument + } + use nac3core::typecheck::typedef::TypeEnum::*; let int32 = ctx.primitives.int32; @@ -403,14 +408,14 @@ fn gen_rpc_tag( buffer.push(b't'); buffer.push(ty.len() as u8); for ty in ty { - gen_rpc_tag(ctx, *ty, buffer)?; + gen_rpc_tag(ctx, *ty, false, buffer)?; // Pass false for is_kwarg } } TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { let ty = iter_type_vars(params).next().unwrap().ty; buffer.push(b'l'); - gen_rpc_tag(ctx, ty, buffer)?; + gen_rpc_tag(ctx, ty, false, buffer)?; // Pass false for is_kwarg } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); @@ -434,7 +439,7 @@ fn gen_rpc_tag( buffer.push(b'a'); buffer.push((ndarray_ndims & 0xFF) as u8); - gen_rpc_tag(ctx, ndarray_dtype, buffer)?; + gen_rpc_tag(ctx, ndarray_dtype, false, buffer)?; // Pass false for is_kwarg } _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } @@ -808,10 +813,10 @@ fn rpc_codegen_callback_fn<'ctx>( tag.push(b'O'); } for arg in &fun.0.args { - gen_rpc_tag(ctx, arg.ty, &mut tag)?; + gen_rpc_tag(ctx, arg.ty, false, &mut tag)?; // Pass false for is_kwarg } tag.push(b':'); - gen_rpc_tag(ctx, fun.0.ret, &mut tag)?; + gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?; let mut hasher = DefaultHasher::new(); tag.hash(&mut hasher); @@ -858,8 +863,17 @@ fn rpc_codegen_callback_fn<'ctx>( // -- rpc args handling let mut keys = fun.0.args.clone(); let mut mapping = HashMap::new(); + let mut is_keyword_arg = HashMap::new(); + for (key, value) in args { - mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + if let Some(key_name) = key { + mapping.insert(key_name, value); + is_keyword_arg.insert(key_name, true); + } else { + let arg_name = keys.remove(0).name; + mapping.insert(arg_name, value); + is_keyword_arg.insert(arg_name, false); + } } // default value handling for k in keys { @@ -901,6 +915,14 @@ fn rpc_codegen_callback_fn<'ctx>( ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); } + // Before calling rpc_send/rpc_send_async, add keyword arg info to tag + for arg in &fun.0.args { + if *is_keyword_arg.get(&arg.name).unwrap_or(&false) { + tag.push(b'k'); // Mark as keyword argument + } + gen_rpc_tag(ctx, arg.ty, true, &mut tag)?; // Pass true for is_kwarg + } + // call if is_async { let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| { @@ -1007,7 +1029,7 @@ pub fn attributes_writeback<'ctx>( if !is_mutable { continue; } - if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() { attributes.push(name.to_string()); let (index, _) = ctx.get_attr_index(ty, *name); values.push(( @@ -1030,7 +1052,7 @@ pub fn attributes_writeback<'ctx>( TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { let elem_ty = iter_type_vars(params).next().unwrap().ty; - if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() { + if gen_rpc_tag(ctx, elem_ty, false, &mut scratch_buffer).is_ok() { let pydict = PyDict::new(py); pydict.set_item("obj", val)?; host_attributes.append(pydict)?; diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 6068f63..bb57905 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1832,20 +1832,47 @@ impl<'a> Inferencer<'a> { if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { if sign.vars.is_empty() { + // Build keyword argument map + let mut kwargs_map = HashMap::new(); + for kw in &keywords { + if let Some(name) = &kw.node.arg { + // Check if keyword arg exists in function signature + if !sign.args.iter().any(|arg| arg.name == *name) { + return report_error( + &format!("Unexpected keyword argument '{}'", name), + kw.location, + ); + } + kwargs_map.insert(*name, kw.node.value.custom.unwrap()); + } + } + + // Validate that all required args are provided + for arg in &sign.args { + if arg.default_value.is_none() + && !kwargs_map.contains_key(&arg.name) + && args.len() < sign.args.len() + { + return report_error( + &format!("Missing required argument '{}'", arg.name), + location, + ); + } + } + let call = Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), - kwargs: keywords - .iter() - .map(|v| (*v.node.arg.as_ref().unwrap(), v.node.value.custom.unwrap())) - .collect(), + kwargs: kwargs_map, fun: RefCell::new(None), ret: sign.ret, loc: Some(location), operator_info: None, }; + self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) })?; + return Ok(Located { location, custom: Some(sign.ret), @@ -1859,7 +1886,7 @@ impl<'a> Inferencer<'a> { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), kwargs: keywords .iter() - .map(|v| (*v.node.arg.as_ref().unwrap(), v.custom.unwrap())) + .filter_map(|v| v.node.arg.map(|name| (name, v.node.value.custom.unwrap()))) .collect(), fun: RefCell::new(None), ret,