forked from M-Labs/nac3
core: Revert breaking changes to round-family functions
These functions should return ints as the math.* functions do instead of following the convention of numpy.* functions.
This commit is contained in:
parent
2e055e8ab1
commit
4dbe07a0c0
|
@ -814,6 +814,66 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
})),
|
})),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"round",
|
||||||
|
int32,
|
||||||
|
&[(float, "n")],
|
||||||
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.round.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
let val_toint = ctx.builder
|
||||||
|
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "round");
|
||||||
|
Ok(Some(val_toint.into()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"round64",
|
||||||
|
int64,
|
||||||
|
&[(float, "n")],
|
||||||
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let llvm_i64 = ctx.ctx.i64_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.round.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
let val_toint = ctx.builder
|
||||||
|
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "round");
|
||||||
|
Ok(Some(val_toint.into()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
name: "range".into(),
|
name: "range".into(),
|
||||||
simple_name: "range".into(),
|
simple_name: "range".into(),
|
||||||
|
@ -996,21 +1056,125 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
})),
|
})),
|
||||||
create_fn_by_intrinsic(
|
create_fn_by_codegen(
|
||||||
primitives,
|
primitives,
|
||||||
&var_map,
|
&var_map,
|
||||||
"floor",
|
"floor",
|
||||||
float,
|
int32,
|
||||||
&[(float, "x")],
|
&[(float, "n")],
|
||||||
"llvm.floor.f64",
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.floor.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
let val_toint = ctx.builder
|
||||||
|
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "floor");
|
||||||
|
Ok(Some(val_toint.into()))
|
||||||
|
}),
|
||||||
),
|
),
|
||||||
create_fn_by_intrinsic(
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"floor64",
|
||||||
|
int64,
|
||||||
|
&[(float, "n")],
|
||||||
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let llvm_i64 = ctx.ctx.i64_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.floor.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
let val_toint = ctx.builder
|
||||||
|
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "floor");
|
||||||
|
Ok(Some(val_toint.into()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
primitives,
|
primitives,
|
||||||
&var_map,
|
&var_map,
|
||||||
"ceil",
|
"ceil",
|
||||||
float,
|
int32,
|
||||||
&[(float, "x")],
|
&[(float, "n")],
|
||||||
"llvm.ceil.f64",
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
let val_toint = ctx.builder
|
||||||
|
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "ceil");
|
||||||
|
Ok(Some(val_toint.into()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"ceil64",
|
||||||
|
int64,
|
||||||
|
&[(float, "n")],
|
||||||
|
Box::new(|ctx, _, _, args, generator| {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let llvm_i64 = ctx.ctx.i64_type();
|
||||||
|
|
||||||
|
let arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
|
||||||
|
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let val = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[arg.into()], "")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
let val_toint = ctx.builder
|
||||||
|
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "ceil");
|
||||||
|
Ok(Some(val_toint.into()))
|
||||||
|
}),
|
||||||
),
|
),
|
||||||
Arc::new(RwLock::new({
|
Arc::new(RwLock::new({
|
||||||
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
|
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
|
||||||
|
@ -1815,11 +1979,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
"uint32",
|
"uint32",
|
||||||
"uint64",
|
"uint64",
|
||||||
"float",
|
"float",
|
||||||
|
"round",
|
||||||
|
"round64",
|
||||||
"range",
|
"range",
|
||||||
"str",
|
"str",
|
||||||
"bool",
|
"bool",
|
||||||
"floor",
|
"floor",
|
||||||
|
"floor64",
|
||||||
"ceil",
|
"ceil",
|
||||||
|
"ceil64",
|
||||||
"len",
|
"len",
|
||||||
"min",
|
"min",
|
||||||
"max",
|
"max",
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
import sys
|
import sys
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import importlib.machinery
|
import importlib.machinery
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pathlib
|
import pathlib
|
||||||
import scipy
|
import scipy
|
||||||
|
@ -43,6 +44,12 @@ def Some(v: T) -> Option[T]:
|
||||||
|
|
||||||
none = Option(None)
|
none = Option(None)
|
||||||
|
|
||||||
|
def round_away_zero(x):
|
||||||
|
if x >= 0.0:
|
||||||
|
return math.floor(x + 0.5)
|
||||||
|
else:
|
||||||
|
return math.ceil(x - 0.5)
|
||||||
|
|
||||||
def patch(module):
|
def patch(module):
|
||||||
def dbl_nan():
|
def dbl_nan():
|
||||||
return np.nan
|
return np.nan
|
||||||
|
@ -98,6 +105,14 @@ def patch(module):
|
||||||
module.Some = Some
|
module.Some = Some
|
||||||
module.none = none
|
module.none = none
|
||||||
|
|
||||||
|
# Builtin Math functions
|
||||||
|
module.round = round_away_zero
|
||||||
|
module.round64 = round_away_zero
|
||||||
|
module.floor = math.floor
|
||||||
|
module.floor64 = math.floor
|
||||||
|
module.ceil = math.ceil
|
||||||
|
module.ceil64 = math.ceil
|
||||||
|
|
||||||
# NumPy Math functions
|
# NumPy Math functions
|
||||||
module.isnan = np.isnan
|
module.isnan = np.isnan
|
||||||
module.isinf = np.isinf
|
module.isinf = np.isinf
|
||||||
|
@ -109,8 +124,6 @@ def patch(module):
|
||||||
module.log10 = np.log10
|
module.log10 = np.log10
|
||||||
module.log2 = np.log2
|
module.log2 = np.log2
|
||||||
module.fabs = np.fabs
|
module.fabs = np.fabs
|
||||||
module.floor = np.floor
|
|
||||||
module.ceil = np.ceil
|
|
||||||
module.trunc = np.trunc
|
module.trunc = np.trunc
|
||||||
module.sqrt = np.sqrt
|
module.sqrt = np.sqrt
|
||||||
module.rint = np.rint
|
module.rint = np.rint
|
||||||
|
|
|
@ -2,6 +2,14 @@
|
||||||
def output_bool(x: bool):
|
def output_bool(x: bool):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@extern
|
||||||
|
def output_int32(x: int32):
|
||||||
|
...
|
||||||
|
|
||||||
|
@extern
|
||||||
|
def output_int64(x: int64):
|
||||||
|
...
|
||||||
|
|
||||||
@extern
|
@extern
|
||||||
def output_float64(x: float):
|
def output_float64(x: float):
|
||||||
...
|
...
|
||||||
|
@ -20,6 +28,14 @@ def dbl_pi() -> float:
|
||||||
def dbl_e() -> float:
|
def dbl_e() -> float:
|
||||||
return 2.71828182845904523536028747135266249775724709369995
|
return 2.71828182845904523536028747135266249775724709369995
|
||||||
|
|
||||||
|
def test_round():
|
||||||
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
|
output_int32(round(x))
|
||||||
|
|
||||||
|
def test_round64():
|
||||||
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
|
output_int64(round64(x))
|
||||||
|
|
||||||
def test_isnan():
|
def test_isnan():
|
||||||
for x in [dbl_nan(), 0.0, dbl_inf()]:
|
for x in [dbl_nan(), 0.0, dbl_inf()]:
|
||||||
output_bool(isnan(x))
|
output_bool(isnan(x))
|
||||||
|
@ -64,12 +80,20 @@ def test_fabs():
|
||||||
output_float64(fabs(x))
|
output_float64(fabs(x))
|
||||||
|
|
||||||
def test_floor():
|
def test_floor():
|
||||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
output_float64(floor(x))
|
output_int32(floor(x))
|
||||||
|
|
||||||
|
def test_floor64():
|
||||||
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
|
output_int64(floor64(x))
|
||||||
|
|
||||||
def test_ceil():
|
def test_ceil():
|
||||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
output_float64(ceil(x))
|
output_int32(ceil(x))
|
||||||
|
|
||||||
|
def test_ceil64():
|
||||||
|
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||||
|
output_int64(ceil64(x))
|
||||||
|
|
||||||
def test_trunc():
|
def test_trunc():
|
||||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||||
|
@ -192,6 +216,8 @@ def test_nextafter():
|
||||||
output_float64(nextafter(x1, x2))
|
output_float64(nextafter(x1, x2))
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
|
test_round()
|
||||||
|
test_round64()
|
||||||
test_isnan()
|
test_isnan()
|
||||||
test_isinf()
|
test_isinf()
|
||||||
test_sin()
|
test_sin()
|
||||||
|
@ -203,7 +229,9 @@ def run() -> int32:
|
||||||
test_log2()
|
test_log2()
|
||||||
test_fabs()
|
test_fabs()
|
||||||
test_floor()
|
test_floor()
|
||||||
|
test_floor64()
|
||||||
test_ceil()
|
test_ceil()
|
||||||
|
test_ceil64()
|
||||||
test_trunc()
|
test_trunc()
|
||||||
test_sqrt()
|
test_sqrt()
|
||||||
test_rint()
|
test_rint()
|
||||||
|
|
Loading…
Reference in New Issue