[core] toplevel: Implement np_{any,all}

This commit is contained in:
David Mak 2025-01-15 15:36:03 +08:00
parent 18e8e5269f
commit 1cfaa1a779
4 changed files with 129 additions and 3 deletions

View File

@ -6,7 +6,8 @@ use strum::IntoEnumIterator;
use super::{
helper::{
debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails,
arraylike_flatten_element_type, debug_assert_prim_is_allowed, extract_ndims,
make_exception_fields, PrimDef, PrimDefDetails,
},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
*,
@ -15,9 +16,12 @@ use crate::{
codegen::{
builtin_fns,
numpy::*,
stmt::exn_constructor,
stmt::{exn_constructor, gen_if_callback},
types::ndarray::NDArrayType,
values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue},
values::{
ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray},
ProxyValue, RangeValue,
},
},
symbol_resolver::SymbolValue,
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
@ -405,6 +409,8 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
PrimDef::FunNpSin
| PrimDef::FunNpCos
| PrimDef::FunNpTan
@ -1720,6 +1726,64 @@ impl<'a> BuiltinBuilder<'a> {
)
}
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
let param_ty = &[(self.num_or_ndarray_ty.ty, "a")];
let ret_ty = self.primitives.bool;
let var_map = &self.num_or_ndarray_var_map;
let codegen_callback: Box<GenCallCallback> =
Box::new(move |ctx, _, fun, args, generator| {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i1_k0 = llvm_i1.const_zero();
let llvm_i1_k1 = llvm_i1.const_all_ones();
let a_ty = fun.0.args[0].ty;
let a_val = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
let a = ScalarOrNDArray::from_value(generator, ctx, (a_ty, a_val));
let a_elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, a_ty);
let (init, sc_val) = match prim {
PrimDef::FunNpAny => (llvm_i1_k0, llvm_i1_k1),
PrimDef::FunNpAll => (llvm_i1_k1, llvm_i1_k0),
_ => unreachable!(),
};
let acc = a.fold(generator, ctx, init, |generator, ctx, hooks, acc, elem| {
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::EQ, acc, sc_val, "")
.unwrap())
},
|_, ctx| {
if let Some(hooks) = hooks {
hooks.build_break_branch(&ctx.builder);
}
Ok(())
},
|_, _| Ok(()),
)?;
let is_truthy =
builtin_fns::call_bool(generator, ctx, (a_elem_ty, elem))?.into_int_value();
Ok(match prim {
PrimDef::FunNpAny => ctx.builder.build_or(acc, is_truthy, "").unwrap(),
PrimDef::FunNpAll => ctx.builder.build_and(acc, is_truthy, "").unwrap(),
_ => unreachable!(),
})
})?;
Ok(Some(acc.as_basic_value_enum()))
});
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
}
/// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input.
fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(

View File

@ -111,6 +111,8 @@ pub enum PrimDef {
FunNpLdExp,
FunNpHypot,
FunNpNextAfter,
FunNpAny,
FunNpAll,
// Linalg functions
FunNpDot,
@ -305,6 +307,8 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpAny => fun("np_any", None),
PrimDef::FunNpAll => fun("np_all", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None),

View File

@ -232,6 +232,8 @@ def patch(module):
module.np_ldexp = np.ldexp
module.np_hypot = np.hypot
module.np_nextafter = np.nextafter
module.np_any = np.any
module.np_all = np.all
# SciPy Math functions
module.sp_spec_erf = special.erf

View File

@ -1551,6 +1551,59 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_any():
s0 = 0
output_bool(np_any(s0))
s1 = 1
output_bool(np_any(s1))
x1 = np_identity(5)
y1 = np_any(x1)
output_ndarray_float_2(x1)
output_bool(y1)
x2 = np_identity(1)
y2 = np_any(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_any(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_any(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def test_ndarray_all():
s0 = 0
output_bool(np_all(s0))
s1 = 1
output_bool(np_all(s1))
x1 = np_identity(5)
y1 = np_all(x1)
output_ndarray_float_2(x1)
output_bool(y1)
x2 = np_identity(1)
y2 = np_all(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_all(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_all(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def test_ndarray_dot():
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
@ -1851,6 +1904,9 @@ def run() -> int32:
test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_any()
test_ndarray_all()
test_ndarray_dot()
test_ndarray_cholesky()
test_ndarray_qr()