core: WIP - round works now

This commit is contained in:
David Mak 2024-04-26 19:31:15 +08:00
parent 92a40c4b6e
commit 37a29162c6
4 changed files with 165 additions and 57 deletions

View File

@ -1,5 +1,5 @@
use inkwell::{FloatPredicate, IntPredicate};
use inkwell::types::{BasicTypeEnum, IntType};
use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValueEnum, FloatValue, IntValue};
use itertools::Itertools;
@ -406,37 +406,89 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
}
/// Invokes the `round` builtin function.
pub fn call_round<'ctx>(
pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, FloatValue<'ctx>),
llvm_ret_ty: IntType<'ctx>,
) -> IntValue<'ctx> {
n: (Type, BasicValueEnum<'ctx>),
ret_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "round";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty).into_int_type();
if !ctx.unifier.unioned(n_ty, ctx.primitives.float) {
unsupported_type(ctx, FN_NAME, &[n_ty])
}
Ok(match n.get_type() {
BasicTypeEnum::FloatType(_) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
let val = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder
.build_float_to_signed_int(val, llvm_ret_ty, FN_NAME)
.unwrap()
let val = llvm_intrinsics::call_float_round(ctx, n.into_float_value(), None);
ctx.builder
.build_float_to_signed_int(val, llvm_ret_ty, FN_NAME)
.map(Into::into)
.unwrap()
}
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ret_ty,
None,
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|generator, ctx, val| {
call_round(generator, ctx, (elem_ty, val), ret_ty)
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
})
}
/// Invokes the `np_round` builtin function.
pub fn call_numpy_round<'ctx>(
pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, FloatValue<'ctx>),
) -> FloatValue<'ctx> {
n: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_round";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
if !ctx.unifier.unioned(n_ty, ctx.primitives.float) {
unsupported_type(ctx, "np_round", &[n_ty])
}
Ok(match n.get_type() {
BasicTypeEnum::FloatType(_) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
llvm_intrinsics::call_float_roundeven(ctx, n, None)
llvm_intrinsics::call_float_roundeven(ctx, n.into_float_value(), None).into()
}
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ctx.primitives.float,
None,
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|generator, ctx, val| {
call_numpy_round(generator, ctx, (elem_ty, val))
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
})
}
/// Invokes the `bool` builtin function.

View File

@ -303,6 +303,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
None,
);
let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None);
let float_or_ndarray_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float],
Some("T".into()),
None,
);
let num_or_ndarray_ty = unifier.get_fresh_var_with_range(
&[num_ty.0, ndarray_num_ty],
Some("T".into()),
@ -789,53 +794,88 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
.map(|val| Some(val.as_basic_value_enum()))
}),
),
create_fn_by_codegen(
unifier,
&var_map,
"round",
int32,
&[(float, "n")],
Box::new(|ctx, _, fun, args, generator| {
let llvm_i32 = ctx.ctx.i32_type();
{
let common_ndim = unifier.get_fresh_const_generic_var(
primitives.usize(),
Some("N".into()),
None,
);
let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0));
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let p0_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float],
Some("T".into()),
None,
);
let ret_ty = unifier.get_fresh_var_with_range(
&[int32, ndarray_int32],
Some("R".into()),
None,
);
Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i32).into()))
}),
),
create_fn_by_codegen(
unifier,
&var_map,
"round64",
int64,
&[(float, "n")],
Box::new(|ctx, _, fun, args, generator| {
let llvm_i64 = ctx.ctx.i64_type();
create_fn_by_codegen(
unifier,
&var_map,
"round",
ret_ty.0,
&[(p0_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?))
}),
)
},
{
let common_ndim = unifier.get_fresh_const_generic_var(
primitives.usize(),
Some("N".into()),
None,
);
let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0));
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i64).into()))
}),
),
let p0_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float],
Some("T".into()),
None,
);
let ret_ty = unifier.get_fresh_var_with_range(
&[int64, ndarray_int64],
Some("R".into()),
None,
);
create_fn_by_codegen(
unifier,
&var_map,
"round64",
ret_ty.0,
&[(p0_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?))
}),
)
},
create_fn_by_codegen(
unifier,
&var_map,
"np_round",
float,
&[(float, "n")],
float_or_ndarray_ty.0,
&[(float_or_ndarray_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
Ok(Some(builtin_fns::call_numpy_round(ctx, (arg_ty, arg)).into()))
Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?))
}),
),
Arc::new(RwLock::new(TopLevelDef::Function {

View File

@ -71,10 +71,13 @@ def _float(x):
return float(x)
def round_away_zero(x):
if x >= 0.0:
return math.floor(x + 0.5)
if isinstance(x, np.ndarray):
return np.vectorize(round_away_zero)(x)
else:
return math.ceil(x - 0.5)
if x >= 0.0:
return math.floor(x + 0.5)
else:
return math.ceil(x - 0.5)
def patch(module):
def dbl_nan():

View File

@ -718,6 +718,17 @@ def test_ndarray_bool():
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_round():
x = np_identity(2)
xf32 = round(x)
xf64 = round64(x)
xff = np_round(x)
output_ndarray_float_2(x)
output_ndarray_int32_2(xf32)
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
@ -815,4 +826,6 @@ def run() -> int32:
test_ndarray_float()
test_ndarray_bool()
test_ndarray_round()
return 0