forked from M-Labs/nac3
core: Add np_{round,floor,ceil}
These functions are NumPy variants of round/floor/ceil, which returns floats instead of ints.
This commit is contained in:
parent
0af1e37e99
commit
5c5620692f
|
@ -881,6 +881,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
Ok(Some(val_toint.into()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_round",
|
||||
float,
|
||||
&[(float, "n")],
|
||||
Box::new(|ctx, _, _, args, generator| {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("llvm.roundeven.f64").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
|
||||
ctx.module.add_function("llvm.roundeven.f64", fn_type, None)
|
||||
});
|
||||
|
||||
let val = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap();
|
||||
Ok(Some(val.into()))
|
||||
}),
|
||||
),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "range".into(),
|
||||
simple_name: "range".into(),
|
||||
|
@ -1123,6 +1150,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
Ok(Some(val_toint.into()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_floor",
|
||||
float,
|
||||
&[(float, "n")],
|
||||
Box::new(|ctx, _, _, args, generator| {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
|
||||
ctx.module.add_function("llvm.floor.f64", fn_type, None)
|
||||
});
|
||||
|
||||
let val = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap();
|
||||
Ok(Some(val.into()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
|
@ -1183,6 +1237,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
Ok(Some(val_toint.into()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_ceil",
|
||||
float,
|
||||
&[(float, "n")],
|
||||
Box::new(|ctx, _, _, args, generator| {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
|
||||
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
|
||||
});
|
||||
|
||||
let val = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap();
|
||||
Ok(Some(val.into()))
|
||||
}),
|
||||
),
|
||||
Arc::new(RwLock::new({
|
||||
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
|
||||
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
|
||||
|
@ -1835,13 +1916,16 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
"float",
|
||||
"round",
|
||||
"round64",
|
||||
"np_round",
|
||||
"range",
|
||||
"str",
|
||||
"bool",
|
||||
"floor",
|
||||
"floor64",
|
||||
"np_floor",
|
||||
"ceil",
|
||||
"ceil64",
|
||||
"np_ceil",
|
||||
"len",
|
||||
"min",
|
||||
"max",
|
||||
|
|
|
@ -108,10 +108,13 @@ def patch(module):
|
|||
# Builtin Math functions
|
||||
module.round = round_away_zero
|
||||
module.round64 = round_away_zero
|
||||
module.np_round = np.round
|
||||
module.floor = math.floor
|
||||
module.floor64 = math.floor
|
||||
module.np_floor = np.floor
|
||||
module.ceil = math.ceil
|
||||
module.ceil64 = math.ceil
|
||||
module.np_ceil = np.ceil
|
||||
|
||||
# NumPy Math functions
|
||||
module.np_isnan = np.isnan
|
||||
|
|
|
@ -36,6 +36,10 @@ def test_round64():
|
|||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int64(round64(x))
|
||||
|
||||
def test_np_round():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_round(x))
|
||||
|
||||
def test_np_isnan():
|
||||
for x in [dbl_nan(), 0.0, dbl_inf()]:
|
||||
output_bool(np_isnan(x))
|
||||
|
@ -87,6 +91,10 @@ def test_floor64():
|
|||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int64(floor64(x))
|
||||
|
||||
def test_np_floor():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_floor(x))
|
||||
|
||||
def test_ceil():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int32(ceil(x))
|
||||
|
@ -95,6 +103,10 @@ def test_ceil64():
|
|||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int64(ceil64(x))
|
||||
|
||||
def test_np_ceil():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_ceil(x))
|
||||
|
||||
def test_np_sqrt():
|
||||
for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_sqrt(x))
|
||||
|
@ -214,6 +226,7 @@ def test_np_nextafter():
|
|||
def run() -> int32:
|
||||
test_round()
|
||||
test_round64()
|
||||
test_np_round()
|
||||
test_np_isnan()
|
||||
test_np_isinf()
|
||||
test_np_sin()
|
||||
|
@ -226,8 +239,10 @@ def run() -> int32:
|
|||
test_np_fabs()
|
||||
test_floor()
|
||||
test_floor64()
|
||||
test_np_floor()
|
||||
test_ceil()
|
||||
test_ceil64()
|
||||
test_np_ceil()
|
||||
test_np_sqrt()
|
||||
test_np_rint()
|
||||
test_np_tan()
|
||||
|
|
Loading…
Reference in New Issue