diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 7c203df..3f0c08d 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -857,30 +857,84 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> IntValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_isnan"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_isnan", &[x_ty]) - } + Ok(match x.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - irrt::call_isnan(generator, ctx, x) + irrt::call_isnan(generator, ctx, x.into_float_value()).into() + } + + BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + let val = call_numpy_isnan(generator, ctx, (elem_ty, val))?; + + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_isinf` builtin function. pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> IntValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_isinf"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_isinf", &[x_ty]) - } + Ok(match x.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - irrt::call_isinf(generator, ctx, x) + irrt::call_isinf(generator, ctx, x.into_float_value()).into() + } + + BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + let val = call_numpy_isinf(generator, ctx, (elem_ty, val))?; + + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_sin` builtin function. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index f9049d9..cb5fb35 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1393,10 +1393,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( @@ -1408,10 +1407,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7fa6c63..6ad97e7 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -855,12 +855,14 @@ impl<'a> Inferencer<'a> { "int32", "float", "bool", + "np_isnan", + "np_isinf", ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { let target_ty = if id == &"int32".into() { self.primitives.int32 } else if id == &"float".into() { self.primitives.float - } else if id == &"bool".into() { + } else if id == &"bool".into() || id == &"np_isnan".into() || id == &"np_isinf".into() { self.primitives.bool } else { unreachable!() }; diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index d803c80..fe080b9 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1,3 +1,11 @@ +@extern +def dbl_nan() -> float: + ... + +@extern +def dbl_inf() -> float: + ... + @extern def output_bool(x: bool): ... @@ -758,6 +766,28 @@ def test_ndarray_abs(): output_ndarray_float_2(x) output_ndarray_float_2(y) +def test_ndarray_isnan(): + x = np_identity(2) + x_isnan = np_isnan(x) + y = np_full([2, 2], dbl_nan()) + y_isnan = np_isnan(y) + + output_ndarray_float_2(x) + output_ndarray_bool_2(x_isnan) + output_ndarray_float_2(y) + output_ndarray_bool_2(y_isnan) + +def test_ndarray_isinf(): + x = np_identity(2) + x_isinf = np_isinf(x) + y = np_full([2, 2], dbl_inf()) + y_isinf = np_isinf(y) + + output_ndarray_float_2(x) + output_ndarray_bool_2(x_isinf) + output_ndarray_float_2(y) + output_ndarray_bool_2(y_isinf) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -858,5 +888,7 @@ def run() -> int32: test_ndarray_round() test_ndarray_floor() test_ndarray_abs() + test_ndarray_isnan() + test_ndarray_isinf() return 0