From 7cf763498562e537e1c2f840c4ac75daeb1c86c1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 10 Oct 2023 14:56:16 +0800 Subject: [PATCH] core: Add create_fn_by_* functions Used for abstracting the creation of function from different sources. --- nac3core/src/toplevel/builtins.rs | 186 +++++++++++++++++++++++++++++- 1 file changed, 184 insertions(+), 2 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 1b9ef28f..5e3f1121 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -5,7 +5,13 @@ use crate::{ }, symbol_resolver::SymbolValue, }; -use inkwell::{types::BasicType, FloatPredicate, IntPredicate}; +use inkwell::{ + attributes::{Attribute, AttributeLoc}, + types::{BasicType, BasicMetadataTypeEnum}, + values::BasicMetadataValueEnum, + FloatPredicate, + IntPredicate +}; type BuiltinInfo = (Vec<(Arc>, Option)>, &'static [&'static str]); @@ -78,6 +84,182 @@ pub fn get_exn_constructor( (fun_def, class_def, signature, exn_type) } +/// Creates a NumPy [TopLevelDef] function by code generation. +/// +/// * `name`: The name of the implemented NumPy function. +/// * `ret_ty`: The return type of this function. +/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the +/// [parameter type][Type] and the parameter symbol name. +/// * `codegen_callback`: A lambda generating LLVM IR for the implementation of this function. +fn create_fn_by_codegen( + primitives: &mut (PrimitiveStore, Unifier), + var_map: &HashMap, + name: &'static str, + ret_ty: Type, + param_ty: &[(Type, &'static str)], + codegen_callback: GenCallCallback, +) -> Arc> { + Arc::new(RwLock::new(TopLevelDef::Function { + name: name.into(), + simple_name: name.into(), + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.clone(), + vars: var_map.clone(), + })), + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(codegen_callback))), + loc: None, + })) +} + +/// Creates a NumPy [TopLevelDef] function using an LLVM intrinsic. +/// +/// * `name`: The name of the implemented NumPy function. +/// * `ret_ty`: The return type of this function. +/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the +/// [parameter type][Type] and the parameter symbol name. +/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function. +fn create_fn_by_intrinsic( + primitives: &mut (PrimitiveStore, Unifier), + var_map: &HashMap, + name: &'static str, + ret_ty: Type, + params: &[(Type, &'static str)], + intrinsic_fn: &'static str, +) -> Arc> { + let param_tys = params.iter() + .map(|p| p.0) + .collect_vec(); + + create_fn_by_codegen( + primitives, + var_map, + name, + ret_ty, + params, + Box::new(move |ctx, _, fun, args, generator| { + let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); + + assert!(param_tys.iter().zip(&args_ty) + .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); + + let args_val = args_ty.iter().zip_eq(args.iter()) + .map(|(ty, arg)| { + arg.1.clone() + .to_basic_value_enum(ctx, generator, ty.clone()) + .unwrap() + }) + .map_into::() + .collect_vec(); + + let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| { + let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty.clone()); + let param_llvm_ty = param_tys.iter() + .map(|p| ctx.get_llvm_abi_type(generator, *p)) + .map_into::() + .collect_vec(); + let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); + + ctx.module.add_function(intrinsic_fn, fn_type, None) + }); + + let call = ctx.builder + .build_call(intrinsic_fn, args_val.as_slice(), name); + + let val = call.try_as_basic_value() + .left() + .unwrap(); + Ok(val.into()) + }), + ) +} + +/// Creates a unary NumPy [TopLevelDef] function using an extern function (e.g. from `libc` or +/// `libm`). +/// +/// * `name`: The name of the implemented NumPy function. +/// * `ret_ty`: The return type of this function. +/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the +/// [parameter type][Type] and the parameter symbol name. +/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation. +/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is +/// already implied by the C ABI. +fn create_fn_by_extern( + primitives: &mut (PrimitiveStore, Unifier), + var_map: &HashMap, + name: &'static str, + ret_ty: Type, + params: &[(Type, &'static str)], + extern_fn: &'static str, + attrs: &'static [&str], +) -> Arc> { + let param_tys = params.iter() + .map(|p| p.0) + .collect_vec(); + + create_fn_by_codegen( + primitives, + var_map, + name, + ret_ty, + params, + Box::new(move |ctx, _, fun, args, generator| { + let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); + + assert!(param_tys.iter().zip(&args_ty) + .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); + + let args_val = args_ty.iter().zip_eq(args.iter()) + .map(|(ty, arg)| { + arg.1.clone() + .to_basic_value_enum(ctx, generator, ty.clone()) + .unwrap() + }) + .map_into::() + .collect_vec(); + + let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| { + let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty.clone()); + let param_llvm_ty = param_tys.iter() + .map(|p| ctx.get_llvm_abi_type(generator, *p)) + .map_into::() + .collect_vec(); + let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); + let func = ctx.module.add_function(extern_fn, fn_type, None); + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ); + + for attr in attrs { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ); + } + + func + }); + + let call = ctx.builder + .build_call(intrinsic_fn, &args_val, name); + + let val = call.try_as_basic_value() + .left() + .unwrap(); + Ok(val.into()) + }), + ) +} + pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let int32 = primitives.0.int32; let int64 = primitives.0.int64; @@ -981,7 +1163,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], ret: num_ty.0, - vars: var_map, + vars: var_map.clone(), })), var_id: Default::default(), instance_to_symbol: Default::default(),