From 2f85bb38376ff3346709c7344065872759c96e05 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Wed, 9 Mar 2022 22:09:36 +0800 Subject: [PATCH] nac3core: impl call attributes sret for returning large structs, and byval for struct args in extern function calls. --- nac3core/src/codegen/expr.rs | 88 ++++++++++++++++++++++++++++++++---- nac3core/src/codegen/mod.rs | 67 +++++++++++++++++++-------- nac3core/src/codegen/stmt.rs | 4 ++ 3 files changed, 131 insertions(+), 28 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 09d0afb0b..bfebffb61 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -13,16 +13,17 @@ use crate::{ typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, }; use inkwell::{ - types::{BasicType, BasicTypeEnum}, - values::{BasicValueEnum, FunctionValue, IntValue, PointerValue}, AddressSpace, + attributes::{Attribute, AttributeLoc}, + types::{AnyType, BasicType, BasicTypeEnum}, + values::{BasicValueEnum, FunctionValue, IntValue, PointerValue} }; use itertools::{chain, izip, zip, Itertools}; use nac3parser::ast::{ self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, }; -use super::CodeGenerator; +use super::{CodeGenerator, need_sret}; pub fn get_subst_key( unifier: &mut Unifier, @@ -299,7 +300,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { params: &[BasicValueEnum<'ctx>], call_name: &str, ) -> Option> { - if let Some(target) = self.unwind_target { + let mut loc_params: Vec> = Vec::new(); + let mut return_slot = None; + if fun.count_params() > 0 { + let sret_id = Attribute::get_named_enum_kind_id("sret"); + let byval_id = Attribute::get_named_enum_kind_id("byval"); + + let offset = if fun.get_enum_attribute(AttributeLoc::Param(0), sret_id).is_some() { + return_slot = Some(self.builder.build_alloca(fun.get_type().get_param_types()[0] + .into_pointer_type().get_element_type().into_struct_type(), call_name)); + loc_params.push((*return_slot.as_ref().unwrap()).into()); + 1 + } else { + 0 + }; + for (i, param) in params.iter().enumerate() { + if fun.get_enum_attribute(AttributeLoc::Param((i + offset) as u32), byval_id).is_some() { + // lazy update + if loc_params.is_empty() { + loc_params.extend(params[0..i+offset].iter().copied()); + } + let slot = self.builder.build_alloca(param.get_type(), call_name); + loc_params.push(slot.into()); + self.builder.build_store(slot, *param); + } else if !loc_params.is_empty() { + loc_params.push(*param); + } + } + } + let params = if loc_params.is_empty() { + params + } else { + &loc_params + }; + let result = if let Some(target) = self.unwind_target { let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name)); let result = self @@ -312,6 +346,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } else { let param: Vec<_> = params.iter().map(|v| (*v).into()).collect(); self.builder.build_call(fun, ¶m, call_name).try_as_basic_value().left() + }; + if let Some(slot) = return_slot { + Some(self.builder.build_load(slot, call_name)) + } else { + result } } @@ -519,6 +558,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( let id; let key; let param_vals; + let is_extern; let symbol = { // make sure this lock guard is dropped at the end of this scope... let def = definition.read(); @@ -532,6 +572,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( if let Some(callback) = codegen_callback { return callback.run(ctx, obj, fun, params, generator); } + is_extern = instance_to_stmt.is_empty(); let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None); let mut keys = fun.0.args.clone(); let mut mapping = HashMap::new(); @@ -606,14 +647,41 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( if let Some(obj) = &obj { args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None }); } - let params = - args.iter().map(|arg| ctx.get_llvm_type(generator, arg.ty).into()).collect_vec(); - let fun_ty = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { - ctx.ctx.void_type().fn_type(¶ms, false) + let ret_type = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { + None } else { - ctx.get_llvm_type(generator, fun.0.ret).fn_type(¶ms, false) + Some(ctx.get_llvm_type(generator, fun.0.ret)) }; - ctx.module.add_function(&symbol, fun_ty, None) + let has_sret = ret_type.map_or(false, |ret_type| need_sret(ctx.ctx, ret_type)); + let mut byvals = Vec::new(); + let mut params = + args.iter().enumerate().map(|(i, arg)| match ctx.get_llvm_type(generator, arg.ty) { + BasicTypeEnum::StructType(ty) if is_extern => { + byvals.push((i, ty)); + ty.ptr_type(AddressSpace::Generic).into() + }, + x => x + }.into()).collect_vec(); + if has_sret { + params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::Generic).into()); + } + let fun_ty = match ret_type { + Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false), + _ => ctx.ctx.void_type().fn_type(¶ms, false) + }; + let fun_val = ctx.module.add_function(&symbol, fun_ty, None); + let offset = if has_sret { + fun_val.add_attribute(AttributeLoc::Param(0), + ctx.ctx.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), ret_type.unwrap().as_any_type_enum())); + 1 + } else { + 0 + }; + for (i, ty) in byvals { + fun_val.add_attribute(AttributeLoc::Param((i as u32) + offset), + ctx.ctx.create_type_attribute(Attribute::get_named_enum_kind_id("byval"), ty.as_any_type_enum())); + } + fun_val }); Ok(ctx.build_call_or_invoke(fun_val, ¶m_vals, "call")) } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index ee7d3a0d2..b9e75e620 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -8,14 +8,16 @@ use crate::{ }; use crossbeam::channel::{unbounded, Receiver, Sender}; use inkwell::{ + AddressSpace, + OptimizationLevel, + attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, builder::Builder, context::Context, module::Module, passes::{PassManager, PassManagerBuilder}, - types::{BasicType, BasicTypeEnum}, - values::{BasicValueEnum, FunctionValue, PhiValue, PointerValue}, - AddressSpace, OptimizationLevel, + types::{AnyType, BasicType, BasicTypeEnum}, + values::{BasicValueEnum, FunctionValue, PhiValue, PointerValue} }; use itertools::Itertools; use nac3parser::ast::{Stmt, StrRef}; @@ -74,6 +76,7 @@ pub struct CodeGenContext<'ctx, 'a> { // outer catch clauses pub outer_catch_clauses: Option<(Vec>>, BasicBlock<'ctx>, PhiValue<'ctx>)>, + pub need_sret: bool, } impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { @@ -323,6 +326,19 @@ fn get_llvm_type<'ctx>( }) } +fn need_sret<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> bool { + fn need_sret_impl<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>, maybe_large: bool) -> bool { + match ty { + BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false, + BasicTypeEnum::FloatType(_) if maybe_large => false, + BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => + ty.get_field_types().iter().any(|ty| need_sret_impl(ctx, *ty, false)), + _ => true, + } + } + need_sret_impl(ctx, ty, true) +} + pub fn gen_func<'ctx, G: CodeGenerator>( context: &'ctx Context, generator: &mut G, @@ -417,7 +433,14 @@ pub fn gen_func<'ctx, G: CodeGenerator>( } else { unreachable!() }; - let params = args + let ret_type = if unifier.unioned(ret, primitives.none) { + None + } else { + Some(get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, ret)) + }; + + let has_sret = ret_type.map_or(false, |ty| need_sret(context, ty)); + let mut params = args .iter() .map(|arg| { get_llvm_type( @@ -432,18 +455,13 @@ pub fn gen_func<'ctx, G: CodeGenerator>( }) .collect_vec(); - let fn_type = if unifier.unioned(ret, primitives.none) { - context.void_type().fn_type(¶ms, false) - } else { - get_llvm_type( - context, - generator, - &mut unifier, - top_level_ctx.as_ref(), - &mut type_cache, - ret, - ) - .fn_type(¶ms, false) + if has_sret { + params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::Generic).into()); + } + + let fn_type = match ret_type { + Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false), + _ => context.void_type().fn_type(¶ms, false) }; let symbol = &task.symbol_name; @@ -457,14 +475,20 @@ pub fn gen_func<'ctx, G: CodeGenerator>( }); fn_val.set_personality_function(personality); } + if has_sret { + fn_val.add_attribute(AttributeLoc::Param(0), + context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), + ret_type.unwrap().as_any_type_enum())); + } let init_bb = context.append_basic_block(fn_val, "init"); builder.position_at_end(init_bb); let body_bb = context.append_basic_block(fn_val, "body"); let mut var_assignment = HashMap::new(); + let offset = if has_sret { 1 } else { 0 }; for (n, arg) in args.iter().enumerate() { - let param = fn_val.get_nth_param(n as u32).unwrap(); + let param = fn_val.get_nth_param((n as u32) + offset).unwrap(); let alloca = builder.build_alloca( get_llvm_type( context, @@ -479,7 +503,13 @@ pub fn gen_func<'ctx, G: CodeGenerator>( builder.build_store(alloca, param); var_assignment.insert(arg.name, (alloca, None, 0)); } - let return_buffer = fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret")); + + let return_buffer = if has_sret { + Some(fn_val.get_nth_param(0).unwrap().into_pointer_value()) + } else { + fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret")) + }; + let static_values = { let store = registry.static_value_store.lock(); store.store[task.id].clone() @@ -512,6 +542,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( module, unifier, static_value_store, + need_sret: has_sret }; let mut err = None; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 1bc728578..5c3478065 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -913,6 +913,10 @@ pub fn gen_return<'ctx, 'a, G: CodeGenerator>( ctx.builder.build_store(ctx.return_buffer.unwrap(), value); } ctx.builder.build_unconditional_branch(return_target); + } else if ctx.need_sret { + // sret + ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()); + ctx.builder.build_return(None); } else { let value = value.as_ref().map(|v| v as &dyn BasicValue); ctx.builder.build_return(value);