forked from M-Labs/nac3
core/ndstrides: implement np_size()
This commit is contained in:
parent
be6d704020
commit
31931b7b26
|
@ -16,6 +16,7 @@ use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
builtin_fns,
|
builtin_fns,
|
||||||
classes::{ProxyValue, RangeValue},
|
classes::{ProxyValue, RangeValue},
|
||||||
|
model::*,
|
||||||
numpy::*,
|
numpy::*,
|
||||||
object::{
|
object::{
|
||||||
any::AnyObject,
|
any::AnyObject,
|
||||||
|
@ -516,7 +517,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
| PrimDef::FunNpEye
|
| PrimDef::FunNpEye
|
||||||
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
||||||
self.build_ndarray_property_getter_function(prim)
|
self.build_ndarray_property_getter_function(prim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1395,7 +1396,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]);
|
debug_assert_prim_is_allowed(
|
||||||
|
prim,
|
||||||
|
&[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides],
|
||||||
|
);
|
||||||
|
|
||||||
let mut var_map = self.num_var_map.clone();
|
let mut var_map = self.num_var_map.clone();
|
||||||
var_map.insert(self.ndarray_dtype_tvar.id, self.ndarray_dtype_tvar.ty);
|
var_map.insert(self.ndarray_dtype_tvar.id, self.ndarray_dtype_tvar.ty);
|
||||||
|
@ -1407,6 +1411,26 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
);
|
);
|
||||||
|
|
||||||
match prim {
|
match prim {
|
||||||
|
PrimDef::FunNpSize => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.primitives.int32,
|
||||||
|
&[(in_ndarray_ty.ty, "a")],
|
||||||
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
|
let ndarray =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||||
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
|
let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32);
|
||||||
|
Ok(Some(size.value.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
||||||
// The fnuction signatures of `np_shape` an `np_size` are the same.
|
// The fnuction signatures of `np_shape` an `np_size` are the same.
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,7 @@ pub enum PrimDef {
|
||||||
FunNpIdentity,
|
FunNpIdentity,
|
||||||
|
|
||||||
// NumPy ndarray property getters
|
// NumPy ndarray property getters
|
||||||
|
FunNpSize,
|
||||||
FunNpShape,
|
FunNpShape,
|
||||||
FunNpStrides,
|
FunNpStrides,
|
||||||
|
|
||||||
|
@ -246,6 +247,7 @@ impl PrimDef {
|
||||||
PrimDef::FunNpIdentity => fun("np_identity", None),
|
PrimDef::FunNpIdentity => fun("np_identity", None),
|
||||||
|
|
||||||
// NumPy NDArray property getters,
|
// NumPy NDArray property getters,
|
||||||
|
PrimDef::FunNpSize => fun("np_size", None),
|
||||||
PrimDef::FunNpShape => fun("np_shape", None),
|
PrimDef::FunNpShape => fun("np_shape", None),
|
||||||
PrimDef::FunNpStrides => fun("np_strides", None),
|
PrimDef::FunNpStrides => fun("np_strides", None),
|
||||||
|
|
||||||
|
|
|
@ -185,6 +185,7 @@ def patch(module):
|
||||||
module.np_reshape = np.reshape
|
module.np_reshape = np.reshape
|
||||||
|
|
||||||
# NumPy NDArray property getters
|
# NumPy NDArray property getters
|
||||||
|
module.np_size = np.size
|
||||||
module.np_shape = np.shape
|
module.np_shape = np.shape
|
||||||
module.np_strides = lambda ndarray: ndarray.strides
|
module.np_strides = lambda ndarray: ndarray.strides
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue