forked from M-Labs/nac3
1
0
Fork 0

nac3core: impl call attributes

sret for returning large structs, and byval for struct args in extern
function calls.
This commit is contained in:
pca006132 2022-03-09 22:09:36 +08:00
parent e266d3c2b0
commit 2f85bb3837
3 changed files with 131 additions and 28 deletions

View File

@ -13,16 +13,17 @@ use crate::{
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
};
use inkwell::{
types::{BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PointerValue},
AddressSpace,
attributes::{Attribute, AttributeLoc},
types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PointerValue}
};
use itertools::{chain, izip, zip, Itertools};
use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
};
use super::CodeGenerator;
use super::{CodeGenerator, need_sret};
pub fn get_subst_key(
unifier: &mut Unifier,
@ -299,7 +300,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
params: &[BasicValueEnum<'ctx>],
call_name: &str,
) -> Option<BasicValueEnum<'ctx>> {
if let Some(target) = self.unwind_target {
let mut loc_params: Vec<BasicValueEnum<'ctx>> = Vec::new();
let mut return_slot = None;
if fun.count_params() > 0 {
let sret_id = Attribute::get_named_enum_kind_id("sret");
let byval_id = Attribute::get_named_enum_kind_id("byval");
let offset = if fun.get_enum_attribute(AttributeLoc::Param(0), sret_id).is_some() {
return_slot = Some(self.builder.build_alloca(fun.get_type().get_param_types()[0]
.into_pointer_type().get_element_type().into_struct_type(), call_name));
loc_params.push((*return_slot.as_ref().unwrap()).into());
1
} else {
0
};
for (i, param) in params.iter().enumerate() {
if fun.get_enum_attribute(AttributeLoc::Param((i + offset) as u32), byval_id).is_some() {
// lazy update
if loc_params.is_empty() {
loc_params.extend(params[0..i+offset].iter().copied());
}
let slot = self.builder.build_alloca(param.get_type(), call_name);
loc_params.push(slot.into());
self.builder.build_store(slot, *param);
} else if !loc_params.is_empty() {
loc_params.push(*param);
}
}
}
let params = if loc_params.is_empty() {
params
} else {
&loc_params
};
let result = if let Some(target) = self.unwind_target {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name));
let result = self
@ -312,6 +346,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} else {
let param: Vec<_> = params.iter().map(|v| (*v).into()).collect();
self.builder.build_call(fun, &param, call_name).try_as_basic_value().left()
};
if let Some(slot) = return_slot {
Some(self.builder.build_load(slot, call_name))
} else {
result
}
}
@ -519,6 +558,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>(
let id;
let key;
let param_vals;
let is_extern;
let symbol = {
// make sure this lock guard is dropped at the end of this scope...
let def = definition.read();
@ -532,6 +572,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>(
if let Some(callback) = codegen_callback {
return callback.run(ctx, obj, fun, params, generator);
}
is_extern = instance_to_stmt.is_empty();
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
@ -606,14 +647,41 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>(
if let Some(obj) = &obj {
args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None });
}
let params =
args.iter().map(|arg| ctx.get_llvm_type(generator, arg.ty).into()).collect_vec();
let fun_ty = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
ctx.ctx.void_type().fn_type(&params, false)
let ret_type = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
None
} else {
ctx.get_llvm_type(generator, fun.0.ret).fn_type(&params, false)
Some(ctx.get_llvm_type(generator, fun.0.ret))
};
ctx.module.add_function(&symbol, fun_ty, None)
let has_sret = ret_type.map_or(false, |ret_type| need_sret(ctx.ctx, ret_type));
let mut byvals = Vec::new();
let mut params =
args.iter().enumerate().map(|(i, arg)| match ctx.get_llvm_type(generator, arg.ty) {
BasicTypeEnum::StructType(ty) if is_extern => {
byvals.push((i, ty));
ty.ptr_type(AddressSpace::Generic).into()
},
x => x
}.into()).collect_vec();
if has_sret {
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::Generic).into());
}
let fun_ty = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => ctx.ctx.void_type().fn_type(&params, false)
};
let fun_val = ctx.module.add_function(&symbol, fun_ty, None);
let offset = if has_sret {
fun_val.add_attribute(AttributeLoc::Param(0),
ctx.ctx.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), ret_type.unwrap().as_any_type_enum()));
1
} else {
0
};
for (i, ty) in byvals {
fun_val.add_attribute(AttributeLoc::Param((i as u32) + offset),
ctx.ctx.create_type_attribute(Attribute::get_named_enum_kind_id("byval"), ty.as_any_type_enum()));
}
fun_val
});
Ok(ctx.build_call_or_invoke(fun_val, &param_vals, "call"))
}

