diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index a0673a1..e06366c 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -6,7 +6,8 @@ use strum::IntoEnumIterator; use super::{ helper::{ - debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails, + arraylike_flatten_element_type, debug_assert_prim_is_allowed, extract_ndims, + make_exception_fields, PrimDef, PrimDefDetails, }, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, *, @@ -15,9 +16,12 @@ use crate::{ codegen::{ builtin_fns, numpy::*, - stmt::exn_constructor, + stmt::{exn_constructor, gen_if_callback}, types::ndarray::NDArrayType, - values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue}, + values::{ + ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray}, + ProxyValue, RangeValue, + }, }, symbol_resolver::SymbolValue, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, @@ -405,6 +409,8 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), + PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim), + PrimDef::FunNpSin | PrimDef::FunNpCos | PrimDef::FunNpTan @@ -1720,6 +1726,64 @@ impl<'a> BuiltinBuilder<'a> { ) } + fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]); + + let param_ty = &[(self.num_or_ndarray_ty.ty, "a")]; + let ret_ty = self.primitives.bool; + let var_map = &self.num_or_ndarray_var_map; + let codegen_callback: Box = + Box::new(move |ctx, _, fun, args, generator| { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i1_k0 = llvm_i1.const_zero(); + let llvm_i1_k1 = llvm_i1.const_all_ones(); + + let a_ty = fun.0.args[0].ty; + let a_val = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; + let a = ScalarOrNDArray::from_value(generator, ctx, (a_ty, a_val)); + let a_elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, a_ty); + + let (init, sc_val) = match prim { + PrimDef::FunNpAny => (llvm_i1_k0, llvm_i1_k1), + PrimDef::FunNpAll => (llvm_i1_k1, llvm_i1_k0), + _ => unreachable!(), + }; + + let acc = a.fold(generator, ctx, init, |generator, ctx, hooks, acc, elem| { + gen_if_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare(IntPredicate::EQ, acc, sc_val, "") + .unwrap()) + }, + |_, ctx| { + if let Some(hooks) = hooks { + hooks.build_break_branch(&ctx.builder); + } + Ok(()) + }, + |_, _| Ok(()), + )?; + + let is_truthy = + builtin_fns::call_bool(generator, ctx, (a_elem_ty, elem))?.into_int_value(); + + Ok(match prim { + PrimDef::FunNpAny => ctx.builder.build_or(acc, is_truthy, "").unwrap(), + PrimDef::FunNpAll => ctx.builder.build_and(acc, is_truthy, "").unwrap(), + _ => unreachable!(), + }) + })?; + + Ok(Some(acc.as_basic_value_enum())) + }); + + create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback) + } + /// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input. fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed( diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index de90a41..72d3eaa 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -111,6 +111,8 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, + FunNpAny, + FunNpAll, // Linalg functions FunNpDot, @@ -305,6 +307,8 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), + PrimDef::FunNpAny => fun("np_any", None), + PrimDef::FunNpAll => fun("np_all", None), // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index fa91ed3..180d24f 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -232,6 +232,8 @@ def patch(module): module.np_ldexp = np.ldexp module.np_hypot = np.hypot module.np_nextafter = np.nextafter + module.np_any = np.any + module.np_all = np.all # SciPy Math functions module.sp_spec_erf = special.erf diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index b668860..d077b82 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1551,6 +1551,59 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) + +def test_ndarray_any(): + s0 = 0 + output_bool(np_any(s0)) + s1 = 1 + output_bool(np_any(s1)) + + x1 = np_identity(5) + y1 = np_any(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_any(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_any(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_any(x4) + output_ndarray_float_2(x4) + output_bool(y4) + +def test_ndarray_all(): + s0 = 0 + output_bool(np_all(s0)) + s1 = 1 + output_bool(np_all(s1)) + + x1 = np_identity(5) + y1 = np_all(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_all(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_all(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_all(x4) + output_ndarray_float_2(x4) + output_bool(y4) + def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1851,6 +1904,9 @@ def run() -> int32: test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() + test_ndarray_any() + test_ndarray_all() + test_ndarray_dot() test_ndarray_cholesky() test_ndarray_qr()