1
0
forked from M-Labs/nac3

core: Add np_{round,floor,ceil}

These functions are NumPy variants of round/floor/ceil, which returns
floats instead of ints.
This commit is contained in:
David Mak 2023-11-23 13:45:07 +08:00
parent 0af1e37e99
commit 5c5620692f
3 changed files with 102 additions and 0 deletions

View File

@ -881,6 +881,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
), ),
create_fn_by_codegen(
primitives,
&var_map,
"np_round",
float,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.roundeven.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.roundeven.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
Ok(Some(val.into()))
}),
),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: "range".into(), name: "range".into(),
simple_name: "range".into(), simple_name: "range".into(),
@ -1123,6 +1150,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
), ),
create_fn_by_codegen(
primitives,
&var_map,
"np_floor",
float,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
Ok(Some(val.into()))
}),
),
create_fn_by_codegen( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
@ -1183,6 +1237,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
), ),
create_fn_by_codegen(
primitives,
&var_map,
"np_ceil",
float,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
Ok(Some(val.into()))
}),
),
Arc::new(RwLock::new({ Arc::new(RwLock::new({
let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
@ -1835,13 +1916,16 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"float", "float",
"round", "round",
"round64", "round64",
"np_round",
"range", "range",
"str", "str",
"bool", "bool",
"floor", "floor",
"floor64", "floor64",
"np_floor",
"ceil", "ceil",
"ceil64", "ceil64",
"np_ceil",
"len", "len",
"min", "min",
"max", "max",

View File

@ -108,10 +108,13 @@ def patch(module):
# Builtin Math functions # Builtin Math functions
module.round = round_away_zero module.round = round_away_zero
module.round64 = round_away_zero module.round64 = round_away_zero
module.np_round = np.round
module.floor = math.floor module.floor = math.floor
module.floor64 = math.floor module.floor64 = math.floor
module.np_floor = np.floor
module.ceil = math.ceil module.ceil = math.ceil
module.ceil64 = math.ceil module.ceil64 = math.ceil
module.np_ceil = np.ceil
# NumPy Math functions # NumPy Math functions
module.np_isnan = np.isnan module.np_isnan = np.isnan

View File

@ -36,6 +36,10 @@ def test_round64():
for x in [-1.5, -0.5, 0.5, 1.5]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(round64(x)) output_int64(round64(x))
def test_np_round():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(np_round(x))
def test_np_isnan(): def test_np_isnan():
for x in [dbl_nan(), 0.0, dbl_inf()]: for x in [dbl_nan(), 0.0, dbl_inf()]:
output_bool(np_isnan(x)) output_bool(np_isnan(x))
@ -87,6 +91,10 @@ def test_floor64():
for x in [-1.5, -0.5, 0.5, 1.5]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(floor64(x)) output_int64(floor64(x))
def test_np_floor():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(np_floor(x))
def test_ceil(): def test_ceil():
for x in [-1.5, -0.5, 0.5, 1.5]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_int32(ceil(x)) output_int32(ceil(x))
@ -95,6 +103,10 @@ def test_ceil64():
for x in [-1.5, -0.5, 0.5, 1.5]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(ceil64(x)) output_int64(ceil64(x))
def test_np_ceil():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(np_ceil(x))
def test_np_sqrt(): def test_np_sqrt():
for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(np_sqrt(x)) output_float64(np_sqrt(x))
@ -214,6 +226,7 @@ def test_np_nextafter():
def run() -> int32: def run() -> int32:
test_round() test_round()
test_round64() test_round64()
test_np_round()
test_np_isnan() test_np_isnan()
test_np_isinf() test_np_isinf()
test_np_sin() test_np_sin()
@ -226,8 +239,10 @@ def run() -> int32:
test_np_fabs() test_np_fabs()
test_floor() test_floor()
test_floor64() test_floor64()
test_np_floor()
test_ceil() test_ceil()
test_ceil64() test_ceil64()
test_np_ceil()
test_np_sqrt() test_np_sqrt()
test_np_rint() test_np_rint()
test_np_tan() test_np_tan()