Revert removal of round and round64 #352

Merged
sb10q merged 3 commits from issue-149 into master 2023-11-04 13:35:53 +08:00
3 changed files with 224 additions and 65 deletions

View File

@ -737,6 +737,66 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
create_fn_by_codegen(
primitives,
&var_map,
"round",
int32,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "round");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"round64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "round");
Ok(Some(val_toint.into()))
}),
),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: "range".into(), name: "range".into(),
simple_name: "range".into(), simple_name: "range".into(),
@ -919,21 +979,125 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
create_fn_by_intrinsic( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"floor", "floor",
float, int32,
&[(float, "x")], &[(float, "n")],
"llvm.floor.f64", Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "floor");
Ok(Some(val_toint.into()))
}),
), ),
create_fn_by_intrinsic( create_fn_by_codegen(
primitives,
&var_map,
"floor64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "floor");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"ceil", "ceil",
float, int32,
&[(float, "x")], &[(float, "n")],
"llvm.ceil.f64", Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "ceil");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"ceil64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "ceil");
Ok(Some(val_toint.into()))
}),
), ),
Arc::new(RwLock::new({ Arc::new(RwLock::new({
let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
@ -1284,14 +1448,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
&[(float, "x")], &[(float, "x")],
"llvm.fabs.f64", "llvm.fabs.f64",
), ),
create_fn_by_intrinsic(
primitives,
&var_map,
"trunc",
float,
&[(float, "x")],
"llvm.trunc.f64",
),
create_fn_by_intrinsic( create_fn_by_intrinsic(
primitives, primitives,
&var_map, &var_map,
@ -1300,49 +1456,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
&[(float, "x")], &[(float, "x")],
"llvm.sqrt.f64", "llvm.sqrt.f64",
), ),
create_fn_by_codegen( create_fn_by_intrinsic(
primitives, primitives,
&var_map, &var_map,
"rint", "rint",
float, float,
&[(float, "x")], &[(float, "x")],
Box::new(|ctx, _, fun, args, generator| { "llvm.roundeven.f64",
let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
// rint(x) == round(x * 0.5) * 2.0
// %0 = fmul f64 %x, 0.5
let x_half = ctx.builder
.build_float_mul(x_val.into_float_value(), llvm_f64.const_float(0.5), "");
// %1 = call f64 @llvm.round.f64(f64 %0)
let round = ctx.builder
.build_call(
intrinsic_fn,
&vec![x_half.into()],
"",
)
.try_as_basic_value()
.left()
.unwrap();
// %2 = fmul f64 %1, 2.0
let val = ctx.builder
.build_float_mul(round.into_float_value(), llvm_f64.const_float(2.0).into(), "rint");
Ok(Some(val.into()))
}),
), ),
create_fn_by_extern( create_fn_by_extern(
primitives, primitives,
@ -1774,11 +1894,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"uint32", "uint32",
"uint64", "uint64",
"float", "float",
"round",
"round64",
"range", "range",
"str", "str",
"bool", "bool",
"floor", "floor",
"floor64",
"ceil", "ceil",
"ceil64",
"len", "len",
"min", "min",
"max", "max",
@ -1793,7 +1917,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"log10", "log10",
"log2", "log2",
"fabs", "fabs",
"trunc",
"sqrt", "sqrt",
"rint", "rint",
"tan", "tan",

View File

@ -3,6 +3,7 @@
import sys import sys
import importlib.util import importlib.util
import importlib.machinery import importlib.machinery
import math
import numpy as np import numpy as np
import pathlib import pathlib
import scipy import scipy
@ -43,6 +44,12 @@ def Some(v: T) -> Option[T]:
none = Option(None) none = Option(None)
def round_away_zero(x):
if x >= 0.0:
return math.floor(x + 0.5)
else:
return math.ceil(x - 0.5)
def patch(module): def patch(module):
def dbl_nan(): def dbl_nan():
return np.nan return np.nan
@ -98,6 +105,14 @@ def patch(module):
module.Some = Some module.Some = Some
module.none = none module.none = none
# Builtin Math functions
module.round = round_away_zero
module.round64 = round_away_zero
module.floor = math.floor
module.floor64 = math.floor
module.ceil = math.ceil
module.ceil64 = math.ceil
# NumPy Math functions # NumPy Math functions
module.isnan = np.isnan module.isnan = np.isnan
module.isinf = np.isinf module.isinf = np.isinf
@ -109,8 +124,6 @@ def patch(module):
module.log10 = np.log10 module.log10 = np.log10
module.log2 = np.log2 module.log2 = np.log2
module.fabs = np.fabs module.fabs = np.fabs
module.floor = np.floor
module.ceil = np.ceil
module.trunc = np.trunc module.trunc = np.trunc
module.sqrt = np.sqrt module.sqrt = np.sqrt
module.rint = np.rint module.rint = np.rint

View File

@ -2,6 +2,14 @@
def output_bool(x: bool): def output_bool(x: bool):
... ...
@extern
def output_int32(x: int32):
...
@extern
def output_int64(x: int64):
...
@extern @extern
def output_float64(x: float): def output_float64(x: float):
... ...
@ -20,6 +28,14 @@ def dbl_pi() -> float:
def dbl_e() -> float: def dbl_e() -> float:
return 2.71828182845904523536028747135266249775724709369995 return 2.71828182845904523536028747135266249775724709369995
def test_round():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int32(round(x))
def test_round64():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(round64(x))
def test_isnan(): def test_isnan():
for x in [dbl_nan(), 0.0, dbl_inf()]: for x in [dbl_nan(), 0.0, dbl_inf()]:
output_bool(isnan(x)) output_bool(isnan(x))
@ -64,16 +80,20 @@ def test_fabs():
output_float64(fabs(x)) output_float64(fabs(x))
def test_floor(): def test_floor():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(floor(x)) output_int32(floor(x))
def test_floor64():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(floor64(x))
def test_ceil(): def test_ceil():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(ceil(x)) output_int32(ceil(x))
def test_trunc(): def test_ceil64():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(trunc(x)) output_int64(ceil64(x))
def test_sqrt(): def test_sqrt():
for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
@ -192,6 +212,8 @@ def test_nextafter():
output_float64(nextafter(x1, x2)) output_float64(nextafter(x1, x2))
def run() -> int32: def run() -> int32:
test_round()
test_round64()
test_isnan() test_isnan()
test_isinf() test_isinf()
test_sin() test_sin()
@ -203,8 +225,9 @@ def run() -> int32:
test_log2() test_log2()
test_fabs() test_fabs()
test_floor() test_floor()
test_floor64()
test_ceil() test_ceil()
test_trunc() test_ceil64()
test_sqrt() test_sqrt()
test_rint() test_rint()
test_tan() test_tan()