[core] toplevel: Implement np_{any,all}
This commit is contained in:
parent
18e8e5269f
commit
1cfaa1a779
@ -6,7 +6,8 @@ use strum::IntoEnumIterator;
|
||||
|
||||
use super::{
|
||||
helper::{
|
||||
debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails,
|
||||
arraylike_flatten_element_type, debug_assert_prim_is_allowed, extract_ndims,
|
||||
make_exception_fields, PrimDef, PrimDefDetails,
|
||||
},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
*,
|
||||
@ -15,9 +16,12 @@ use crate::{
|
||||
codegen::{
|
||||
builtin_fns,
|
||||
numpy::*,
|
||||
stmt::exn_constructor,
|
||||
stmt::{exn_constructor, gen_if_callback},
|
||||
types::ndarray::NDArrayType,
|
||||
values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue},
|
||||
values::{
|
||||
ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray},
|
||||
ProxyValue, RangeValue,
|
||||
},
|
||||
},
|
||||
symbol_resolver::SymbolValue,
|
||||
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
||||
@ -405,6 +409,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
|
||||
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
|
||||
|
||||
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
|
||||
|
||||
PrimDef::FunNpSin
|
||||
| PrimDef::FunNpCos
|
||||
| PrimDef::FunNpTan
|
||||
@ -1720,6 +1726,64 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
)
|
||||
}
|
||||
|
||||
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
|
||||
|
||||
let param_ty = &[(self.num_or_ndarray_ty.ty, "a")];
|
||||
let ret_ty = self.primitives.bool;
|
||||
let var_map = &self.num_or_ndarray_var_map;
|
||||
let codegen_callback: Box<GenCallCallback> =
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i1_k0 = llvm_i1.const_zero();
|
||||
let llvm_i1_k1 = llvm_i1.const_all_ones();
|
||||
|
||||
let a_ty = fun.0.args[0].ty;
|
||||
let a_val = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
||||
let a = ScalarOrNDArray::from_value(generator, ctx, (a_ty, a_val));
|
||||
let a_elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, a_ty);
|
||||
|
||||
let (init, sc_val) = match prim {
|
||||
PrimDef::FunNpAny => (llvm_i1_k0, llvm_i1_k1),
|
||||
PrimDef::FunNpAll => (llvm_i1_k1, llvm_i1_k0),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let acc = a.fold(generator, ctx, init, |generator, ctx, hooks, acc, elem| {
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, acc, sc_val, "")
|
||||
.unwrap())
|
||||
},
|
||||
|_, ctx| {
|
||||
if let Some(hooks) = hooks {
|
||||
hooks.build_break_branch(&ctx.builder);
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
|_, _| Ok(()),
|
||||
)?;
|
||||
|
||||
let is_truthy =
|
||||
builtin_fns::call_bool(generator, ctx, (a_elem_ty, elem))?.into_int_value();
|
||||
|
||||
Ok(match prim {
|
||||
PrimDef::FunNpAny => ctx.builder.build_or(acc, is_truthy, "").unwrap(),
|
||||
PrimDef::FunNpAll => ctx.builder.build_and(acc, is_truthy, "").unwrap(),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
})?;
|
||||
|
||||
Ok(Some(acc.as_basic_value_enum()))
|
||||
});
|
||||
|
||||
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
|
||||
}
|
||||
|
||||
/// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input.
|
||||
fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(
|
||||
|
@ -111,6 +111,8 @@ pub enum PrimDef {
|
||||
FunNpLdExp,
|
||||
FunNpHypot,
|
||||
FunNpNextAfter,
|
||||
FunNpAny,
|
||||
FunNpAll,
|
||||
|
||||
// Linalg functions
|
||||
FunNpDot,
|
||||
@ -305,6 +307,8 @@ impl PrimDef {
|
||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||
PrimDef::FunNpAny => fun("np_any", None),
|
||||
PrimDef::FunNpAll => fun("np_all", None),
|
||||
|
||||
// Linalg functions
|
||||
PrimDef::FunNpDot => fun("np_dot", None),
|
||||
|
@ -232,6 +232,8 @@ def patch(module):
|
||||
module.np_ldexp = np.ldexp
|
||||
module.np_hypot = np.hypot
|
||||
module.np_nextafter = np.nextafter
|
||||
module.np_any = np.any
|
||||
module.np_all = np.all
|
||||
|
||||
# SciPy Math functions
|
||||
module.sp_spec_erf = special.erf
|
||||
|
@ -1551,6 +1551,59 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
||||
output_ndarray_float_2(nextafter_x_zeros)
|
||||
output_ndarray_float_2(nextafter_x_ones)
|
||||
|
||||
|
||||
def test_ndarray_any():
|
||||
s0 = 0
|
||||
output_bool(np_any(s0))
|
||||
s1 = 1
|
||||
output_bool(np_any(s1))
|
||||
|
||||
x1 = np_identity(5)
|
||||
y1 = np_any(x1)
|
||||
output_ndarray_float_2(x1)
|
||||
output_bool(y1)
|
||||
|
||||
x2 = np_identity(1)
|
||||
y2 = np_any(x2)
|
||||
output_ndarray_float_2(x2)
|
||||
output_bool(y2)
|
||||
|
||||
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||
y3 = np_any(x3)
|
||||
output_ndarray_float_2(x3)
|
||||
output_bool(y3)
|
||||
|
||||
x4 = np_zeros([3, 5])
|
||||
y4 = np_any(x4)
|
||||
output_ndarray_float_2(x4)
|
||||
output_bool(y4)
|
||||
|
||||
def test_ndarray_all():
|
||||
s0 = 0
|
||||
output_bool(np_all(s0))
|
||||
s1 = 1
|
||||
output_bool(np_all(s1))
|
||||
|
||||
x1 = np_identity(5)
|
||||
y1 = np_all(x1)
|
||||
output_ndarray_float_2(x1)
|
||||
output_bool(y1)
|
||||
|
||||
x2 = np_identity(1)
|
||||
y2 = np_all(x2)
|
||||
output_ndarray_float_2(x2)
|
||||
output_bool(y2)
|
||||
|
||||
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||
y3 = np_all(x3)
|
||||
output_ndarray_float_2(x3)
|
||||
output_bool(y3)
|
||||
|
||||
x4 = np_zeros([3, 5])
|
||||
y4 = np_all(x4)
|
||||
output_ndarray_float_2(x4)
|
||||
output_bool(y4)
|
||||
|
||||
def test_ndarray_dot():
|
||||
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
||||
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
||||
@ -1851,6 +1904,9 @@ def run() -> int32:
|
||||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||
|
||||
test_ndarray_any()
|
||||
test_ndarray_all()
|
||||
|
||||
test_ndarray_dot()
|
||||
test_ndarray_cholesky()
|
||||
test_ndarray_qr()
|
||||
|
Loading…
Reference in New Issue
Block a user