1
0
forked from M-Labs/nac3

nac3core: max builtin function

This commit is contained in:
ychenfo 2022-03-08 22:22:00 +08:00
parent e9a17cf8f8
commit 8241a29908

View File

@ -925,7 +925,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, FuncArg { name: "n".into(), ty: num_ty.0, default_value: None },
], ],
ret: num_ty.0, ret: num_ty.0,
vars: var_map, vars: var_map.clone(),
})), })),
var_id: Default::default(), var_id: Default::default(),
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
@ -978,6 +978,68 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "max".into(),
simple_name: "max".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "m".into(), ty: num_ty.0, default_value: None },
FuncArg { name: "n".into(), ty: num_ty.0, default_value: None },
],
ret: num_ty.0,
vars: var_map,
})),
var_id: Default::default(),
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| {
let boolean = ctx.primitives.bool;
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let uint32 = ctx.primitives.uint32;
let uint64 = ctx.primitives.uint64;
let float = ctx.primitives.float;
let llvm_i1 = ctx.ctx.bool_type().as_basic_type_enum();
let llvm_i32 = ctx.ctx.i32_type().as_basic_type_enum();
let llvm_i64 = ctx.ctx.i64_type().as_basic_type_enum();
let llvm_f64 = ctx.ctx.f64_type().as_basic_type_enum();
let m_ty = fun.0.args[0].ty;
let n_ty = fun.0.args[1].ty;
let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator)?;
let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator)?;
let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b);
let (fun_name, arg_ty) = if is_type(m_ty, n_ty) && is_type(n_ty, boolean) {
("llvm.umax.i1", llvm_i1)
} else if is_type(m_ty, n_ty) && is_type(n_ty, int32) {
("llvm.smax.i32", llvm_i32)
} else if is_type(m_ty, n_ty) && is_type(n_ty, int64) {
("llvm.smax.i64", llvm_i64)
} else if is_type(m_ty, n_ty) && is_type(n_ty, uint32) {
("llvm.umax.i32", llvm_i32)
} else if is_type(m_ty, n_ty) && is_type(n_ty, uint64) {
("llvm.umax.i64", llvm_i64)
} else if is_type(m_ty, n_ty) && is_type(n_ty, float) {
("llvm.maxnum.f64", llvm_f64)
} else {
unreachable!();
};
let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| {
let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false);
ctx.module.add_function(fun_name, fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic, &[m_val.into(), n_val.into()], "max")
.try_as_basic_value()
.left()
.unwrap();
Ok(val.into())
},
)))),
loc: None,
})),
]; ];
let ast_list: Vec<Option<ast::Stmt<()>>> = let ast_list: Vec<Option<ast::Stmt<()>>> =
(0..top_level_def_list.len()).map(|_| None).collect(); (0..top_level_def_list.len()).map(|_| None).collect();
@ -1002,6 +1064,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"ceil64", "ceil64",
"len", "len",
"min", "min",
"max",
], ],
) )
} }