forked from M-Labs/nac3
[core] toplevel: Implement np_{any,all}
This commit is contained in:
parent
18e8e5269f
commit
1cfaa1a779
@ -6,7 +6,8 @@ use strum::IntoEnumIterator;
|
|||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
helper::{
|
helper::{
|
||||||
debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails,
|
arraylike_flatten_element_type, debug_assert_prim_is_allowed, extract_ndims,
|
||||||
|
make_exception_fields, PrimDef, PrimDefDetails,
|
||||||
},
|
},
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
*,
|
*,
|
||||||
@ -15,9 +16,12 @@ use crate::{
|
|||||||
codegen::{
|
codegen::{
|
||||||
builtin_fns,
|
builtin_fns,
|
||||||
numpy::*,
|
numpy::*,
|
||||||
stmt::exn_constructor,
|
stmt::{exn_constructor, gen_if_callback},
|
||||||
types::ndarray::NDArrayType,
|
types::ndarray::NDArrayType,
|
||||||
values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue},
|
values::{
|
||||||
|
ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray},
|
||||||
|
ProxyValue, RangeValue,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
||||||
@ -405,6 +409,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
|
|
||||||
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
|
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpSin
|
PrimDef::FunNpSin
|
||||||
| PrimDef::FunNpCos
|
| PrimDef::FunNpCos
|
||||||
| PrimDef::FunNpTan
|
| PrimDef::FunNpTan
|
||||||
@ -1720,6 +1726,64 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
|
||||||
|
|
||||||
|
let param_ty = &[(self.num_or_ndarray_ty.ty, "a")];
|
||||||
|
let ret_ty = self.primitives.bool;
|
||||||
|
let var_map = &self.num_or_ndarray_var_map;
|
||||||
|
let codegen_callback: Box<GenCallCallback> =
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_i1_k0 = llvm_i1.const_zero();
|
||||||
|
let llvm_i1_k1 = llvm_i1.const_all_ones();
|
||||||
|
|
||||||
|
let a_ty = fun.0.args[0].ty;
|
||||||
|
let a_val = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
||||||
|
let a = ScalarOrNDArray::from_value(generator, ctx, (a_ty, a_val));
|
||||||
|
let a_elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, a_ty);
|
||||||
|
|
||||||
|
let (init, sc_val) = match prim {
|
||||||
|
PrimDef::FunNpAny => (llvm_i1_k0, llvm_i1_k1),
|
||||||
|
PrimDef::FunNpAll => (llvm_i1_k1, llvm_i1_k0),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let acc = a.fold(generator, ctx, init, |generator, ctx, hooks, acc, elem| {
|
||||||
|
gen_if_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, acc, sc_val, "")
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, ctx| {
|
||||||
|
if let Some(hooks) = hooks {
|
||||||
|
hooks.build_break_branch(&ctx.builder);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|_, _| Ok(()),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let is_truthy =
|
||||||
|
builtin_fns::call_bool(generator, ctx, (a_elem_ty, elem))?.into_int_value();
|
||||||
|
|
||||||
|
Ok(match prim {
|
||||||
|
PrimDef::FunNpAny => ctx.builder.build_or(acc, is_truthy, "").unwrap(),
|
||||||
|
PrimDef::FunNpAll => ctx.builder.build_and(acc, is_truthy, "").unwrap(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(Some(acc.as_basic_value_enum()))
|
||||||
|
});
|
||||||
|
|
||||||
|
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
|
||||||
|
}
|
||||||
|
|
||||||
/// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input.
|
/// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input.
|
||||||
fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(
|
debug_assert_prim_is_allowed(
|
||||||
|
@ -111,6 +111,8 @@ pub enum PrimDef {
|
|||||||
FunNpLdExp,
|
FunNpLdExp,
|
||||||
FunNpHypot,
|
FunNpHypot,
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
|
FunNpAny,
|
||||||
|
FunNpAll,
|
||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
FunNpDot,
|
FunNpDot,
|
||||||
@ -305,6 +307,8 @@ impl PrimDef {
|
|||||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
|
PrimDef::FunNpAny => fun("np_any", None),
|
||||||
|
PrimDef::FunNpAll => fun("np_all", None),
|
||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
PrimDef::FunNpDot => fun("np_dot", None),
|
PrimDef::FunNpDot => fun("np_dot", None),
|
||||||
|
@ -232,6 +232,8 @@ def patch(module):
|
|||||||
module.np_ldexp = np.ldexp
|
module.np_ldexp = np.ldexp
|
||||||
module.np_hypot = np.hypot
|
module.np_hypot = np.hypot
|
||||||
module.np_nextafter = np.nextafter
|
module.np_nextafter = np.nextafter
|
||||||
|
module.np_any = np.any
|
||||||
|
module.np_all = np.all
|
||||||
|
|
||||||
# SciPy Math functions
|
# SciPy Math functions
|
||||||
module.sp_spec_erf = special.erf
|
module.sp_spec_erf = special.erf
|
||||||
|
@ -1551,6 +1551,59 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
|||||||
output_ndarray_float_2(nextafter_x_zeros)
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
output_ndarray_float_2(nextafter_x_ones)
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ndarray_any():
|
||||||
|
s0 = 0
|
||||||
|
output_bool(np_any(s0))
|
||||||
|
s1 = 1
|
||||||
|
output_bool(np_any(s1))
|
||||||
|
|
||||||
|
x1 = np_identity(5)
|
||||||
|
y1 = np_any(x1)
|
||||||
|
output_ndarray_float_2(x1)
|
||||||
|
output_bool(y1)
|
||||||
|
|
||||||
|
x2 = np_identity(1)
|
||||||
|
y2 = np_any(x2)
|
||||||
|
output_ndarray_float_2(x2)
|
||||||
|
output_bool(y2)
|
||||||
|
|
||||||
|
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
y3 = np_any(x3)
|
||||||
|
output_ndarray_float_2(x3)
|
||||||
|
output_bool(y3)
|
||||||
|
|
||||||
|
x4 = np_zeros([3, 5])
|
||||||
|
y4 = np_any(x4)
|
||||||
|
output_ndarray_float_2(x4)
|
||||||
|
output_bool(y4)
|
||||||
|
|
||||||
|
def test_ndarray_all():
|
||||||
|
s0 = 0
|
||||||
|
output_bool(np_all(s0))
|
||||||
|
s1 = 1
|
||||||
|
output_bool(np_all(s1))
|
||||||
|
|
||||||
|
x1 = np_identity(5)
|
||||||
|
y1 = np_all(x1)
|
||||||
|
output_ndarray_float_2(x1)
|
||||||
|
output_bool(y1)
|
||||||
|
|
||||||
|
x2 = np_identity(1)
|
||||||
|
y2 = np_all(x2)
|
||||||
|
output_ndarray_float_2(x2)
|
||||||
|
output_bool(y2)
|
||||||
|
|
||||||
|
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
y3 = np_all(x3)
|
||||||
|
output_ndarray_float_2(x3)
|
||||||
|
output_bool(y3)
|
||||||
|
|
||||||
|
x4 = np_zeros([3, 5])
|
||||||
|
y4 = np_all(x4)
|
||||||
|
output_ndarray_float_2(x4)
|
||||||
|
output_bool(y4)
|
||||||
|
|
||||||
def test_ndarray_dot():
|
def test_ndarray_dot():
|
||||||
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
||||||
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
||||||
@ -1851,6 +1904,9 @@ def run() -> int32:
|
|||||||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||||
|
|
||||||
|
test_ndarray_any()
|
||||||
|
test_ndarray_all()
|
||||||
|
|
||||||
test_ndarray_dot()
|
test_ndarray_dot()
|
||||||
test_ndarray_cholesky()
|
test_ndarray_cholesky()
|
||||||
test_ndarray_qr()
|
test_ndarray_qr()
|
||||||
|
Loading…
Reference in New Issue
Block a user