From c083245c280f3d28fc286dfb03417f3d7ce521d6 Mon Sep 17 00:00:00 2001 From: ram Date: Wed, 5 Feb 2025 07:20:13 +0000 Subject: [PATCH] Implement passing of kwarg --- nac3artiq/src/codegen.rs | 251 +++++++++--------- nac3core/src/typecheck/type_inferencer/mod.rs | 84 +++--- 2 files changed, 174 insertions(+), 161 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index e877a055..26916279 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -792,11 +792,11 @@ fn format_rpc_ret<'ctx>( Some(result) } -fn rpc_codegen_callback_fn<'ctx>( +pub fn rpc_codegen_callback_fn<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), - args: Vec<(Option, ValueEnum<'ctx>)>, + mut args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator, is_async: bool, ) -> Result>, String> { @@ -804,136 +804,140 @@ fn rpc_codegen_callback_fn<'ctx>( let int32 = ctx.ctx.i32_type(); let size_type = generator.get_size_type(ctx.ctx); let ptr_type = int8.ptr_type(AddressSpace::default()); - let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); - + let tag_ptr_type = ctx + .ctx + .struct_type(&[ptr_type.into(), size_type.into()], false) + .ptr_type(AddressSpace::default()); let service_id = int32.const_int(fun.1 .0 as u64, false); // -- setup rpc tags let mut tag = Vec::new(); if obj.is_some() { tag.push(b'O'); } - for arg in &fun.0.args { - gen_rpc_tag(ctx, arg.ty, false, &mut tag)?; // Pass false for is_kwarg + for param in &fun.0.args { + gen_rpc_tag(ctx, param.ty, false, &mut tag)?; } tag.push(b':'); gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; let mut hasher = DefaultHasher::new(); tag.hash(&mut hasher); - let hash = format!("{}", hasher.finish()); + let hash = format!("rpc_tag_{}", hasher.finish()); - let tag_ptr = ctx - .module - .get_global(hash.as_str()) - .unwrap_or_else(|| { - let tag_arr_ptr = ctx.module.add_global( - int8.array_type(tag.len() as u32), - None, - format!("tagptr{}", fun.1 .0).as_str(), - ); - tag_arr_ptr.set_initializer(&int8.const_array( - &tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::>(), - )); - tag_arr_ptr.set_linkage(Linkage::Private); - let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); - tag_ptr.set_linkage(Linkage::Private); - tag_ptr.set_initializer(&ctx.ctx.const_struct( - &[ - tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(), - size_type.const_int(tag.len() as u64, false).into(), - ], - false, - )); - tag_ptr - }) - .as_pointer_value(); + let maybe_existing = ctx.module.get_global(&hash); + let tag_ptr = if let Some(gv) = maybe_existing { + gv.as_pointer_value() + } else { + let tag_len = tag.len(); + let arr_ty = int8.array_type(tag_len as u32); + let tag_const = int8.const_array( + &tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::>(), + ); + let arr_gv = ctx + .module + .add_global(arr_ty, None, &format!("{}.arr", hash)); + arr_gv.set_linkage(Linkage::Private); + arr_gv.set_initializer(&tag_const); - let arg_length = args.len() + usize::from(obj.is_some()); + let st = ctx.ctx.const_struct( + &[ + arr_gv.as_pointer_value().const_cast(ptr_type).into(), + size_type.const_int(tag_len as u64, false).into(), + ], + false, + ); + let st_gv = ctx.module.add_global(st.get_type(), None, &hash); + st_gv.set_linkage(Linkage::Private); + st_gv.set_initializer(&st); + st_gv.as_pointer_value() + }; + + let n_params = fun.0.args.len(); + let mut param_map: Vec>> = vec![None; n_params]; + + let mut pos_index = 0usize; + + if let Some((obj_ty, obj_val)) = obj { + param_map[0] = Some(obj_val); + pos_index = 1; + } + for (maybe_key, val_enum) in args.drain(..) { + if let Some(kw_name) = maybe_key { + let param_pos = fun.0 + .args + .iter() + .position(|arg| arg.name == kw_name) + .ok_or_else(|| format!("Unknown keyword argument '{}'", kw_name))?; + + if param_map[param_pos].is_some() { + return Err(format!("Multiple values for argument '{}'", kw_name)); + } + param_map[param_pos] = Some(val_enum); + } else { + while pos_index < n_params && param_map[pos_index].is_some() { + pos_index += 1; + } + if pos_index >= n_params { + return Err("Too many positional arguments given to function.".to_string()); + } + param_map[pos_index] = Some(val_enum); + pos_index += 1; + } + } + + for (i, param) in fun.0.args.iter().enumerate() { + if param_map[i].is_none() { + if let Some(default_expr) = ¶m.default_value { + let default_val = ctx.gen_symbol_val(generator, default_expr, param.ty).into(); + param_map[i] = Some(default_val); + } else { + return Err(format!("Missing required argument '{}'", param.name)); + } + } + } + let mut real_params = Vec::with_capacity(n_params); + for (i, param_spec) in fun.0.args.iter().enumerate() { + let some_valenum = param_map[i].take().unwrap(); + let llvm_val = some_valenum.to_basic_value_enum(ctx, generator, param_spec.ty)?; + real_params.push((llvm_val, param_spec.ty)); + } + + let arg_count = real_params.len() as u64; let stackptr = call_stacksave(ctx, Some("rpc.stack")); - let args_ptr = ctx + + let i32_ty = ctx.ctx.i32_type(); + let arg_array = ctx .builder .build_array_alloca( ptr_type, - ctx.ctx.i32_type().const_int(arg_length as u64, false), - "argptr", + i32_ty.const_int(arg_count, false), + "rpc.arg_array", ) .unwrap(); - // -- rpc args handling - let mut keys = fun.0.args.clone(); - let mut mapping = HashMap::new(); - let mut is_keyword_arg = HashMap::new(); - - for (key, value) in args { - if let Some(key_name) = key { - mapping.insert(key_name, value); - is_keyword_arg.insert(key_name, true); - } else { - let arg_name = keys.remove(0).name; - mapping.insert(arg_name, value); - is_keyword_arg.insert(arg_name, false); - } - } - // default value handling - for k in keys { - mapping - .insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()); - } - // reorder the parameters - let mut real_params = fun - .0 - .args - .iter() - .map(|arg| { - mapping - .remove(&arg.name) - .unwrap() - .to_basic_value_enum(ctx, generator, arg.ty) - .map(|llvm_val| (llvm_val, arg.ty)) - }) - .collect::, _>>()?; - if let Some(obj) = obj { - if let ValueEnum::Static(obj_val) = obj.1 { - real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0)); - } else { - // should be an error here... - panic!("only host object is allowed"); - } - } - - for (i, (arg, arg_ty)) in real_params.iter().enumerate() { - let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i)); - let arg_ptr = unsafe { + for (i, (llvm_val, ty)) in real_params.iter().enumerate() { + let arg_slot_ptr = unsafe { ctx.builder.build_gep( - args_ptr, - &[int32.const_int(i as u64, false)], - &format!("rpc.arg{i}"), + arg_array, + &[ + i32_ty.const_int(i as u64, false), + ], + &format!("rpc.arg_slot_{}", i), ) - } - .unwrap(); - ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); + }.unwrap(); + let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i)); + ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap(); } - // Before calling rpc_send/rpc_send_async, add keyword arg info to tag - for arg in &fun.0.args { - if *is_keyword_arg.get(&arg.name).unwrap_or(&false) { - tag.push(b'k'); // Mark as keyword argument - } - gen_rpc_tag(ctx, arg.ty, true, &mut tag)?; // Pass true for is_kwarg - } - - // call if is_async { let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| { ctx.module.add_function( "rpc_send_async", ctx.ctx.void_type().fn_type( - &[ - int32.into(), - tag_ptr_type.ptr_type(AddressSpace::default()).into(), - ptr_type.ptr_type(AddressSpace::default()).into(), - ], + &[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()], false, ), None, @@ -942,45 +946,42 @@ fn rpc_codegen_callback_fn<'ctx>( ctx.builder .build_call( rpc_send_async, - &[service_id.into(), tag_ptr.into(), args_ptr.into()], - "rpc.send", + &[ + service_id.into(), + tag_ptr.into(), + arg_array.into(), + ], + "rpc.send_async", ) .unwrap(); + call_stackrestore(ctx, stackptr); + return Ok(None); } else { let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { ctx.module.add_function( "rpc_send", ctx.ctx.void_type().fn_type( - &[ - int32.into(), - tag_ptr_type.ptr_type(AddressSpace::default()).into(), - ptr_type.ptr_type(AddressSpace::default()).into(), - ], + &[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()], false, ), None, ) }); ctx.builder - .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") + .build_call( + rpc_send, + &[ + service_id.into(), + tag_ptr.into(), + arg_array.into(), + ], + "rpc.send", + ) .unwrap(); - } + call_stackrestore(ctx, stackptr); - // reclaim stack space used by arguments - call_stackrestore(ctx, stackptr); - - if is_async { - // async RPCs do not return any values - Ok(None) - } else { - let result = format_rpc_ret(generator, ctx, fun.0.ret); - - if !result.is_some_and(|res| res.get_type().is_pointer_type()) { - // An RPC returning an NDArray would not touch here. - call_stackrestore(ctx, stackptr); - } - - Ok(result) + let maybe_ret = format_rpc_ret(generator, ctx, fun.0.ret); + Ok(maybe_ret) } } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index bb579054..742fa197 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,7 +3,7 @@ use std::{ cmp::max, collections::{HashMap, HashSet}, convert::{From, TryInto}, - iter::once, + iter::{once, repeat_n}, sync::Arc, }; @@ -187,7 +187,7 @@ fn fix_assignment_target_context(node: &mut ast::Located) { } } -impl<'a> Fold<()> for Inferencer<'a> { +impl Fold<()> for Inferencer<'_> { type TargetU = Option; type Error = InferenceError; @@ -657,7 +657,7 @@ impl<'a> Fold<()> for Inferencer<'a> { type InferenceResult = Result; -impl<'a> Inferencer<'a> { +impl Inferencer<'_> { /// Constrain a <: b /// Currently implemented as unification fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> { @@ -1234,6 +1234,45 @@ impl<'a> Inferencer<'a> { })); } + if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 { + let ndarray = self.fold_expr(args.remove(0))?; + + let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap()); + + // Make a tuple of size `ndims` full of int32 (TODO: Make it usize) + let ret_ty = TypeEnum::TTuple { + ty: repeat_n(self.primitives.int32, ndims as usize).collect_vec(), + is_vararg_ctx: false, + }; + let ret_ty = self.unifier.add_ty(ret_ty); + + let func_ty = TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "a".into(), + default_value: None, + ty: ndarray.custom.unwrap(), + is_vararg: false, + }], + ret: ret_ty, + vars: VarMap::new(), + }); + let func_ty = self.unifier.add_ty(func_ty); + + return Ok(Some(Located { + location, + custom: Some(ret_ty), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(func_ty), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![ndarray], + keywords: vec![], + }, + })); + } + if id == &"np_dot".into() { let arg0 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?; @@ -1555,7 +1594,7 @@ impl<'a> Inferencer<'a> { })); } // 2-argument ndarray n-dimensional factory functions - if id == &"np_reshape".into() && args.len() == 2 { + if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 { let arg0 = self.fold_expr(args.remove(0))?; let shape_expr = args.remove(0); @@ -1832,47 +1871,20 @@ impl<'a> Inferencer<'a> { if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { if sign.vars.is_empty() { - // Build keyword argument map - let mut kwargs_map = HashMap::new(); - for kw in &keywords { - if let Some(name) = &kw.node.arg { - // Check if keyword arg exists in function signature - if !sign.args.iter().any(|arg| arg.name == *name) { - return report_error( - &format!("Unexpected keyword argument '{}'", name), - kw.location, - ); - } - kwargs_map.insert(*name, kw.node.value.custom.unwrap()); - } - } - - // Validate that all required args are provided - for arg in &sign.args { - if arg.default_value.is_none() - && !kwargs_map.contains_key(&arg.name) - && args.len() < sign.args.len() - { - return report_error( - &format!("Missing required argument '{}'", arg.name), - location, - ); - } - } - let call = Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), - kwargs: kwargs_map, + kwargs: keywords + .iter() + .map(|v| (*v.node.arg.as_ref().unwrap(), v.node.value.custom.unwrap())) + .collect(), fun: RefCell::new(None), ret: sign.ret, loc: Some(location), operator_info: None, }; - self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) })?; - return Ok(Located { location, custom: Some(sign.ret), @@ -1886,7 +1898,7 @@ impl<'a> Inferencer<'a> { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), kwargs: keywords .iter() - .filter_map(|v| v.node.arg.map(|name| (name, v.node.value.custom.unwrap()))) + .map(|v| (*v.node.arg.as_ref().unwrap(), v.custom.unwrap())) .collect(), fun: RefCell::new(None), ret,