forked from M-Labs/nac3
nac3core: impl call attributes
sret for returning large structs, and byval for struct args in extern function calls.
This commit is contained in:
parent
e266d3c2b0
commit
2f85bb3837
@ -13,16 +13,17 @@ use crate::{
|
||||
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
};
|
||||
use inkwell::{
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::{BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{AnyType, BasicType, BasicTypeEnum},
|
||||
values::{BasicValueEnum, FunctionValue, IntValue, PointerValue}
|
||||
};
|
||||
use itertools::{chain, izip, zip, Itertools};
|
||||
use nac3parser::ast::{
|
||||
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
};
|
||||
|
||||
use super::CodeGenerator;
|
||||
use super::{CodeGenerator, need_sret};
|
||||
|
||||
pub fn get_subst_key(
|
||||
unifier: &mut Unifier,
|
||||
@ -299,7 +300,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
params: &[BasicValueEnum<'ctx>],
|
||||
call_name: &str,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
if let Some(target) = self.unwind_target {
|
||||
let mut loc_params: Vec<BasicValueEnum<'ctx>> = Vec::new();
|
||||
let mut return_slot = None;
|
||||
if fun.count_params() > 0 {
|
||||
let sret_id = Attribute::get_named_enum_kind_id("sret");
|
||||
let byval_id = Attribute::get_named_enum_kind_id("byval");
|
||||
|
||||
let offset = if fun.get_enum_attribute(AttributeLoc::Param(0), sret_id).is_some() {
|
||||
return_slot = Some(self.builder.build_alloca(fun.get_type().get_param_types()[0]
|
||||
.into_pointer_type().get_element_type().into_struct_type(), call_name));
|
||||
loc_params.push((*return_slot.as_ref().unwrap()).into());
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
for (i, param) in params.iter().enumerate() {
|
||||
if fun.get_enum_attribute(AttributeLoc::Param((i + offset) as u32), byval_id).is_some() {
|
||||
// lazy update
|
||||
if loc_params.is_empty() {
|
||||
loc_params.extend(params[0..i+offset].iter().copied());
|
||||
}
|
||||
let slot = self.builder.build_alloca(param.get_type(), call_name);
|
||||
loc_params.push(slot.into());
|
||||
self.builder.build_store(slot, *param);
|
||||
} else if !loc_params.is_empty() {
|
||||
loc_params.push(*param);
|
||||
}
|
||||
}
|
||||
}
|
||||
let params = if loc_params.is_empty() {
|
||||
params
|
||||
} else {
|
||||
&loc_params
|
||||
};
|
||||
let result = if let Some(target) = self.unwind_target {
|
||||
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name));
|
||||
let result = self
|
||||
@ -312,6 +346,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
} else {
|
||||
let param: Vec<_> = params.iter().map(|v| (*v).into()).collect();
|
||||
self.builder.build_call(fun, ¶m, call_name).try_as_basic_value().left()
|
||||
};
|
||||
if let Some(slot) = return_slot {
|
||||
Some(self.builder.build_load(slot, call_name))
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
@ -519,6 +558,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>(
|
||||
let id;
|
||||
let key;
|
||||
let param_vals;
|
||||
let is_extern;
|
||||
let symbol = {
|
||||
// make sure this lock guard is dropped at the end of this scope...
|
||||
let def = definition.read();
|
||||
@ -532,6 +572,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>(
|
||||
if let Some(callback) = codegen_callback {
|
||||
return callback.run(ctx, obj, fun, params, generator);
|
||||
}
|
||||
is_extern = instance_to_stmt.is_empty();
|
||||
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
@ -606,14 +647,41 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>(
|
||||
if let Some(obj) = &obj {
|
||||
args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None });
|
||||
}
|
||||
let params =
|
||||
args.iter().map(|arg| ctx.get_llvm_type(generator, arg.ty).into()).collect_vec();
|
||||
let fun_ty = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
|
||||
ctx.ctx.void_type().fn_type(¶ms, false)
|
||||
let ret_type = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
|
||||
None
|
||||
} else {
|
||||
ctx.get_llvm_type(generator, fun.0.ret).fn_type(¶ms, false)
|
||||
Some(ctx.get_llvm_type(generator, fun.0.ret))
|
||||
};
|
||||
ctx.module.add_function(&symbol, fun_ty, None)
|
||||
let has_sret = ret_type.map_or(false, |ret_type| need_sret(ctx.ctx, ret_type));
|
||||
let mut byvals = Vec::new();
|
||||
let mut params =
|
||||
args.iter().enumerate().map(|(i, arg)| match ctx.get_llvm_type(generator, arg.ty) {
|
||||
BasicTypeEnum::StructType(ty) if is_extern => {
|
||||
byvals.push((i, ty));
|
||||
ty.ptr_type(AddressSpace::Generic).into()
|
||||
},
|
||||
x => x
|
||||
}.into()).collect_vec();
|
||||
if has_sret {
|
||||
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::Generic).into());
|
||||
}
|
||||
let fun_ty = match ret_type {
|
||||
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false),
|
||||
_ => ctx.ctx.void_type().fn_type(¶ms, false)
|
||||
};
|
||||
let fun_val = ctx.module.add_function(&symbol, fun_ty, None);
|
||||
let offset = if has_sret {
|
||||
fun_val.add_attribute(AttributeLoc::Param(0),
|
||||
ctx.ctx.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), ret_type.unwrap().as_any_type_enum()));
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
for (i, ty) in byvals {
|
||||
fun_val.add_attribute(AttributeLoc::Param((i as u32) + offset),
|
||||
ctx.ctx.create_type_attribute(Attribute::get_named_enum_kind_id("byval"), ty.as_any_type_enum()));
|
||||
}
|
||||
fun_val
|
||||
});
|
||||
Ok(ctx.build_call_or_invoke(fun_val, ¶m_vals, "call"))
|
||||
}
|
||||
|
@ -8,14 +8,16 @@ use crate::{
|
||||
};
|
||||
use crossbeam::channel::{unbounded, Receiver, Sender};
|
||||
use inkwell::{
|
||||
AddressSpace,
|
||||
OptimizationLevel,
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
basic_block::BasicBlock,
|
||||
builder::Builder,
|
||||
context::Context,
|
||||
module::Module,
|
||||
passes::{PassManager, PassManagerBuilder},
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::{BasicValueEnum, FunctionValue, PhiValue, PointerValue},
|
||||
AddressSpace, OptimizationLevel,
|
||||
types::{AnyType, BasicType, BasicTypeEnum},
|
||||
values::{BasicValueEnum, FunctionValue, PhiValue, PointerValue}
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use nac3parser::ast::{Stmt, StrRef};
|
||||
@ -74,6 +76,7 @@ pub struct CodeGenContext<'ctx, 'a> {
|
||||
// outer catch clauses
|
||||
pub outer_catch_clauses:
|
||||
Option<(Vec<Option<BasicValueEnum<'ctx>>>, BasicBlock<'ctx>, PhiValue<'ctx>)>,
|
||||
pub need_sret: bool,
|
||||
}
|
||||
|
||||
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
@ -323,6 +326,19 @@ fn get_llvm_type<'ctx>(
|
||||
})
|
||||
}
|
||||
|
||||
fn need_sret<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> bool {
|
||||
fn need_sret_impl<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>, maybe_large: bool) -> bool {
|
||||
match ty {
|
||||
BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false,
|
||||
BasicTypeEnum::FloatType(_) if maybe_large => false,
|
||||
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 =>
|
||||
ty.get_field_types().iter().any(|ty| need_sret_impl(ctx, *ty, false)),
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
need_sret_impl(ctx, ty, true)
|
||||
}
|
||||
|
||||
pub fn gen_func<'ctx, G: CodeGenerator>(
|
||||
context: &'ctx Context,
|
||||
generator: &mut G,
|
||||
@ -417,7 +433,14 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let params = args
|
||||
let ret_type = if unifier.unioned(ret, primitives.none) {
|
||||
None
|
||||
} else {
|
||||
Some(get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, ret))
|
||||
};
|
||||
|
||||
let has_sret = ret_type.map_or(false, |ty| need_sret(context, ty));
|
||||
let mut params = args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
get_llvm_type(
|
||||
@ -432,18 +455,13 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let fn_type = if unifier.unioned(ret, primitives.none) {
|
||||
context.void_type().fn_type(¶ms, false)
|
||||
} else {
|
||||
get_llvm_type(
|
||||
context,
|
||||
generator,
|
||||
&mut unifier,
|
||||
top_level_ctx.as_ref(),
|
||||
&mut type_cache,
|
||||
ret,
|
||||
)
|
||||
.fn_type(¶ms, false)
|
||||
if has_sret {
|
||||
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::Generic).into());
|
||||
}
|
||||
|
||||
let fn_type = match ret_type {
|
||||
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false),
|
||||
_ => context.void_type().fn_type(¶ms, false)
|
||||
};
|
||||
|
||||
let symbol = &task.symbol_name;
|
||||
@ -457,14 +475,20 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
|
||||
});
|
||||
fn_val.set_personality_function(personality);
|
||||
}
|
||||
if has_sret {
|
||||
fn_val.add_attribute(AttributeLoc::Param(0),
|
||||
context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"),
|
||||
ret_type.unwrap().as_any_type_enum()));
|
||||
}
|
||||
|
||||
let init_bb = context.append_basic_block(fn_val, "init");
|
||||
builder.position_at_end(init_bb);
|
||||
let body_bb = context.append_basic_block(fn_val, "body");
|
||||
|
||||
let mut var_assignment = HashMap::new();
|
||||
let offset = if has_sret { 1 } else { 0 };
|
||||
for (n, arg) in args.iter().enumerate() {
|
||||
let param = fn_val.get_nth_param(n as u32).unwrap();
|
||||
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
|
||||
let alloca = builder.build_alloca(
|
||||
get_llvm_type(
|
||||
context,
|
||||
@ -479,7 +503,13 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
|
||||
builder.build_store(alloca, param);
|
||||
var_assignment.insert(arg.name, (alloca, None, 0));
|
||||
}
|
||||
let return_buffer = fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret"));
|
||||
|
||||
let return_buffer = if has_sret {
|
||||
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
|
||||
} else {
|
||||
fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret"))
|
||||
};
|
||||
|
||||
let static_values = {
|
||||
let store = registry.static_value_store.lock();
|
||||
store.store[task.id].clone()
|
||||
@ -512,6 +542,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
|
||||
module,
|
||||
unifier,
|
||||
static_value_store,
|
||||
need_sret: has_sret
|
||||
};
|
||||
|
||||
let mut err = None;
|
||||
|
@ -913,6 +913,10 @@ pub fn gen_return<'ctx, 'a, G: CodeGenerator>(
|
||||
ctx.builder.build_store(ctx.return_buffer.unwrap(), value);
|
||||
}
|
||||
ctx.builder.build_unconditional_branch(return_target);
|
||||
} else if ctx.need_sret {
|
||||
// sret
|
||||
ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap());
|
||||
ctx.builder.build_return(None);
|
||||
} else {
|
||||
let value = value.as_ref().map(|v| v as &dyn BasicValue);
|
||||
ctx.builder.build_return(value);
|
||||
|
Loading…
Reference in New Issue
Block a user