forked from M-Labs/nac3
core/codegen/expr: Implement vararg handling in gen_call
This commit is contained in:
parent
faa3bb97ad
commit
f5fb504a15
|
@ -1,5 +1,3 @@
|
|||
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
classes::{
|
||||
|
@ -7,7 +5,7 @@ use crate::{
|
|||
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||
},
|
||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type,
|
||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
||||
irrt::*,
|
||||
llvm_intrinsics::{
|
||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||
|
@ -42,6 +40,8 @@ use nac3parser::ast::{
|
|||
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
Unaryop,
|
||||
};
|
||||
use std::iter::{repeat, repeat_with};
|
||||
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||
|
||||
pub fn get_subst_key(
|
||||
unifier: &mut Unifier,
|
||||
|
@ -517,16 +517,19 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
let params = if loc_params.is_empty() { params } else { &loc_params };
|
||||
let params = fun
|
||||
.get_type()
|
||||
.get_param_types()
|
||||
.into_iter()
|
||||
.map(Some)
|
||||
.chain(repeat(None))
|
||||
.zip(params.iter())
|
||||
.map(|(ty, val)| match (ty, val.get_type()) {
|
||||
(BasicTypeEnum::PointerType(arg_ty), BasicTypeEnum::PointerType(val_ty))
|
||||
(Some(BasicTypeEnum::PointerType(arg_ty)), BasicTypeEnum::PointerType(val_ty))
|
||||
if {
|
||||
ty != val.get_type()
|
||||
ty.unwrap() != val.get_type()
|
||||
&& arg_ty.get_element_type().is_struct_type()
|
||||
&& val_ty.get_element_type().is_struct_type()
|
||||
} =>
|
||||
|
@ -536,6 +539,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
_ => *val,
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let result = if let Some(target) = self.unwind_target {
|
||||
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let then_block = self.ctx.append_basic_block(current, &format!("after.{call_name}"));
|
||||
|
@ -555,6 +559,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
.map(Either::left)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
if let Some(slot) = return_slot {
|
||||
Some(self.builder.build_load(slot, call_name).unwrap())
|
||||
} else {
|
||||
|
@ -729,10 +734,10 @@ pub fn gen_func_instance<'ctx>(
|
|||
.collect();
|
||||
|
||||
let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache);
|
||||
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
|
||||
|
||||
if let Some(obj) = &obj {
|
||||
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
|
||||
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
|
||||
|
||||
args.insert(
|
||||
0,
|
||||
|
@ -744,6 +749,26 @@ pub fn gen_func_instance<'ctx>(
|
|||
},
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(vararg_arg) = sign.args.iter().find(|arg| arg.is_vararg) {
|
||||
let va_count_arg = get_va_count_arg_name(vararg_arg.name);
|
||||
|
||||
args.insert(
|
||||
args.len() - 1,
|
||||
ConcreteFuncArg {
|
||||
name: va_count_arg,
|
||||
ty: store.from_unifier_type(
|
||||
&mut ctx.unifier,
|
||||
&ctx.primitives,
|
||||
ctx.primitives.usize(),
|
||||
&mut cache,
|
||||
),
|
||||
default_value: None,
|
||||
is_vararg: false,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let signature = store.add_cty(signature);
|
||||
|
||||
ctx.registry.add_task(CodeGenTask {
|
||||
|
@ -768,11 +793,17 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
|||
fun: (&FunSignature, DefinitionId),
|
||||
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap();
|
||||
let id;
|
||||
let key;
|
||||
let param_vals;
|
||||
let is_extern;
|
||||
let vararg_arg;
|
||||
|
||||
// Ensure that the function object only contains up to 1 vararg parameter
|
||||
debug_assert!(fun.0.args.iter().filter(|arg| arg.is_vararg).count() <= 1);
|
||||
|
||||
let symbol = {
|
||||
// make sure this lock guard is dropped at the end of this scope...
|
||||
|
@ -788,22 +819,72 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
|||
return callback.run(ctx, obj, fun, params, generator);
|
||||
}
|
||||
is_extern = instance_to_stmt.is_empty();
|
||||
vararg_arg = fun.0.args.iter().find(|arg| arg.is_vararg);
|
||||
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
let mut mapping = HashMap::<_, Vec<ValueEnum>>::new();
|
||||
|
||||
for (key, value) in params {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||
// Find the matching argument
|
||||
let matching_param = fun
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.find_or_last(|p| key.is_some_and(|k| k == p.name))
|
||||
.unwrap();
|
||||
if matching_param.is_vararg {
|
||||
if key.is_none() && !keys.is_empty() {
|
||||
keys.remove(0);
|
||||
}
|
||||
|
||||
// vararg is lowered into two arguments - va_count and `...`
|
||||
// Handle va_count first, for each argument encountered we increment it by 1
|
||||
let va_count = get_va_count_arg_name(matching_param.name);
|
||||
if let Some(params) = mapping.get_mut(&va_count) {
|
||||
debug_assert_eq!(params.len(), 1);
|
||||
|
||||
let param = params[0]
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.usize())?
|
||||
.into_int_value();
|
||||
params[0] = param.const_add(llvm_usize.const_int(1, false)).into();
|
||||
} else {
|
||||
mapping.insert(va_count, vec![llvm_usize.const_int(1, false).into()]);
|
||||
}
|
||||
|
||||
if let Some(param) = mapping.get_mut(&matching_param.name) {
|
||||
param.push(value);
|
||||
} else {
|
||||
mapping.insert(key.unwrap_or(matching_param.name), vec![value]);
|
||||
}
|
||||
} else {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), vec![value]);
|
||||
}
|
||||
}
|
||||
|
||||
// default value handling
|
||||
for k in keys {
|
||||
if mapping.contains_key(&k.name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if k.is_vararg {
|
||||
mapping.insert(
|
||||
get_va_count_arg_name(k.name),
|
||||
vec![llvm_usize.const_zero().into()],
|
||||
);
|
||||
|
||||
mapping.insert(k.name, Vec::default());
|
||||
} else {
|
||||
mapping.insert(
|
||||
k.name,
|
||||
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into(),
|
||||
vec![ctx
|
||||
.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty)
|
||||
.into()],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// reorder the parameters
|
||||
let mut real_params = fun
|
||||
.0
|
||||
|
@ -812,13 +893,24 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
|||
.map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty))
|
||||
.collect_vec();
|
||||
if let Some(obj) = &obj {
|
||||
real_params.insert(0, (obj.1.clone(), obj.0));
|
||||
real_params.insert(0, (vec![obj.1.clone()], obj.0));
|
||||
}
|
||||
if let Some(vararg) = vararg_arg {
|
||||
let vararg_arg_name = get_va_count_arg_name(vararg.name);
|
||||
|
||||
real_params.insert(
|
||||
real_params.len() - 1,
|
||||
(mapping[&vararg_arg_name].clone(), ctx.primitives.usize()),
|
||||
);
|
||||
}
|
||||
|
||||
let static_params = real_params
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, (v, _))| {
|
||||
if let ValueEnum::Static(s) = v {
|
||||
if v.len() != 1 {
|
||||
None
|
||||
} else if let ValueEnum::Static(s) = &v[0] {
|
||||
Some((i, s.clone()))
|
||||
} else {
|
||||
None
|
||||
|
@ -848,8 +940,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
|||
};
|
||||
param_vals = real_params
|
||||
.into_iter()
|
||||
.map(|(p, t)| p.to_basic_value_enum(ctx, generator, t))
|
||||
.collect::<Result<Vec<_>, String>>()?;
|
||||
.map(|(ps, t)| {
|
||||
ps.into_iter().map(|p| p.to_basic_value_enum(ctx, generator, t)).collect()
|
||||
})
|
||||
.collect::<Result<Vec<Vec<_>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
instance_to_symbol.get(&key).cloned().ok_or_else(String::new)
|
||||
}
|
||||
TopLevelDef::Class { .. } => {
|
||||
|
@ -878,6 +975,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
|||
let mut params = args
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, arg)| !arg.is_vararg)
|
||||
.map(|(i, arg)| {
|
||||
match ctx.get_llvm_abi_type(generator, arg.ty) {
|
||||
BasicTypeEnum::StructType(ty) if is_extern => {
|
||||
|
@ -892,9 +990,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
|||
if has_sret {
|
||||
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
|
||||
}
|
||||
let is_vararg = args.iter().any(|arg| arg.is_vararg);
|
||||
if is_vararg {
|
||||
params.push(generator.get_size_type(ctx.ctx).into());
|
||||
}
|
||||
let fun_ty = match ret_type {
|
||||
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false),
|
||||
_ => ctx.ctx.void_type().fn_type(¶ms, false),
|
||||
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, is_vararg),
|
||||
_ => ctx.ctx.void_type().fn_type(¶ms, is_vararg),
|
||||
};
|
||||
let fun_val = ctx.module.add_function(&symbol, fun_ty, None);
|
||||
let offset = if has_sret {
|
||||
|
@ -926,13 +1028,16 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
|||
});
|
||||
|
||||
// Convert boolean parameter values into i1
|
||||
let vararg_ty = vararg_arg.map(|vararg| ctx.get_llvm_abi_type(generator, vararg.ty));
|
||||
let param_vals = fun_val
|
||||
.get_params()
|
||||
.iter()
|
||||
.map(BasicValueEnum::get_type)
|
||||
.chain(repeat_with(|| vararg_ty.unwrap()))
|
||||
.zip(param_vals)
|
||||
.map(|(p, v)| {
|
||||
if p.is_int_value() && v.is_int_value() {
|
||||
let expected_ty = p.into_int_value().get_type();
|
||||
if p.is_int_type() && v.is_int_value() {
|
||||
let expected_ty = p.into_int_type();
|
||||
let param_val = v.into_int_value();
|
||||
|
||||
if expected_ty.get_bit_width() == 1 && param_val.get_type().get_bit_width() != 1 {
|
||||
|
|
|
@ -1049,3 +1049,9 @@ fn gen_in_range_check<'ctx>(
|
|||
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
|
||||
}
|
||||
|
||||
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
|
||||
/// passed to the variadic function.
|
||||
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
|
||||
format!("__{}_va_count", &arg_name).into()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue