core: WIP - isnan and isinf works now

This commit is contained in:
David Mak 2024-04-29 22:56:01 +08:00
parent 14164c332d
commit e4a5eb4782
4 changed files with 105 additions and 19 deletions

View File

@ -857,30 +857,84 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x: (Type, FloatValue<'ctx>),
) -> IntValue<'ctx> {
x: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_isnan";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (x_ty, x) = x;
if !ctx.unifier.unioned(x_ty, ctx.primitives.float) {
unsupported_type(ctx, "np_isnan", &[x_ty])
}
Ok(match x.get_type() {
BasicTypeEnum::FloatType(_) => {
debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float));
irrt::call_isnan(generator, ctx, x)
irrt::call_isnan(generator, ctx, x.into_float_value()).into()
}
BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None),
|generator, ctx, val| {
let val = call_numpy_isnan(generator, ctx, (elem_ty, val))?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
})
}
/// Invokes the `np_isinf` builtin function.
pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x: (Type, FloatValue<'ctx>),
) -> IntValue<'ctx> {
x: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_isinf";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (x_ty, x) = x;
if !ctx.unifier.unioned(x_ty, ctx.primitives.float) {
unsupported_type(ctx, "np_isinf", &[x_ty])
}
Ok(match x.get_type() {
BasicTypeEnum::FloatType(_) => {
debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float));
irrt::call_isinf(generator, ctx, x)
irrt::call_isinf(generator, ctx, x.into_float_value()).into()
}
BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None),
|generator, ctx, val| {
let val = call_numpy_isinf(generator, ctx, (elem_ty, val))?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
})
}
/// Invokes the `np_sin` builtin function.

View File

@ -1393,10 +1393,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
Box::new(|ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?
.into_float_value();
.to_basic_value_enum(ctx, generator, x_ty)?;
Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val)).into()))
Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val))?))
}),
),
create_fn_by_codegen(
@ -1408,10 +1407,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
Box::new(|ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?
.into_float_value();
.to_basic_value_enum(ctx, generator, x_ty)?;
Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val)).into()))
Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val))?))
}),
),
create_fn_by_codegen(

View File

@ -855,12 +855,14 @@ impl<'a> Inferencer<'a> {
"int32",
"float",
"bool",
"np_isnan",
"np_isinf",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let target_ty = if id == &"int32".into() {
self.primitives.int32
} else if id == &"float".into() {
self.primitives.float
} else if id == &"bool".into() {
} else if id == &"bool".into() || id == &"np_isnan".into() || id == &"np_isinf".into() {
self.primitives.bool
} else { unreachable!() };

View File

@ -1,3 +1,11 @@
@extern
def dbl_nan() -> float:
...
@extern
def dbl_inf() -> float:
...
@extern
def output_bool(x: bool):
...
@ -758,6 +766,28 @@ def test_ndarray_abs():
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_isnan():
x = np_identity(2)
x_isnan = np_isnan(x)
y = np_full([2, 2], dbl_nan())
y_isnan = np_isnan(y)
output_ndarray_float_2(x)
output_ndarray_bool_2(x_isnan)
output_ndarray_float_2(y)
output_ndarray_bool_2(y_isnan)
def test_ndarray_isinf():
x = np_identity(2)
x_isinf = np_isinf(x)
y = np_full([2, 2], dbl_inf())
y_isinf = np_isinf(y)
output_ndarray_float_2(x)
output_ndarray_bool_2(x_isinf)
output_ndarray_float_2(y)
output_ndarray_bool_2(y_isinf)
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
@ -858,5 +888,7 @@ def run() -> int32:
test_ndarray_round()
test_ndarray_floor()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()
return 0