core: WIP - round works now
parent
fcb6234bbc
commit
12bdf6f77c
|
@ -1,5 +1,5 @@
|
|||
use inkwell::{FloatPredicate, IntPredicate};
|
||||
use inkwell::types::{BasicTypeEnum, IntType};
|
||||
use inkwell::types::BasicTypeEnum;
|
||||
use inkwell::values::{BasicValueEnum, FloatValue, IntValue};
|
||||
use itertools::Itertools;
|
||||
|
||||
|
@ -446,23 +446,58 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
|||
}
|
||||
|
||||
/// Invokes the `round` builtin function.
|
||||
pub fn call_round<'ctx>(
|
||||
pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, FloatValue<'ctx>),
|
||||
llvm_ret_ty: IntType<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
ret_ty: Type,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "round";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty).into_int_type();
|
||||
|
||||
if !ctx.unifier.unioned(n_ty, ctx.primitives.float) {
|
||||
unsupported_type(ctx, FN_NAME, &[n_ty])
|
||||
}
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
|
||||
let val = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||
ctx.builder
|
||||
.build_float_to_signed_int(val, llvm_ret_ty, FN_NAME)
|
||||
.unwrap()
|
||||
let val = llvm_intrinsics::call_float_round(ctx, n.into_float_value(), None);
|
||||
ctx.builder
|
||||
.build_float_to_signed_int(val, llvm_ret_ty, FN_NAME)
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ret_ty,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(
|
||||
n.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None,
|
||||
),
|
||||
|generator, ctx, val| {
|
||||
call_round(
|
||||
generator,
|
||||
ctx,
|
||||
(elem_ty, val),
|
||||
ret_ty,
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `np_round` builtin function.
|
||||
|
|
|
@ -782,23 +782,35 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"round",
|
||||
int32,
|
||||
&[(float, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
{
|
||||
let common_ndim = primitives.1.get_fresh_const_generic_var(
|
||||
primitives.0.usize(),
|
||||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_int32 = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(int32), Some(common_ndim.0));
|
||||
let ndarray_float = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), Some(common_ndim.0));
|
||||
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
let p0_ty = primitives.1
|
||||
.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
||||
let ret_ty = primitives.1
|
||||
.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None);
|
||||
|
||||
Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i32).into()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"round",
|
||||
ret_ty.0,
|
||||
&[(p0_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?))
|
||||
}),
|
||||
)
|
||||
},
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
|
@ -806,14 +818,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
int64,
|
||||
&[(float, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i64).into()))
|
||||
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
|
|
|
@ -855,8 +855,9 @@ impl<'a> Inferencer<'a> {
|
|||
"int32",
|
||||
"float",
|
||||
"bool",
|
||||
"round",
|
||||
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||
let target_ty = if id == &"int32".into() {
|
||||
let target_ty = if id == &"int32".into() || id == &"round".into() {
|
||||
self.primitives.int32
|
||||
} else if id == &"float".into() {
|
||||
self.primitives.float
|
||||
|
|
|
@ -71,10 +71,13 @@ def _float(x):
|
|||
return float(x)
|
||||
|
||||
def round_away_zero(x):
|
||||
if x >= 0.0:
|
||||
return math.floor(x + 0.5)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.vectorize(round_away_zero)(x)
|
||||
else:
|
||||
return math.ceil(x - 0.5)
|
||||
if x >= 0.0:
|
||||
return math.floor(x + 0.5)
|
||||
else:
|
||||
return math.ceil(x - 0.5)
|
||||
|
||||
def patch(module):
|
||||
def dbl_nan():
|
||||
|
|
|
@ -718,6 +718,13 @@ def test_ndarray_bool():
|
|||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_round():
|
||||
x = np_identity(2)
|
||||
y = round(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_int32_2(y)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
|
@ -815,4 +822,6 @@ def run() -> int32:
|
|||
test_ndarray_float()
|
||||
test_ndarray_bool()
|
||||
|
||||
test_ndarray_round()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue