core: WIP - floor and ceil works now
This commit is contained in:
parent
37a29162c6
commit
163867381a
|
@ -10,6 +10,8 @@ use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
|||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::Type;
|
||||
|
||||
// TODO: Rename ret_ty to ret_elem_ty or similar
|
||||
|
||||
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
||||
///
|
||||
/// The generated message will contain the function name and the name of the unsupported type.
|
||||
|
@ -567,55 +569,117 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
|||
}
|
||||
|
||||
/// Invokes the `floor` builtin function.
|
||||
pub fn call_floor<'ctx>(
|
||||
pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, FloatValue<'ctx>),
|
||||
llvm_ret_ty: BasicTypeEnum<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
ret_ty: Type,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "floor";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
|
||||
|
||||
let val = llvm_intrinsics::call_float_floor(ctx, n, None);
|
||||
match llvm_ret_ty {
|
||||
_ if llvm_ret_ty == val.get_type().into() => val.into(),
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::IntType(_)
|
||||
| BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
|
||||
BasicTypeEnum::IntType(_) => {
|
||||
let val = llvm_intrinsics::call_float_floor(ctx, n.into_float_value(), None);
|
||||
if llvm_ret_ty.is_int_type() {
|
||||
ctx.builder
|
||||
.build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME)
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
} else {
|
||||
val.into()
|
||||
}
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ret_ty,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_floor(generator, ctx, (elem_ty, val), ret_ty)
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `ceil` builtin function.
|
||||
pub fn call_ceil<'ctx>(
|
||||
pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, FloatValue<'ctx>),
|
||||
llvm_ret_ty: BasicTypeEnum<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
ret_ty: Type,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "ceil";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
|
||||
|
||||
let val = llvm_intrinsics::call_float_ceil(ctx, n, None);
|
||||
match llvm_ret_ty {
|
||||
_ if llvm_ret_ty == val.get_type().into() => val.into(),
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::IntType(_)
|
||||
| BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
|
||||
BasicTypeEnum::IntType(_) => {
|
||||
let val = llvm_intrinsics::call_float_ceil(ctx, n.into_float_value(), None);
|
||||
if llvm_ret_ty.is_int_type() {
|
||||
ctx.builder
|
||||
.build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME)
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
} else {
|
||||
val.into()
|
||||
}
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ret_ty,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_floor(generator, ctx, (elem_ty, val), ret_ty)
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `min` builtin function.
|
||||
|
|
|
@ -363,6 +363,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
|||
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
|
||||
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
|
||||
|
||||
// TODO: Double check (T | ndarray[T]).to_basic_value_enum converts correctly
|
||||
// TODO: Directly obtain Type instance after get_fresh_var_with_
|
||||
let top_level_def_list = vec![
|
||||
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
|
||||
PRIMITIVE_DEF_IDS.int32,
|
||||
|
@ -1025,102 +1027,172 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
|||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
{
|
||||
let common_ndim = unifier.get_fresh_const_generic_var(
|
||||
primitives.usize(),
|
||||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0));
|
||||
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
|
||||
|
||||
let p0_ty = unifier.get_fresh_var_with_range(
|
||||
&[float, ndarray_float],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
let ret_ty = unifier.get_fresh_var_with_range(
|
||||
&[int32, ndarray_int32],
|
||||
Some("R".into()),
|
||||
None,
|
||||
);
|
||||
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"floor",
|
||||
int32,
|
||||
&[(float, "n")],
|
||||
ret_ty.0,
|
||||
&[(p0_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), llvm_i32.into())))
|
||||
Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?))
|
||||
}),
|
||||
),
|
||||
)
|
||||
},
|
||||
{
|
||||
let common_ndim = unifier.get_fresh_const_generic_var(
|
||||
primitives.usize(),
|
||||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0));
|
||||
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
|
||||
|
||||
let p0_ty = unifier.get_fresh_var_with_range(
|
||||
&[float, ndarray_float],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
let ret_ty = unifier.get_fresh_var_with_range(
|
||||
&[int64, ndarray_int64],
|
||||
Some("R".into()),
|
||||
None,
|
||||
);
|
||||
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"floor64",
|
||||
int64,
|
||||
&[(float, "n")],
|
||||
ret_ty.0,
|
||||
&[(p0_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), llvm_i64.into())))
|
||||
Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?))
|
||||
}),
|
||||
),
|
||||
)
|
||||
},
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"np_floor",
|
||||
float,
|
||||
&[(float, "n")],
|
||||
float_or_ndarray_ty.0,
|
||||
&[(float_or_ndarray_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), arg.get_type().into())))
|
||||
Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
|
||||
}),
|
||||
),
|
||||
{
|
||||
let common_ndim = unifier.get_fresh_const_generic_var(
|
||||
primitives.usize(),
|
||||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0));
|
||||
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
|
||||
|
||||
let p0_ty = unifier.get_fresh_var_with_range(
|
||||
&[float, ndarray_float],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
let ret_ty = unifier.get_fresh_var_with_range(
|
||||
&[int32, ndarray_int32],
|
||||
Some("R".into()),
|
||||
None,
|
||||
);
|
||||
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"ceil",
|
||||
int32,
|
||||
&[(float, "n")],
|
||||
ret_ty.0,
|
||||
&[(p0_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), llvm_i32.into())))
|
||||
Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?))
|
||||
}),
|
||||
),
|
||||
)
|
||||
},
|
||||
{
|
||||
let common_ndim = unifier.get_fresh_const_generic_var(
|
||||
primitives.usize(),
|
||||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0));
|
||||
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0));
|
||||
|
||||
let p0_ty = unifier.get_fresh_var_with_range(
|
||||
&[float, ndarray_float],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
let ret_ty = unifier.get_fresh_var_with_range(
|
||||
&[int64, ndarray_int64],
|
||||
Some("R".into()),
|
||||
None,
|
||||
);
|
||||
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"ceil64",
|
||||
int64,
|
||||
&[(float, "n")],
|
||||
ret_ty.0,
|
||||
&[(p0_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), llvm_i64.into())))
|
||||
Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?))
|
||||
}),
|
||||
),
|
||||
)
|
||||
},
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"np_ceil",
|
||||
float,
|
||||
&[(float, "n")],
|
||||
float_or_ndarray_ty.0,
|
||||
&[(float_or_ndarray_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), arg.get_type().into())))
|
||||
Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
|
||||
}),
|
||||
),
|
||||
Arc::new(RwLock::new({
|
||||
|
|
|
@ -79,6 +79,18 @@ def round_away_zero(x):
|
|||
else:
|
||||
return math.ceil(x - 0.5)
|
||||
|
||||
def _floor(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.vectorize(_floor)(x)
|
||||
else:
|
||||
return math.floor(x)
|
||||
|
||||
def _ceil(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.vectorize(_ceil)(x)
|
||||
else:
|
||||
return math.ceil(x)
|
||||
|
||||
def patch(module):
|
||||
def dbl_nan():
|
||||
return np.nan
|
||||
|
@ -142,11 +154,11 @@ def patch(module):
|
|||
module.round = round_away_zero
|
||||
module.round64 = round_away_zero
|
||||
module.np_round = np.round
|
||||
module.floor = math.floor
|
||||
module.floor64 = math.floor
|
||||
module.floor = _floor
|
||||
module.floor64 = _floor
|
||||
module.np_floor = np.floor
|
||||
module.ceil = math.ceil
|
||||
module.ceil64 = math.ceil
|
||||
module.ceil = _ceil
|
||||
module.ceil64 = _ceil
|
||||
module.np_ceil = np.ceil
|
||||
|
||||
# NumPy ndarray functions
|
||||
|
|
|
@ -729,6 +729,28 @@ def test_ndarray_round():
|
|||
output_ndarray_int64_2(xf64)
|
||||
output_ndarray_float_2(xff)
|
||||
|
||||
def test_ndarray_floor():
|
||||
x = np_identity(2)
|
||||
xf32 = floor(x)
|
||||
xf64 = floor64(x)
|
||||
xff = np_floor(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_int32_2(xf32)
|
||||
output_ndarray_int64_2(xf64)
|
||||
output_ndarray_float_2(xff)
|
||||
|
||||
def test_ndarray_ceil():
|
||||
x = np_identity(2)
|
||||
xf32 = ceil(x)
|
||||
xf64 = ceil64(x)
|
||||
xff = np_ceil(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_int32_2(xf32)
|
||||
output_ndarray_int64_2(xf64)
|
||||
output_ndarray_float_2(xff)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
|
@ -827,5 +849,6 @@ def run() -> int32:
|
|||
test_ndarray_bool()
|
||||
|
||||
test_ndarray_round()
|
||||
test_ndarray_floor()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue