core/ndstrides: implement np_size()

This commit is contained in:
lyken 2024-08-20 16:03:48 +08:00
parent 5bce12a333
commit 56b5764980
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
3 changed files with 29 additions and 2 deletions

View File

@ -15,6 +15,7 @@ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
classes::{ProxyValue, RangeValue}, classes::{ProxyValue, RangeValue},
model::*,
numpy::*, numpy::*,
object::{any::AnyObject, ndarray::NDArrayObject}, object::{any::AnyObject, ndarray::NDArrayObject},
stmt::exn_constructor, stmt::exn_constructor,
@ -512,7 +513,7 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpEye | PrimDef::FunNpEye
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
PrimDef::FunNpShape | PrimDef::FunNpStrides => { PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => {
self.build_ndarray_property_getter_function(prim) self.build_ndarray_property_getter_function(prim)
} }
@ -1391,7 +1392,10 @@ impl<'a> BuiltinBuilder<'a> {
} }
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides],
);
let mut var_map = self.num_var_map.clone(); let mut var_map = self.num_var_map.clone();
var_map.insert(self.ndarray_dtype_tvar.id, self.ndarray_dtype_tvar.ty); var_map.insert(self.ndarray_dtype_tvar.id, self.ndarray_dtype_tvar.ty);
@ -1403,6 +1407,26 @@ impl<'a> BuiltinBuilder<'a> {
); );
match prim { match prim {
PrimDef::FunNpSize => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.primitives.int32,
&[(in_ndarray_ty.ty, "a")],
Box::new(|ctx, obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32);
Ok(Some(size.value.as_basic_value_enum()))
}),
),
PrimDef::FunNpShape | PrimDef::FunNpStrides => { PrimDef::FunNpShape | PrimDef::FunNpStrides => {
// The fnuction signatures of `np_shape` an `np_size` are the same. // The fnuction signatures of `np_shape` an `np_size` are the same.

View File

@ -53,6 +53,7 @@ pub enum PrimDef {
FunNpIdentity, FunNpIdentity,
// NumPy ndarray property getters // NumPy ndarray property getters
FunNpSize,
FunNpShape, FunNpShape,
FunNpStrides, FunNpStrides,
@ -243,6 +244,7 @@ impl PrimDef {
PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpIdentity => fun("np_identity", None),
// NumPy NDArray property getters, // NumPy NDArray property getters,
PrimDef::FunNpSize => fun("np_size", None),
PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpShape => fun("np_shape", None),
PrimDef::FunNpStrides => fun("np_strides", None), PrimDef::FunNpStrides => fun("np_strides", None),

View File

@ -180,6 +180,7 @@ def patch(module):
module.np_array = np.array module.np_array = np.array
# NumPy NDArray property getters # NumPy NDArray property getters
module.np_size = np.size
module.np_shape = np.shape module.np_shape = np.shape
module.np_strides = lambda ndarray: ndarray.strides module.np_strides = lambda ndarray: ndarray.strides