View File

@ -8,14 +8,16 @@ use crate::{
};
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
AddressSpace,
OptimizationLevel,
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
builder::Builder,
context::Context,
module::Module,
passes::{PassManager, PassManagerBuilder},
types::{BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, PhiValue, PointerValue},
AddressSpace, OptimizationLevel,
types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, PhiValue, PointerValue}
};
use itertools::Itertools;
use nac3parser::ast::{Stmt, StrRef};
@ -74,6 +76,7 @@ pub struct CodeGenContext<'ctx, 'a> {
// outer catch clauses
pub outer_catch_clauses:
Option<(Vec<Option<BasicValueEnum<'ctx>>>, BasicBlock<'ctx>, PhiValue<'ctx>)>,
pub need_sret: bool,
}
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
@ -323,6 +326,19 @@ fn get_llvm_type<'ctx>(
})
}
fn need_sret<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> bool {
fn need_sret_impl<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>, maybe_large: bool) -> bool {
match ty {
BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false,
BasicTypeEnum::FloatType(_) if maybe_large => false,
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 =>
ty.get_field_types().iter().any(|ty| need_sret_impl(ctx, *ty, false)),
_ => true,
}
}
need_sret_impl(ctx, ty, true)
}
pub fn gen_func<'ctx, G: CodeGenerator>(
context: &'ctx Context,
generator: &mut G,
@ -417,7 +433,14 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
} else {
unreachable!()
};
let params = args
let ret_type = if unifier.unioned(ret, primitives.none) {
None
} else {
Some(get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, ret))
};
let has_sret = ret_type.map_or(false, |ty| need_sret(context, ty));
let mut params = args
.iter()
.map(|arg| {
get_llvm_type(
@ -432,18 +455,13 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
})
.collect_vec();
let fn_type = if unifier.unioned(ret, primitives.none) {
context.void_type().fn_type(&params, false)
} else {
get_llvm_type(
context,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
ret,
)
.fn_type(&params, false)
if has_sret {
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::Generic).into());
}
let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, false)
};
let symbol = &task.symbol_name;
@ -457,14 +475,20 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
});
fn_val.set_personality_function(personality);
}
if has_sret {
fn_val.add_attribute(AttributeLoc::Param(0),
context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"),
ret_type.unwrap().as_any_type_enum()));
}
let init_bb = context.append_basic_block(fn_val, "init");
builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body");
let mut var_assignment = HashMap::new();
let offset = if has_sret { 1 } else { 0 };
for (n, arg) in args.iter().enumerate() {
let param = fn_val.get_nth_param(n as u32).unwrap();
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let alloca = builder.build_alloca(
get_llvm_type(
context,
@ -479,7 +503,13 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
builder.build_store(alloca, param);
var_assignment.insert(arg.name, (alloca, None, 0));
}
let return_buffer = fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret"));
let return_buffer = if has_sret {
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
} else {
fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret"))
};
let static_values = {
let store = registry.static_value_store.lock();
store.store[task.id].clone()
@ -512,6 +542,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
module,
unifier,
static_value_store,
need_sret: has_sret
};
let mut err = None;

View File

@ -913,6 +913,10 @@ pub fn gen_return<'ctx, 'a, G: CodeGenerator>(
ctx.builder.build_store(ctx.return_buffer.unwrap(), value);
}
ctx.builder.build_unconditional_branch(return_target);
} else if ctx.need_sret {
// sret
ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap());
ctx.builder.build_return(None);
} else {
let value = value.as_ref().map(|v| v as &dyn BasicValue);
ctx.builder.build_return(value);