core: WIP - floor and ceil works now

This commit is contained in:
David Mak 2024-04-29 15:27:15 +08:00
parent 37a29162c6
commit 163867381a
4 changed files with 273 additions and 102 deletions

View File

@ -10,6 +10,8 @@ use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::Type;
// TODO: Rename ret_ty to ret_elem_ty or similar
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
///
/// The generated message will contain the function name and the name of the unsupported type.
@ -567,55 +569,117 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
}
/// Invokes the `floor` builtin function.
pub fn call_floor<'ctx>(
pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, FloatValue<'ctx>),
llvm_ret_ty: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
n: (Type, BasicValueEnum<'ctx>),
ret_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "floor";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
let val = llvm_intrinsics::call_float_floor(ctx, n, None);
match llvm_ret_ty {
_ if llvm_ret_ty == val.get_type().into() => val.into(),
Ok(match n.get_type() {
BasicTypeEnum::IntType(_)
| BasicTypeEnum::FloatType(_) => {
debug_assert!([
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
BasicTypeEnum::IntType(_) => {
ctx.builder
.build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME)
.map(Into::into)
.unwrap()
let val = llvm_intrinsics::call_float_floor(ctx, n.into_float_value(), None);
if llvm_ret_ty.is_int_type() {
ctx.builder
.build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME)
.map(Into::into)
.unwrap()
} else {
val.into()
}
}
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ret_ty,
None,
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|generator, ctx, val| {
call_floor(generator, ctx, (elem_ty, val), ret_ty)
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
}
})
}
/// Invokes the `ceil` builtin function.
pub fn call_ceil<'ctx>(
pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, FloatValue<'ctx>),
llvm_ret_ty: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
n: (Type, BasicValueEnum<'ctx>),
ret_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ceil";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
let val = llvm_intrinsics::call_float_ceil(ctx, n, None);
match llvm_ret_ty {
_ if llvm_ret_ty == val.get_type().into() => val.into(),
Ok(match n.get_type() {
BasicTypeEnum::IntType(_)
| BasicTypeEnum::FloatType(_) => {
debug_assert!([
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
BasicTypeEnum::IntType(_) => {
ctx.builder
.build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME)
.map(Into::into)
.unwrap()
let val = llvm_intrinsics::call_float_ceil(ctx, n.into_float_value(), None);
if llvm_ret_ty.is_int_type() {
ctx.builder
.build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME)
.map(Into::into)
.unwrap()
} else {
val.into()
}
}
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ret_ty,
None,
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|generator, ctx, val| {
call_floor(generator, ctx, (elem_ty, val), ret_ty)
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
}
})
}
/// Invokes the `min` builtin function.

View File

@ -363,6 +363,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
// TODO: Double check (T | ndarray[T]).to_basic_value_enum converts correctly
// TODO: Directly obtain Type instance after get_fresh_var_with_
let top_level_def_list = vec![
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
PRIMITIVE_DEF_IDS.int32,
@ -1025,102 +1027,172 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
)))),
loc: None,
})),
create_fn_by_codegen(
unifier,
&var_map,
"floor",
int32,
&[(float, "n")],
Box::new(|ctx, _, fun, args, generator| {
let llvm_i32 = ctx.ctx.i32_type();
{
let common_ndim = unifier.get_fresh_const_generic_var(
primitives.usize(),
Some("N".into()),
None,
);
let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0));
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let p0_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float],
Some("T".into()),
None,
);
let ret_ty = unifier.get_fresh_var_with_range(
&[int32, ndarray_int32],
Some("R".into()),
None,
);
Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), llvm_i32.into())))
}),
),
create_fn_by_codegen(
unifier,
&var_map,
"floor64",
int64,
&[(float, "n")],
Box::new(|ctx, _, fun, args, generator| {
let llvm_i64 = ctx.ctx.i64_type();
create_fn_by_codegen(
unifier,
&var_map,
"floor",
ret_ty.0,
&[(p0_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?))
}),
)
},
{
let common_ndim = unifier.get_fresh_const_generic_var(
primitives.usize(),
Some("N".into()),
None,
);
let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0));
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), llvm_i64.into())))
}),
),
let p0_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float],
Some("T".into()),
None,
);
let ret_ty = unifier.get_fresh_var_with_range(
&[int64, ndarray_int64],
Some("R".into()),
None,
);
create_fn_by_codegen(
unifier,
&var_map,
"floor64",
ret_ty.0,
&[(p0_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?))
}),
)
},
create_fn_by_codegen(
unifier,
&var_map,
"np_floor",
float,
&[(float, "n")],
float_or_ndarray_ty.0,
&[(float_or_ndarray_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), arg.get_type().into())))
Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
}),
),
create_fn_by_codegen(
unifier,
&var_map,
"ceil",
int32,
&[(float, "n")],
Box::new(|ctx, _, fun, args, generator| {
let llvm_i32 = ctx.ctx.i32_type();
{
let common_ndim = unifier.get_fresh_const_generic_var(
primitives.usize(),
Some("N".into()),
None,
);
let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0));
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let p0_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float],
Some("T".into()),
None,
);
let ret_ty = unifier.get_fresh_var_with_range(
&[int32, ndarray_int32],
Some("R".into()),
None,
);
Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), llvm_i32.into())))
}),
),
create_fn_by_codegen(
unifier,
&var_map,
"ceil64",
int64,
&[(float, "n")],
Box::new(|ctx, _, fun, args, generator| {
let llvm_i64 = ctx.ctx.i64_type();
create_fn_by_codegen(
unifier,
&var_map,
"ceil",
ret_ty.0,
&[(p0_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?))
}),
)
},
{
let common_ndim = unifier.get_fresh_const_generic_var(
primitives.usize(),
Some("N".into()),
None,
);
let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0));
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), llvm_i64.into())))
}),
),
let p0_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float],
Some("T".into()),
None,
);
let ret_ty = unifier.get_fresh_var_with_range(
&[int64, ndarray_int64],
Some("R".into()),
None,
);
create_fn_by_codegen(
unifier,
&var_map,
"ceil64",
ret_ty.0,
&[(p0_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?))
}),
)
},
create_fn_by_codegen(
unifier,
&var_map,
"np_ceil",
float,
&[(float, "n")],
float_or_ndarray_ty.0,
&[(float_or_ndarray_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), arg.get_type().into())))
Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
}),
),
Arc::new(RwLock::new({

View File

@ -79,6 +79,18 @@ def round_away_zero(x):
else:
return math.ceil(x - 0.5)
def _floor(x):
if isinstance(x, np.ndarray):
return np.vectorize(_floor)(x)
else:
return math.floor(x)
def _ceil(x):
if isinstance(x, np.ndarray):
return np.vectorize(_ceil)(x)
else:
return math.ceil(x)
def patch(module):
def dbl_nan():
return np.nan
@ -142,11 +154,11 @@ def patch(module):
module.round = round_away_zero
module.round64 = round_away_zero
module.np_round = np.round
module.floor = math.floor
module.floor64 = math.floor
module.floor = _floor
module.floor64 = _floor
module.np_floor = np.floor
module.ceil = math.ceil
module.ceil64 = math.ceil
module.ceil = _ceil
module.ceil64 = _ceil
module.np_ceil = np.ceil
# NumPy ndarray functions

View File

@ -729,6 +729,28 @@ def test_ndarray_round():
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def test_ndarray_floor():
x = np_identity(2)
xf32 = floor(x)
xf64 = floor64(x)
xff = np_floor(x)
output_ndarray_float_2(x)
output_ndarray_int32_2(xf32)
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def test_ndarray_ceil():
x = np_identity(2)
xf32 = ceil(x)
xf64 = ceil64(x)
xff = np_ceil(x)
output_ndarray_float_2(x)
output_ndarray_int32_2(xf32)
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
@ -827,5 +849,6 @@ def run() -> int32:
test_ndarray_bool()
test_ndarray_round()
test_ndarray_floor()
return 0