forked from M-Labs/nac3
core: Add np_{round,floor,ceil}
These functions are NumPy variants of round/floor/ceil, which returns floats instead of ints.
This commit is contained in:
parent
0af1e37e99
commit
5c5620692f
|
@ -881,6 +881,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
Ok(Some(val_toint.into()))
|
Ok(Some(val_toint.into()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_round",
|
||||||
|
float,
|
||||||
|
&[(float, "n")],
|
||||||
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.roundeven.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.roundeven.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
Ok(Some(val.into()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
name: "range".into(),
|
name: "range".into(),
|
||||||
simple_name: "range".into(),
|
simple_name: "range".into(),
|
||||||
|
@ -1123,6 +1150,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
Ok(Some(val_toint.into()))
|
Ok(Some(val_toint.into()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_floor",
|
||||||
|
float,
|
||||||
|
&[(float, "n")],
|
||||||
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.floor.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
Ok(Some(val.into()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
primitives,
|
primitives,
|
||||||
&var_map,
|
&var_map,
|
||||||
|
@ -1183,6 +1237,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
Ok(Some(val_toint.into()))
|
Ok(Some(val_toint.into()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_ceil",
|
||||||
|
float,
|
||||||
|
&[(float, "n")],
|
||||||
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
Ok(Some(val.into()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
Arc::new(RwLock::new({
|
Arc::new(RwLock::new({
|
||||||
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
|
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
|
||||||
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
|
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
|
||||||
|
@ -1835,13 +1916,16 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
"float",
|
"float",
|
||||||
"round",
|
"round",
|
||||||
"round64",
|
"round64",
|
||||||
|
"np_round",
|
||||||
"range",
|
"range",
|
||||||
"str",
|
"str",
|
||||||
"bool",
|
"bool",
|
||||||
"floor",
|
"floor",
|
||||||
"floor64",
|
"floor64",
|
||||||
|
"np_floor",
|
||||||
"ceil",
|
"ceil",
|
||||||
"ceil64",
|
"ceil64",
|
||||||
|
"np_ceil",
|
||||||
"len",
|
"len",
|
||||||
"min",
|
"min",
|
||||||
"max",
|
"max",
|
||||||
|
|
|
@ -108,10 +108,13 @@ def patch(module):
|
||||||
# Builtin Math functions
|
# Builtin Math functions
|
||||||
module.round = round_away_zero
|
module.round = round_away_zero
|
||||||
module.round64 = round_away_zero
|
module.round64 = round_away_zero
|
||||||
|
module.np_round = np.round
|
||||||
module.floor = math.floor
|
module.floor = math.floor
|
||||||
module.floor64 = math.floor
|
module.floor64 = math.floor
|
||||||
|
module.np_floor = np.floor
|
||||||
module.ceil = math.ceil
|
module.ceil = math.ceil
|
||||||
module.ceil64 = math.ceil
|
module.ceil64 = math.ceil
|
||||||
|
module.np_ceil = np.ceil
|
||||||
|
|
||||||
# NumPy Math functions
|
# NumPy Math functions
|
||||||
module.np_isnan = np.isnan
|
module.np_isnan = np.isnan
|
||||||
|
|
|
@ -36,6 +36,10 @@ def test_round64():
|
||||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
output_int64(round64(x))
|
output_int64(round64(x))
|
||||||
|
|
||||||
|
def test_np_round():
|
||||||
|
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||||
|
output_float64(np_round(x))
|
||||||
|
|
||||||
def test_np_isnan():
|
def test_np_isnan():
|
||||||
for x in [dbl_nan(), 0.0, dbl_inf()]:
|
for x in [dbl_nan(), 0.0, dbl_inf()]:
|
||||||
output_bool(np_isnan(x))
|
output_bool(np_isnan(x))
|
||||||
|
@ -87,6 +91,10 @@ def test_floor64():
|
||||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
output_int64(floor64(x))
|
output_int64(floor64(x))
|
||||||
|
|
||||||
|
def test_np_floor():
|
||||||
|
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||||
|
output_float64(np_floor(x))
|
||||||
|
|
||||||
def test_ceil():
|
def test_ceil():
|
||||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
output_int32(ceil(x))
|
output_int32(ceil(x))
|
||||||
|
@ -95,6 +103,10 @@ def test_ceil64():
|
||||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
output_int64(ceil64(x))
|
output_int64(ceil64(x))
|
||||||
|
|
||||||
|
def test_np_ceil():
|
||||||
|
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||||
|
output_float64(np_ceil(x))
|
||||||
|
|
||||||
def test_np_sqrt():
|
def test_np_sqrt():
|
||||||
for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||||
output_float64(np_sqrt(x))
|
output_float64(np_sqrt(x))
|
||||||
|
@ -214,6 +226,7 @@ def test_np_nextafter():
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_round()
|
test_round()
|
||||||
test_round64()
|
test_round64()
|
||||||
|
test_np_round()
|
||||||
test_np_isnan()
|
test_np_isnan()
|
||||||
test_np_isinf()
|
test_np_isinf()
|
||||||
test_np_sin()
|
test_np_sin()
|
||||||
|
@ -226,8 +239,10 @@ def run() -> int32:
|
||||||
test_np_fabs()
|
test_np_fabs()
|
||||||
test_floor()
|
test_floor()
|
||||||
test_floor64()
|
test_floor64()
|
||||||
|
test_np_floor()
|
||||||
test_ceil()
|
test_ceil()
|
||||||
test_ceil64()
|
test_ceil64()
|
||||||
|
test_np_ceil()
|
||||||
test_np_sqrt()
|
test_np_sqrt()
|
||||||
test_np_rint()
|
test_np_rint()
|
||||||
test_np_tan()
|
test_np_tan()
|
||||||
|
|
Loading…
Reference in New Issue