forked from M-Labs/nac3
core: Add create_fn_by_* functions
Used for abstracting the creation of function from different sources.
This commit is contained in:
parent
068f0d9faf
commit
7cf7634985
|
@ -5,7 +5,13 @@ use crate::{
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
};
|
};
|
||||||
use inkwell::{types::BasicType, FloatPredicate, IntPredicate};
|
use inkwell::{
|
||||||
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
types::{BasicType, BasicMetadataTypeEnum},
|
||||||
|
values::BasicMetadataValueEnum,
|
||||||
|
FloatPredicate,
|
||||||
|
IntPredicate
|
||||||
|
};
|
||||||
|
|
||||||
type BuiltinInfo = (Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>, &'static [&'static str]);
|
type BuiltinInfo = (Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>, &'static [&'static str]);
|
||||||
|
|
||||||
|
@ -78,6 +84,182 @@ pub fn get_exn_constructor(
|
||||||
(fun_def, class_def, signature, exn_type)
|
(fun_def, class_def, signature, exn_type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a NumPy [TopLevelDef] function by code generation.
|
||||||
|
///
|
||||||
|
/// * `name`: The name of the implemented NumPy function.
|
||||||
|
/// * `ret_ty`: The return type of this function.
|
||||||
|
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||||
|
/// [parameter type][Type] and the parameter symbol name.
|
||||||
|
/// * `codegen_callback`: A lambda generating LLVM IR for the implementation of this function.
|
||||||
|
fn create_fn_by_codegen(
|
||||||
|
primitives: &mut (PrimitiveStore, Unifier),
|
||||||
|
var_map: &HashMap<u32, Type>,
|
||||||
|
name: &'static str,
|
||||||
|
ret_ty: Type,
|
||||||
|
param_ty: &[(Type, &'static str)],
|
||||||
|
codegen_callback: GenCallCallback,
|
||||||
|
) -> Arc<RwLock<TopLevelDef>> {
|
||||||
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
|
name: name.into(),
|
||||||
|
simple_name: name.into(),
|
||||||
|
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: param_ty.iter().map(|p| FuncArg {
|
||||||
|
name: p.1.into(),
|
||||||
|
ty: p.0,
|
||||||
|
default_value: None,
|
||||||
|
}).collect(),
|
||||||
|
ret: ret_ty.clone(),
|
||||||
|
vars: var_map.clone(),
|
||||||
|
})),
|
||||||
|
var_id: Default::default(),
|
||||||
|
instance_to_symbol: Default::default(),
|
||||||
|
instance_to_stmt: Default::default(),
|
||||||
|
resolver: None,
|
||||||
|
codegen_callback: Some(Arc::new(GenCall::new(codegen_callback))),
|
||||||
|
loc: None,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a NumPy [TopLevelDef] function using an LLVM intrinsic.
|
||||||
|
///
|
||||||
|
/// * `name`: The name of the implemented NumPy function.
|
||||||
|
/// * `ret_ty`: The return type of this function.
|
||||||
|
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||||
|
/// [parameter type][Type] and the parameter symbol name.
|
||||||
|
/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function.
|
||||||
|
fn create_fn_by_intrinsic(
|
||||||
|
primitives: &mut (PrimitiveStore, Unifier),
|
||||||
|
var_map: &HashMap<u32, Type>,
|
||||||
|
name: &'static str,
|
||||||
|
ret_ty: Type,
|
||||||
|
params: &[(Type, &'static str)],
|
||||||
|
intrinsic_fn: &'static str,
|
||||||
|
) -> Arc<RwLock<TopLevelDef>> {
|
||||||
|
let param_tys = params.iter()
|
||||||
|
.map(|p| p.0)
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
var_map,
|
||||||
|
name,
|
||||||
|
ret_ty,
|
||||||
|
params,
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec();
|
||||||
|
|
||||||
|
assert!(param_tys.iter().zip(&args_ty)
|
||||||
|
.all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual)));
|
||||||
|
|
||||||
|
let args_val = args_ty.iter().zip_eq(args.iter())
|
||||||
|
.map(|(ty, arg)| {
|
||||||
|
arg.1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ty.clone())
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.map_into::<BasicMetadataValueEnum>()
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| {
|
||||||
|
let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty.clone());
|
||||||
|
let param_llvm_ty = param_tys.iter()
|
||||||
|
.map(|p| ctx.get_llvm_abi_type(generator, *p))
|
||||||
|
.map_into::<BasicMetadataTypeEnum>()
|
||||||
|
.collect_vec();
|
||||||
|
let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false);
|
||||||
|
|
||||||
|
ctx.module.add_function(intrinsic_fn, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let call = ctx.builder
|
||||||
|
.build_call(intrinsic_fn, args_val.as_slice(), name);
|
||||||
|
|
||||||
|
let val = call.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
Ok(val.into())
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a unary NumPy [TopLevelDef] function using an extern function (e.g. from `libc` or
|
||||||
|
/// `libm`).
|
||||||
|
///
|
||||||
|
/// * `name`: The name of the implemented NumPy function.
|
||||||
|
/// * `ret_ty`: The return type of this function.
|
||||||
|
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||||
|
/// [parameter type][Type] and the parameter symbol name.
|
||||||
|
/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation.
|
||||||
|
/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is
|
||||||
|
/// already implied by the C ABI.
|
||||||
|
fn create_fn_by_extern(
|
||||||
|
primitives: &mut (PrimitiveStore, Unifier),
|
||||||
|
var_map: &HashMap<u32, Type>,
|
||||||
|
name: &'static str,
|
||||||
|
ret_ty: Type,
|
||||||
|
params: &[(Type, &'static str)],
|
||||||
|
extern_fn: &'static str,
|
||||||
|
attrs: &'static [&str],
|
||||||
|
) -> Arc<RwLock<TopLevelDef>> {
|
||||||
|
let param_tys = params.iter()
|
||||||
|
.map(|p| p.0)
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
var_map,
|
||||||
|
name,
|
||||||
|
ret_ty,
|
||||||
|
params,
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec();
|
||||||
|
|
||||||
|
assert!(param_tys.iter().zip(&args_ty)
|
||||||
|
.all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual)));
|
||||||
|
|
||||||
|
let args_val = args_ty.iter().zip_eq(args.iter())
|
||||||
|
.map(|(ty, arg)| {
|
||||||
|
arg.1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ty.clone())
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.map_into::<BasicMetadataValueEnum>()
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| {
|
||||||
|
let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty.clone());
|
||||||
|
let param_llvm_ty = param_tys.iter()
|
||||||
|
.map(|p| ctx.get_llvm_abi_type(generator, *p))
|
||||||
|
.map_into::<BasicMetadataTypeEnum>()
|
||||||
|
.collect_vec();
|
||||||
|
let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false);
|
||||||
|
let func = ctx.module.add_function(extern_fn, fn_type, None);
|
||||||
|
func.add_attribute(
|
||||||
|
AttributeLoc::Function,
|
||||||
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
|
||||||
|
);
|
||||||
|
|
||||||
|
for attr in attrs {
|
||||||
|
func.add_attribute(
|
||||||
|
AttributeLoc::Function,
|
||||||
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
func
|
||||||
|
});
|
||||||
|
|
||||||
|
let call = ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &args_val, name);
|
||||||
|
|
||||||
|
let val = call.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
Ok(val.into())
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
let int32 = primitives.0.int32;
|
let int32 = primitives.0.int32;
|
||||||
let int64 = primitives.0.int64;
|
let int64 = primitives.0.int64;
|
||||||
|
@ -981,7 +1163,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
||||||
ret: num_ty.0,
|
ret: num_ty.0,
|
||||||
vars: var_map,
|
vars: var_map.clone(),
|
||||||
})),
|
})),
|
||||||
var_id: Default::default(),
|
var_id: Default::default(),
|
||||||
instance_to_symbol: Default::default(),
|
instance_to_symbol: Default::default(),
|
||||||
|
|
Loading…
Reference in New Issue