diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index b22d3cc2a8..1b79c0b64a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1,5 +1,3 @@ -use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; - use crate::{ codegen::{ classes::{ @@ -7,7 +5,7 @@ use crate::{ ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, - gen_in_range_check, get_llvm_abi_type, get_llvm_type, + gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, irrt::*, llvm_intrinsics::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, @@ -42,6 +40,8 @@ use nac3parser::ast::{ self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, Unaryop, }; +use std::iter::{repeat, repeat_with}; +use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; pub fn get_subst_key( unifier: &mut Unifier, @@ -517,16 +517,19 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } } } + let params = if loc_params.is_empty() { params } else { &loc_params }; let params = fun .get_type() .get_param_types() .into_iter() + .map(Some) + .chain(repeat(None)) .zip(params.iter()) .map(|(ty, val)| match (ty, val.get_type()) { - (BasicTypeEnum::PointerType(arg_ty), BasicTypeEnum::PointerType(val_ty)) + (Some(BasicTypeEnum::PointerType(arg_ty)), BasicTypeEnum::PointerType(val_ty)) if { - ty != val.get_type() + ty.unwrap() != val.get_type() && arg_ty.get_element_type().is_struct_type() && val_ty.get_element_type().is_struct_type() } => @@ -536,6 +539,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { _ => *val, }) .collect_vec(); + let result = if let Some(target) = self.unwind_target { let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let then_block = self.ctx.append_basic_block(current, &format!("after.{call_name}")); @@ -555,6 +559,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .map(Either::left) .unwrap() }; + if let Some(slot) = return_slot { Some(self.builder.build_load(slot, call_name).unwrap()) } else { @@ -729,10 +734,10 @@ pub fn gen_func_instance<'ctx>( .collect(); let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); + let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() }; if let Some(obj) = &obj { let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); - let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() }; args.insert( 0, @@ -744,6 +749,26 @@ pub fn gen_func_instance<'ctx>( }, ); } + + if let Some(vararg_arg) = sign.args.iter().find(|arg| arg.is_vararg) { + let va_count_arg = get_va_count_arg_name(vararg_arg.name); + + args.insert( + args.len() - 1, + ConcreteFuncArg { + name: va_count_arg, + ty: store.from_unifier_type( + &mut ctx.unifier, + &ctx.primitives, + ctx.primitives.usize(), + &mut cache, + ), + default_value: None, + is_vararg: false, + }, + ); + } + let signature = store.add_cty(signature); ctx.registry.add_task(CodeGenTask { @@ -768,11 +793,17 @@ pub fn gen_call<'ctx, G: CodeGenerator>( fun: (&FunSignature, DefinitionId), params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Result>, String> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); let id; let key; let param_vals; let is_extern; + let vararg_arg; + + // Ensure that the function object only contains up to 1 vararg parameter + debug_assert!(fun.0.args.iter().filter(|arg| arg.is_vararg).count() <= 1); let symbol = { // make sure this lock guard is dropped at the end of this scope... @@ -788,22 +819,72 @@ pub fn gen_call<'ctx, G: CodeGenerator>( return callback.run(ctx, obj, fun, params, generator); } is_extern = instance_to_stmt.is_empty(); + vararg_arg = fun.0.args.iter().find(|arg| arg.is_vararg); let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None); let mut keys = fun.0.args.clone(); - let mut mapping = HashMap::new(); + let mut mapping = HashMap::<_, Vec>::new(); + for (key, value) in params { - mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + // Find the matching argument + let matching_param = fun + .0 + .args + .iter() + .find_or_last(|p| key.is_some_and(|k| k == p.name)) + .unwrap(); + if matching_param.is_vararg { + if key.is_none() && !keys.is_empty() { + keys.remove(0); + } + + // vararg is lowered into two arguments - va_count and `...` + // Handle va_count first, for each argument encountered we increment it by 1 + let va_count = get_va_count_arg_name(matching_param.name); + if let Some(params) = mapping.get_mut(&va_count) { + debug_assert_eq!(params.len(), 1); + + let param = params[0] + .clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.usize())? + .into_int_value(); + params[0] = param.const_add(llvm_usize.const_int(1, false)).into(); + } else { + mapping.insert(va_count, vec![llvm_usize.const_int(1, false).into()]); + } + + if let Some(param) = mapping.get_mut(&matching_param.name) { + param.push(value); + } else { + mapping.insert(key.unwrap_or(matching_param.name), vec![value]); + } + } else { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), vec![value]); + } } + // default value handling for k in keys { if mapping.contains_key(&k.name) { continue; } - mapping.insert( - k.name, - ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into(), - ); + + if k.is_vararg { + mapping.insert( + get_va_count_arg_name(k.name), + vec![llvm_usize.const_zero().into()], + ); + + mapping.insert(k.name, Vec::default()); + } else { + mapping.insert( + k.name, + vec![ctx + .gen_symbol_val(generator, &k.default_value.unwrap(), k.ty) + .into()], + ); + } } + // reorder the parameters let mut real_params = fun .0 @@ -812,13 +893,24 @@ pub fn gen_call<'ctx, G: CodeGenerator>( .map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty)) .collect_vec(); if let Some(obj) = &obj { - real_params.insert(0, (obj.1.clone(), obj.0)); + real_params.insert(0, (vec![obj.1.clone()], obj.0)); } + if let Some(vararg) = vararg_arg { + let vararg_arg_name = get_va_count_arg_name(vararg.name); + + real_params.insert( + real_params.len() - 1, + (mapping[&vararg_arg_name].clone(), ctx.primitives.usize()), + ); + } + let static_params = real_params .iter() .enumerate() .filter_map(|(i, (v, _))| { - if let ValueEnum::Static(s) = v { + if v.len() != 1 { + None + } else if let ValueEnum::Static(s) = &v[0] { Some((i, s.clone())) } else { None @@ -848,8 +940,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>( }; param_vals = real_params .into_iter() - .map(|(p, t)| p.to_basic_value_enum(ctx, generator, t)) - .collect::, String>>()?; + .map(|(ps, t)| { + ps.into_iter().map(|p| p.to_basic_value_enum(ctx, generator, t)).collect() + }) + .collect::>, _>>()? + .into_iter() + .flatten() + .collect::>(); instance_to_symbol.get(&key).cloned().ok_or_else(String::new) } TopLevelDef::Class { .. } => { @@ -878,6 +975,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( let mut params = args .iter() .enumerate() + .filter(|(_, arg)| !arg.is_vararg) .map(|(i, arg)| { match ctx.get_llvm_abi_type(generator, arg.ty) { BasicTypeEnum::StructType(ty) if is_extern => { @@ -892,9 +990,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>( if has_sret { params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into()); } + let is_vararg = args.iter().any(|arg| arg.is_vararg); + if is_vararg { + params.push(generator.get_size_type(ctx.ctx).into()); + } let fun_ty = match ret_type { - Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false), - _ => ctx.ctx.void_type().fn_type(¶ms, false), + Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, is_vararg), + _ => ctx.ctx.void_type().fn_type(¶ms, is_vararg), }; let fun_val = ctx.module.add_function(&symbol, fun_ty, None); let offset = if has_sret { @@ -926,13 +1028,16 @@ pub fn gen_call<'ctx, G: CodeGenerator>( }); // Convert boolean parameter values into i1 + let vararg_ty = vararg_arg.map(|vararg| ctx.get_llvm_abi_type(generator, vararg.ty)); let param_vals = fun_val .get_params() .iter() + .map(BasicValueEnum::get_type) + .chain(repeat_with(|| vararg_ty.unwrap())) .zip(param_vals) .map(|(p, v)| { - if p.is_int_value() && v.is_int_value() { - let expected_ty = p.into_int_value().get_type(); + if p.is_int_type() && v.is_int_value() { + let expected_ty = p.into_int_type(); let param_val = v.into_int_value(); if expected_ty.get_bit_width() == 1 && param_val.get_type().get_bit_width() != 1 { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 7bc8c98923..4f79ddf4d3 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1049,3 +1049,9 @@ fn gen_in_range_check<'ctx>( ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap() } + +/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments +/// passed to the variadic function. +fn get_va_count_arg_name(arg_name: StrRef) -> StrRef { + format!("__{}_va_count", &arg_name).into() +}