forked from M-Labs/nac3
1
0
Fork 0

core/codegen/expr: Implement vararg handling in gen_call

This commit is contained in:
David Mak 2024-07-11 18:01:48 +08:00
parent faa3bb97ad
commit f5fb504a15
2 changed files with 131 additions and 20 deletions

View File

@ -1,5 +1,3 @@
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{
codegen::{
classes::{
@ -7,7 +5,7 @@ use crate::{
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
},
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type,
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
irrt::*,
llvm_intrinsics::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
@ -42,6 +40,8 @@ use nac3parser::ast::{
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
Unaryop,
};
use std::iter::{repeat, repeat_with};
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
pub fn get_subst_key(
unifier: &mut Unifier,
@ -517,16 +517,19 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
}
}
let params = if loc_params.is_empty() { params } else { &loc_params };
let params = fun
.get_type()
.get_param_types()
.into_iter()
.map(Some)
.chain(repeat(None))
.zip(params.iter())
.map(|(ty, val)| match (ty, val.get_type()) {
(BasicTypeEnum::PointerType(arg_ty), BasicTypeEnum::PointerType(val_ty))
(Some(BasicTypeEnum::PointerType(arg_ty)), BasicTypeEnum::PointerType(val_ty))
if {
ty != val.get_type()
ty.unwrap() != val.get_type()
&& arg_ty.get_element_type().is_struct_type()
&& val_ty.get_element_type().is_struct_type()
} =>
@ -536,6 +539,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
_ => *val,
})
.collect_vec();
let result = if let Some(target) = self.unwind_target {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_block = self.ctx.append_basic_block(current, &format!("after.{call_name}"));
@ -555,6 +559,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.map(Either::left)
.unwrap()
};
if let Some(slot) = return_slot {
Some(self.builder.build_load(slot, call_name).unwrap())
} else {
@ -729,10 +734,10 @@ pub fn gen_func_instance<'ctx>(
.collect();
let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache);
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
if let Some(obj) = &obj {
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
args.insert(
0,
@ -744,6 +749,26 @@ pub fn gen_func_instance<'ctx>(
},
);
}
if let Some(vararg_arg) = sign.args.iter().find(|arg| arg.is_vararg) {
let va_count_arg = get_va_count_arg_name(vararg_arg.name);
args.insert(
args.len() - 1,
ConcreteFuncArg {
name: va_count_arg,
ty: store.from_unifier_type(
&mut ctx.unifier,
&ctx.primitives,
ctx.primitives.usize(),
&mut cache,
),
default_value: None,
is_vararg: false,
},
);
}
let signature = store.add_cty(signature);
ctx.registry.add_task(CodeGenTask {
@ -768,11 +793,17 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap();
let id;
let key;
let param_vals;
let is_extern;
let vararg_arg;
// Ensure that the function object only contains up to 1 vararg parameter
debug_assert!(fun.0.args.iter().filter(|arg| arg.is_vararg).count() <= 1);
let symbol = {
// make sure this lock guard is dropped at the end of this scope...
@ -788,22 +819,72 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
return callback.run(ctx, obj, fun, params, generator);
}
is_extern = instance_to_stmt.is_empty();
vararg_arg = fun.0.args.iter().find(|arg| arg.is_vararg);
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
let mut mapping = HashMap::<_, Vec<ValueEnum>>::new();
for (key, value) in params {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
// Find the matching argument
let matching_param = fun
.0
.args
.iter()
.find_or_last(|p| key.is_some_and(|k| k == p.name))
.unwrap();
if matching_param.is_vararg {
if key.is_none() && !keys.is_empty() {
keys.remove(0);
}
// vararg is lowered into two arguments - va_count and `...`
// Handle va_count first, for each argument encountered we increment it by 1
let va_count = get_va_count_arg_name(matching_param.name);
if let Some(params) = mapping.get_mut(&va_count) {
debug_assert_eq!(params.len(), 1);
let param = params[0]
.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.usize())?
.into_int_value();
params[0] = param.const_add(llvm_usize.const_int(1, false)).into();
} else {
mapping.insert(va_count, vec![llvm_usize.const_int(1, false).into()]);
}
if let Some(param) = mapping.get_mut(&matching_param.name) {
param.push(value);
} else {
mapping.insert(key.unwrap_or(matching_param.name), vec![value]);
}
} else {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), vec![value]);
}
}
// default value handling
for k in keys {
if mapping.contains_key(&k.name) {
continue;
}
if k.is_vararg {
mapping.insert(
get_va_count_arg_name(k.name),
vec![llvm_usize.const_zero().into()],
);
mapping.insert(k.name, Vec::default());
} else {
mapping.insert(
k.name,
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into(),
vec![ctx
.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty)
.into()],
);
}
}
// reorder the parameters
let mut real_params = fun
.0
@ -812,13 +893,24 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
.map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty))
.collect_vec();
if let Some(obj) = &obj {
real_params.insert(0, (obj.1.clone(), obj.0));
real_params.insert(0, (vec![obj.1.clone()], obj.0));
}
if let Some(vararg) = vararg_arg {
let vararg_arg_name = get_va_count_arg_name(vararg.name);
real_params.insert(
real_params.len() - 1,
(mapping[&vararg_arg_name].clone(), ctx.primitives.usize()),
);
}
let static_params = real_params
.iter()
.enumerate()
.filter_map(|(i, (v, _))| {
if let ValueEnum::Static(s) = v {
if v.len() != 1 {
None
} else if let ValueEnum::Static(s) = &v[0] {
Some((i, s.clone()))
} else {
None
@ -848,8 +940,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
};
param_vals = real_params
.into_iter()
.map(|(p, t)| p.to_basic_value_enum(ctx, generator, t))
.collect::<Result<Vec<_>, String>>()?;
.map(|(ps, t)| {
ps.into_iter().map(|p| p.to_basic_value_enum(ctx, generator, t)).collect()
})
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
instance_to_symbol.get(&key).cloned().ok_or_else(String::new)
}
TopLevelDef::Class { .. } => {
@ -878,6 +975,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
let mut params = args
.iter()
.enumerate()
.filter(|(_, arg)| !arg.is_vararg)
.map(|(i, arg)| {
match ctx.get_llvm_abi_type(generator, arg.ty) {
BasicTypeEnum::StructType(ty) if is_extern => {
@ -892,9 +990,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
if has_sret {
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
}
let is_vararg = args.iter().any(|arg| arg.is_vararg);
if is_vararg {
params.push(generator.get_size_type(ctx.ctx).into());
}
let fun_ty = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => ctx.ctx.void_type().fn_type(&params, false),
Some(ret_type) if !has_sret => ret_type.fn_type(&params, is_vararg),
_ => ctx.ctx.void_type().fn_type(&params, is_vararg),
};
let fun_val = ctx.module.add_function(&symbol, fun_ty, None);
let offset = if has_sret {
@ -926,13 +1028,16 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
});
// Convert boolean parameter values into i1
let vararg_ty = vararg_arg.map(|vararg| ctx.get_llvm_abi_type(generator, vararg.ty));
let param_vals = fun_val
.get_params()
.iter()
.map(BasicValueEnum::get_type)
.chain(repeat_with(|| vararg_ty.unwrap()))
.zip(param_vals)
.map(|(p, v)| {
if p.is_int_value() && v.is_int_value() {
let expected_ty = p.into_int_value().get_type();
if p.is_int_type() && v.is_int_value() {
let expected_ty = p.into_int_type();
let param_val = v.into_int_value();
if expected_ty.get_bit_width() == 1 && param_val.get_type().get_bit_width() != 1 {

View File

@ -1049,3 +1049,9 @@ fn gen_in_range_check<'ctx>(
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
}
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
/// passed to the variadic function.
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
format!("__{}_va_count", &arg_name).into()
}