forked from M-Labs/nac3
core/ndstrides: implement np_size()
This commit is contained in:
parent
854a3eb6f0
commit
ee66cac8c2
|
@ -15,6 +15,7 @@ use crate::{
|
|||
codegen::{
|
||||
builtin_fns,
|
||||
classes::{ProxyValue, RangeValue},
|
||||
model::*,
|
||||
numpy::*,
|
||||
object::{any::AnyObject, ndarray::NDArrayObject},
|
||||
stmt::exn_constructor,
|
||||
|
@ -512,7 +513,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
| PrimDef::FunNpEye
|
||||
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
||||
|
||||
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
||||
PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
||||
self.build_ndarray_property_getter_function(prim)
|
||||
}
|
||||
|
||||
|
@ -1391,7 +1392,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
}
|
||||
|
||||
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]);
|
||||
debug_assert_prim_is_allowed(
|
||||
prim,
|
||||
&[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides],
|
||||
);
|
||||
|
||||
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[self.primitives.ndarray],
|
||||
|
@ -1400,6 +1404,27 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
);
|
||||
|
||||
match prim {
|
||||
PrimDef::FunNpSize => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&into_var_map([in_ndarray_ty]),
|
||||
prim.name(),
|
||||
self.primitives.int32,
|
||||
&[(in_ndarray_ty.ty, "a")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let size =
|
||||
ndarray.size(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32);
|
||||
Ok(Some(size.value.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
||||
// The function signatures of `np_shape` an `np_size` are the same.
|
||||
// Mixed together for convenience.
|
||||
|
|
|
@ -53,6 +53,7 @@ pub enum PrimDef {
|
|||
FunNpIdentity,
|
||||
|
||||
// NumPy ndarray property getters
|
||||
FunNpSize,
|
||||
FunNpShape,
|
||||
FunNpStrides,
|
||||
|
||||
|
@ -243,6 +244,7 @@ impl PrimDef {
|
|||
PrimDef::FunNpIdentity => fun("np_identity", None),
|
||||
|
||||
// NumPy NDArray property getters,
|
||||
PrimDef::FunNpSize => fun("np_size", None),
|
||||
PrimDef::FunNpShape => fun("np_shape", None),
|
||||
PrimDef::FunNpStrides => fun("np_strides", None),
|
||||
|
||||
|
|
|
@ -180,6 +180,7 @@ def patch(module):
|
|||
module.np_array = np.array
|
||||
|
||||
# NumPy NDArray property getters
|
||||
module.np_size = np.size
|
||||
module.np_shape = np.shape
|
||||
module.np_strides = lambda ndarray: ndarray.strides
|
||||
|
||||
|
|
Loading…
Reference in New Issue