forked from M-Labs/nac3
core: Add create_fn_by_* functions
Used for abstracting the creation of function from different sources.
This commit is contained in:
parent
068f0d9faf
commit
7cf7634985
|
@ -5,7 +5,13 @@ use crate::{
|
|||
},
|
||||
symbol_resolver::SymbolValue,
|
||||
};
|
||||
use inkwell::{types::BasicType, FloatPredicate, IntPredicate};
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{BasicType, BasicMetadataTypeEnum},
|
||||
values::BasicMetadataValueEnum,
|
||||
FloatPredicate,
|
||||
IntPredicate
|
||||
};
|
||||
|
||||
type BuiltinInfo = (Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>, &'static [&'static str]);
|
||||
|
||||
|
@ -78,6 +84,182 @@ pub fn get_exn_constructor(
|
|||
(fun_def, class_def, signature, exn_type)
|
||||
}
|
||||
|
||||
/// Creates a NumPy [TopLevelDef] function by code generation.
|
||||
///
|
||||
/// * `name`: The name of the implemented NumPy function.
|
||||
/// * `ret_ty`: The return type of this function.
|
||||
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// * `codegen_callback`: A lambda generating LLVM IR for the implementation of this function.
|
||||
fn create_fn_by_codegen(
|
||||
primitives: &mut (PrimitiveStore, Unifier),
|
||||
var_map: &HashMap<u32, Type>,
|
||||
name: &'static str,
|
||||
ret_ty: Type,
|
||||
param_ty: &[(Type, &'static str)],
|
||||
codegen_callback: GenCallCallback,
|
||||
) -> Arc<RwLock<TopLevelDef>> {
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: name.into(),
|
||||
simple_name: name.into(),
|
||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: param_ty.iter().map(|p| FuncArg {
|
||||
name: p.1.into(),
|
||||
ty: p.0,
|
||||
default_value: None,
|
||||
}).collect(),
|
||||
ret: ret_ty.clone(),
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Default::default(),
|
||||
instance_to_symbol: Default::default(),
|
||||
instance_to_stmt: Default::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(codegen_callback))),
|
||||
loc: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Creates a NumPy [TopLevelDef] function using an LLVM intrinsic.
|
||||
///
|
||||
/// * `name`: The name of the implemented NumPy function.
|
||||
/// * `ret_ty`: The return type of this function.
|
||||
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function.
|
||||
fn create_fn_by_intrinsic(
|
||||
primitives: &mut (PrimitiveStore, Unifier),
|
||||
var_map: &HashMap<u32, Type>,
|
||||
name: &'static str,
|
||||
ret_ty: Type,
|
||||
params: &[(Type, &'static str)],
|
||||
intrinsic_fn: &'static str,
|
||||
) -> Arc<RwLock<TopLevelDef>> {
|
||||
let param_tys = params.iter()
|
||||
.map(|p| p.0)
|
||||
.collect_vec();
|
||||
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
var_map,
|
||||
name,
|
||||
ret_ty,
|
||||
params,
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec();
|
||||
|
||||
assert!(param_tys.iter().zip(&args_ty)
|
||||
.all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual)));
|
||||
|
||||
let args_val = args_ty.iter().zip_eq(args.iter())
|
||||
.map(|(ty, arg)| {
|
||||
arg.1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ty.clone())
|
||||
.unwrap()
|
||||
})
|
||||
.map_into::<BasicMetadataValueEnum>()
|
||||
.collect_vec();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| {
|
||||
let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty.clone());
|
||||
let param_llvm_ty = param_tys.iter()
|
||||
.map(|p| ctx.get_llvm_abi_type(generator, *p))
|
||||
.map_into::<BasicMetadataTypeEnum>()
|
||||
.collect_vec();
|
||||
let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false);
|
||||
|
||||
ctx.module.add_function(intrinsic_fn, fn_type, None)
|
||||
});
|
||||
|
||||
let call = ctx.builder
|
||||
.build_call(intrinsic_fn, args_val.as_slice(), name);
|
||||
|
||||
let val = call.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap();
|
||||
Ok(val.into())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a unary NumPy [TopLevelDef] function using an extern function (e.g. from `libc` or
|
||||
/// `libm`).
|
||||
///
|
||||
/// * `name`: The name of the implemented NumPy function.
|
||||
/// * `ret_ty`: The return type of this function.
|
||||
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation.
|
||||
/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is
|
||||
/// already implied by the C ABI.
|
||||
fn create_fn_by_extern(
|
||||
primitives: &mut (PrimitiveStore, Unifier),
|
||||
var_map: &HashMap<u32, Type>,
|
||||
name: &'static str,
|
||||
ret_ty: Type,
|
||||
params: &[(Type, &'static str)],
|
||||
extern_fn: &'static str,
|
||||
attrs: &'static [&str],
|
||||
) -> Arc<RwLock<TopLevelDef>> {
|
||||
let param_tys = params.iter()
|
||||
.map(|p| p.0)
|
||||
.collect_vec();
|
||||
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
var_map,
|
||||
name,
|
||||
ret_ty,
|
||||
params,
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec();
|
||||
|
||||
assert!(param_tys.iter().zip(&args_ty)
|
||||
.all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual)));
|
||||
|
||||
let args_val = args_ty.iter().zip_eq(args.iter())
|
||||
.map(|(ty, arg)| {
|
||||
arg.1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ty.clone())
|
||||
.unwrap()
|
||||
})
|
||||
.map_into::<BasicMetadataValueEnum>()
|
||||
.collect_vec();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| {
|
||||
let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty.clone());
|
||||
let param_llvm_ty = param_tys.iter()
|
||||
.map(|p| ctx.get_llvm_abi_type(generator, *p))
|
||||
.map_into::<BasicMetadataTypeEnum>()
|
||||
.collect_vec();
|
||||
let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false);
|
||||
let func = ctx.module.add_function(extern_fn, fn_type, None);
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
|
||||
);
|
||||
|
||||
for attr in attrs {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
let call = ctx.builder
|
||||
.build_call(intrinsic_fn, &args_val, name);
|
||||
|
||||
let val = call.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap();
|
||||
Ok(val.into())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||
let int32 = primitives.0.int32;
|
||||
let int64 = primitives.0.int64;
|
||||
|
@ -981,7 +1163,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
||||
ret: num_ty.0,
|
||||
vars: var_map,
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Default::default(),
|
||||
instance_to_symbol: Default::default(),
|
||||
|
|
Loading…
Reference in New Issue