forked from M-Labs/nac3
1
0
Fork 0

core: Revert breaking changes to round-family functions

These functions should return ints as the math.* functions do instead of
following the convention of numpy.* functions.
This commit is contained in:
David Mak 2023-11-02 14:56:35 +08:00 committed by sb10q
parent 2e055e8ab1
commit 4dbe07a0c0
3 changed files with 223 additions and 14 deletions

View File

@ -814,6 +814,66 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
create_fn_by_codegen(
primitives,
&var_map,
"round",
int32,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "round");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"round64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "round");
Ok(Some(val_toint.into()))
}),
),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: "range".into(), name: "range".into(),
simple_name: "range".into(), simple_name: "range".into(),
@ -996,21 +1056,125 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
create_fn_by_intrinsic( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"floor", "floor",
float, int32,
&[(float, "x")], &[(float, "n")],
"llvm.floor.f64", Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "floor");
Ok(Some(val_toint.into()))
}),
), ),
create_fn_by_intrinsic( create_fn_by_codegen(
primitives,
&var_map,
"floor64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "floor");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"ceil", "ceil",
float, int32,
&[(float, "x")], &[(float, "n")],
"llvm.ceil.f64", Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "ceil");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"ceil64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "ceil");
Ok(Some(val_toint.into()))
}),
), ),
Arc::new(RwLock::new({ Arc::new(RwLock::new({
let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
@ -1815,11 +1979,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"uint32", "uint32",
"uint64", "uint64",
"float", "float",
"round",
"round64",
"range", "range",
"str", "str",
"bool", "bool",
"floor", "floor",
"floor64",
"ceil", "ceil",
"ceil64",
"len", "len",
"min", "min",
"max", "max",

View File

@ -3,6 +3,7 @@
import sys import sys
import importlib.util import importlib.util
import importlib.machinery import importlib.machinery
import math
import numpy as np import numpy as np
import pathlib import pathlib
import scipy import scipy
@ -43,6 +44,12 @@ def Some(v: T) -> Option[T]:
none = Option(None) none = Option(None)
def round_away_zero(x):
if x >= 0.0:
return math.floor(x + 0.5)
else:
return math.ceil(x - 0.5)
def patch(module): def patch(module):
def dbl_nan(): def dbl_nan():
return np.nan return np.nan
@ -98,6 +105,14 @@ def patch(module):
module.Some = Some module.Some = Some
module.none = none module.none = none
# Builtin Math functions
module.round = round_away_zero
module.round64 = round_away_zero
module.floor = math.floor
module.floor64 = math.floor
module.ceil = math.ceil
module.ceil64 = math.ceil
# NumPy Math functions # NumPy Math functions
module.isnan = np.isnan module.isnan = np.isnan
module.isinf = np.isinf module.isinf = np.isinf
@ -109,8 +124,6 @@ def patch(module):
module.log10 = np.log10 module.log10 = np.log10
module.log2 = np.log2 module.log2 = np.log2
module.fabs = np.fabs module.fabs = np.fabs
module.floor = np.floor
module.ceil = np.ceil
module.trunc = np.trunc module.trunc = np.trunc
module.sqrt = np.sqrt module.sqrt = np.sqrt
module.rint = np.rint module.rint = np.rint

View File

@ -2,6 +2,14 @@
def output_bool(x: bool): def output_bool(x: bool):
... ...
@extern
def output_int32(x: int32):
...
@extern
def output_int64(x: int64):
...
@extern @extern
def output_float64(x: float): def output_float64(x: float):
... ...
@ -20,6 +28,14 @@ def dbl_pi() -> float:
def dbl_e() -> float: def dbl_e() -> float:
return 2.71828182845904523536028747135266249775724709369995 return 2.71828182845904523536028747135266249775724709369995
def test_round():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int32(round(x))
def test_round64():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(round64(x))
def test_isnan(): def test_isnan():
for x in [dbl_nan(), 0.0, dbl_inf()]: for x in [dbl_nan(), 0.0, dbl_inf()]:
output_bool(isnan(x)) output_bool(isnan(x))
@ -64,12 +80,20 @@ def test_fabs():
output_float64(fabs(x)) output_float64(fabs(x))
def test_floor(): def test_floor():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(floor(x)) output_int32(floor(x))
def test_floor64():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(floor64(x))
def test_ceil(): def test_ceil():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(ceil(x)) output_int32(ceil(x))
def test_ceil64():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(ceil64(x))
def test_trunc(): def test_trunc():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
@ -192,6 +216,8 @@ def test_nextafter():
output_float64(nextafter(x1, x2)) output_float64(nextafter(x1, x2))
def run() -> int32: def run() -> int32:
test_round()
test_round64()
test_isnan() test_isnan()
test_isinf() test_isinf()
test_sin() test_sin()
@ -203,7 +229,9 @@ def run() -> int32:
test_log2() test_log2()
test_fabs() test_fabs()
test_floor() test_floor()
test_floor64()
test_ceil() test_ceil()
test_ceil64()
test_trunc() test_trunc()
test_sqrt() test_sqrt()
test_rint() test_rint()