forked from M-Labs/nac3
core/codegen/expr: Implement vararg handling in gen_call
This commit is contained in:
parent
faa3bb97ad
commit
f5fb504a15
|
@ -1,5 +1,3 @@
|
||||||
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{
|
classes::{
|
||||||
|
@ -7,7 +5,7 @@ use crate::{
|
||||||
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type,
|
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
||||||
irrt::*,
|
irrt::*,
|
||||||
llvm_intrinsics::{
|
llvm_intrinsics::{
|
||||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||||
|
@ -42,6 +40,8 @@ use nac3parser::ast::{
|
||||||
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||||
Unaryop,
|
Unaryop,
|
||||||
};
|
};
|
||||||
|
use std::iter::{repeat, repeat_with};
|
||||||
|
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||||
|
|
||||||
pub fn get_subst_key(
|
pub fn get_subst_key(
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
|
@ -517,16 +517,19 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let params = if loc_params.is_empty() { params } else { &loc_params };
|
let params = if loc_params.is_empty() { params } else { &loc_params };
|
||||||
let params = fun
|
let params = fun
|
||||||
.get_type()
|
.get_type()
|
||||||
.get_param_types()
|
.get_param_types()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
.map(Some)
|
||||||
|
.chain(repeat(None))
|
||||||
.zip(params.iter())
|
.zip(params.iter())
|
||||||
.map(|(ty, val)| match (ty, val.get_type()) {
|
.map(|(ty, val)| match (ty, val.get_type()) {
|
||||||
(BasicTypeEnum::PointerType(arg_ty), BasicTypeEnum::PointerType(val_ty))
|
(Some(BasicTypeEnum::PointerType(arg_ty)), BasicTypeEnum::PointerType(val_ty))
|
||||||
if {
|
if {
|
||||||
ty != val.get_type()
|
ty.unwrap() != val.get_type()
|
||||||
&& arg_ty.get_element_type().is_struct_type()
|
&& arg_ty.get_element_type().is_struct_type()
|
||||||
&& val_ty.get_element_type().is_struct_type()
|
&& val_ty.get_element_type().is_struct_type()
|
||||||
} =>
|
} =>
|
||||||
|
@ -536,6 +539,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
_ => *val,
|
_ => *val,
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
|
||||||
let result = if let Some(target) = self.unwind_target {
|
let result = if let Some(target) = self.unwind_target {
|
||||||
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||||
let then_block = self.ctx.append_basic_block(current, &format!("after.{call_name}"));
|
let then_block = self.ctx.append_basic_block(current, &format!("after.{call_name}"));
|
||||||
|
@ -555,6 +559,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
.map(Either::left)
|
.map(Either::left)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(slot) = return_slot {
|
if let Some(slot) = return_slot {
|
||||||
Some(self.builder.build_load(slot, call_name).unwrap())
|
Some(self.builder.build_load(slot, call_name).unwrap())
|
||||||
} else {
|
} else {
|
||||||
|
@ -729,10 +734,10 @@ pub fn gen_func_instance<'ctx>(
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache);
|
let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache);
|
||||||
|
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
|
||||||
|
|
||||||
if let Some(obj) = &obj {
|
if let Some(obj) = &obj {
|
||||||
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
|
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
|
||||||
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
|
|
||||||
|
|
||||||
args.insert(
|
args.insert(
|
||||||
0,
|
0,
|
||||||
|
@ -744,6 +749,26 @@ pub fn gen_func_instance<'ctx>(
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(vararg_arg) = sign.args.iter().find(|arg| arg.is_vararg) {
|
||||||
|
let va_count_arg = get_va_count_arg_name(vararg_arg.name);
|
||||||
|
|
||||||
|
args.insert(
|
||||||
|
args.len() - 1,
|
||||||
|
ConcreteFuncArg {
|
||||||
|
name: va_count_arg,
|
||||||
|
ty: store.from_unifier_type(
|
||||||
|
&mut ctx.unifier,
|
||||||
|
&ctx.primitives,
|
||||||
|
ctx.primitives.usize(),
|
||||||
|
&mut cache,
|
||||||
|
),
|
||||||
|
default_value: None,
|
||||||
|
is_vararg: false,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let signature = store.add_cty(signature);
|
let signature = store.add_cty(signature);
|
||||||
|
|
||||||
ctx.registry.add_task(CodeGenTask {
|
ctx.registry.add_task(CodeGenTask {
|
||||||
|
@ -768,11 +793,17 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap();
|
let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap();
|
||||||
let id;
|
let id;
|
||||||
let key;
|
let key;
|
||||||
let param_vals;
|
let param_vals;
|
||||||
let is_extern;
|
let is_extern;
|
||||||
|
let vararg_arg;
|
||||||
|
|
||||||
|
// Ensure that the function object only contains up to 1 vararg parameter
|
||||||
|
debug_assert!(fun.0.args.iter().filter(|arg| arg.is_vararg).count() <= 1);
|
||||||
|
|
||||||
let symbol = {
|
let symbol = {
|
||||||
// make sure this lock guard is dropped at the end of this scope...
|
// make sure this lock guard is dropped at the end of this scope...
|
||||||
|
@ -788,22 +819,72 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||||
return callback.run(ctx, obj, fun, params, generator);
|
return callback.run(ctx, obj, fun, params, generator);
|
||||||
}
|
}
|
||||||
is_extern = instance_to_stmt.is_empty();
|
is_extern = instance_to_stmt.is_empty();
|
||||||
|
vararg_arg = fun.0.args.iter().find(|arg| arg.is_vararg);
|
||||||
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
|
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
|
||||||
let mut keys = fun.0.args.clone();
|
let mut keys = fun.0.args.clone();
|
||||||
let mut mapping = HashMap::new();
|
let mut mapping = HashMap::<_, Vec<ValueEnum>>::new();
|
||||||
|
|
||||||
for (key, value) in params {
|
for (key, value) in params {
|
||||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
// Find the matching argument
|
||||||
|
let matching_param = fun
|
||||||
|
.0
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.find_or_last(|p| key.is_some_and(|k| k == p.name))
|
||||||
|
.unwrap();
|
||||||
|
if matching_param.is_vararg {
|
||||||
|
if key.is_none() && !keys.is_empty() {
|
||||||
|
keys.remove(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vararg is lowered into two arguments - va_count and `...`
|
||||||
|
// Handle va_count first, for each argument encountered we increment it by 1
|
||||||
|
let va_count = get_va_count_arg_name(matching_param.name);
|
||||||
|
if let Some(params) = mapping.get_mut(&va_count) {
|
||||||
|
debug_assert_eq!(params.len(), 1);
|
||||||
|
|
||||||
|
let param = params[0]
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.usize())?
|
||||||
|
.into_int_value();
|
||||||
|
params[0] = param.const_add(llvm_usize.const_int(1, false)).into();
|
||||||
|
} else {
|
||||||
|
mapping.insert(va_count, vec![llvm_usize.const_int(1, false).into()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(param) = mapping.get_mut(&matching_param.name) {
|
||||||
|
param.push(value);
|
||||||
|
} else {
|
||||||
|
mapping.insert(key.unwrap_or(matching_param.name), vec![value]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), vec![value]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// default value handling
|
// default value handling
|
||||||
for k in keys {
|
for k in keys {
|
||||||
if mapping.contains_key(&k.name) {
|
if mapping.contains_key(&k.name) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
mapping.insert(
|
|
||||||
k.name,
|
if k.is_vararg {
|
||||||
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into(),
|
mapping.insert(
|
||||||
);
|
get_va_count_arg_name(k.name),
|
||||||
|
vec![llvm_usize.const_zero().into()],
|
||||||
|
);
|
||||||
|
|
||||||
|
mapping.insert(k.name, Vec::default());
|
||||||
|
} else {
|
||||||
|
mapping.insert(
|
||||||
|
k.name,
|
||||||
|
vec![ctx
|
||||||
|
.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty)
|
||||||
|
.into()],
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reorder the parameters
|
// reorder the parameters
|
||||||
let mut real_params = fun
|
let mut real_params = fun
|
||||||
.0
|
.0
|
||||||
|
@ -812,13 +893,24 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||||
.map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty))
|
.map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty))
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
if let Some(obj) = &obj {
|
if let Some(obj) = &obj {
|
||||||
real_params.insert(0, (obj.1.clone(), obj.0));
|
real_params.insert(0, (vec![obj.1.clone()], obj.0));
|
||||||
}
|
}
|
||||||
|
if let Some(vararg) = vararg_arg {
|
||||||
|
let vararg_arg_name = get_va_count_arg_name(vararg.name);
|
||||||
|
|
||||||
|
real_params.insert(
|
||||||
|
real_params.len() - 1,
|
||||||
|
(mapping[&vararg_arg_name].clone(), ctx.primitives.usize()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let static_params = real_params
|
let static_params = real_params
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter_map(|(i, (v, _))| {
|
.filter_map(|(i, (v, _))| {
|
||||||
if let ValueEnum::Static(s) = v {
|
if v.len() != 1 {
|
||||||
|
None
|
||||||
|
} else if let ValueEnum::Static(s) = &v[0] {
|
||||||
Some((i, s.clone()))
|
Some((i, s.clone()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -848,8 +940,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||||
};
|
};
|
||||||
param_vals = real_params
|
param_vals = real_params
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(p, t)| p.to_basic_value_enum(ctx, generator, t))
|
.map(|(ps, t)| {
|
||||||
.collect::<Result<Vec<_>, String>>()?;
|
ps.into_iter().map(|p| p.to_basic_value_enum(ctx, generator, t)).collect()
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<Vec<_>>, _>>()?
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
instance_to_symbol.get(&key).cloned().ok_or_else(String::new)
|
instance_to_symbol.get(&key).cloned().ok_or_else(String::new)
|
||||||
}
|
}
|
||||||
TopLevelDef::Class { .. } => {
|
TopLevelDef::Class { .. } => {
|
||||||
|
@ -878,6 +975,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||||
let mut params = args
|
let mut params = args
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
|
.filter(|(_, arg)| !arg.is_vararg)
|
||||||
.map(|(i, arg)| {
|
.map(|(i, arg)| {
|
||||||
match ctx.get_llvm_abi_type(generator, arg.ty) {
|
match ctx.get_llvm_abi_type(generator, arg.ty) {
|
||||||
BasicTypeEnum::StructType(ty) if is_extern => {
|
BasicTypeEnum::StructType(ty) if is_extern => {
|
||||||
|
@ -892,9 +990,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||||
if has_sret {
|
if has_sret {
|
||||||
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
|
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
|
||||||
}
|
}
|
||||||
|
let is_vararg = args.iter().any(|arg| arg.is_vararg);
|
||||||
|
if is_vararg {
|
||||||
|
params.push(generator.get_size_type(ctx.ctx).into());
|
||||||
|
}
|
||||||
let fun_ty = match ret_type {
|
let fun_ty = match ret_type {
|
||||||
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false),
|
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, is_vararg),
|
||||||
_ => ctx.ctx.void_type().fn_type(¶ms, false),
|
_ => ctx.ctx.void_type().fn_type(¶ms, is_vararg),
|
||||||
};
|
};
|
||||||
let fun_val = ctx.module.add_function(&symbol, fun_ty, None);
|
let fun_val = ctx.module.add_function(&symbol, fun_ty, None);
|
||||||
let offset = if has_sret {
|
let offset = if has_sret {
|
||||||
|
@ -926,13 +1028,16 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||||
});
|
});
|
||||||
|
|
||||||
// Convert boolean parameter values into i1
|
// Convert boolean parameter values into i1
|
||||||
|
let vararg_ty = vararg_arg.map(|vararg| ctx.get_llvm_abi_type(generator, vararg.ty));
|
||||||
let param_vals = fun_val
|
let param_vals = fun_val
|
||||||
.get_params()
|
.get_params()
|
||||||
.iter()
|
.iter()
|
||||||
|
.map(BasicValueEnum::get_type)
|
||||||
|
.chain(repeat_with(|| vararg_ty.unwrap()))
|
||||||
.zip(param_vals)
|
.zip(param_vals)
|
||||||
.map(|(p, v)| {
|
.map(|(p, v)| {
|
||||||
if p.is_int_value() && v.is_int_value() {
|
if p.is_int_type() && v.is_int_value() {
|
||||||
let expected_ty = p.into_int_value().get_type();
|
let expected_ty = p.into_int_type();
|
||||||
let param_val = v.into_int_value();
|
let param_val = v.into_int_value();
|
||||||
|
|
||||||
if expected_ty.get_bit_width() == 1 && param_val.get_type().get_bit_width() != 1 {
|
if expected_ty.get_bit_width() == 1 && param_val.get_type().get_bit_width() != 1 {
|
||||||
|
|
|
@ -1049,3 +1049,9 @@ fn gen_in_range_check<'ctx>(
|
||||||
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
|
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
|
||||||
|
/// passed to the variadic function.
|
||||||
|
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
|
||||||
|
format!("__{}_va_count", &arg_name).into()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue