diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index c271ef9c..97c81e48 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -15,6 +15,7 @@ use crate::{ codegen::{ builtin_fns, classes::{ProxyValue, RangeValue}, + model::*, numpy::*, object::{any::AnyObject, ndarray::NDArrayObject}, stmt::exn_constructor, @@ -512,7 +513,7 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), - PrimDef::FunNpShape | PrimDef::FunNpStrides => { + PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => { self.build_ndarray_property_getter_function(prim) } @@ -1391,7 +1392,10 @@ impl<'a> BuiltinBuilder<'a> { } fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides], + ); let mut var_map = self.num_var_map.clone(); var_map.insert(self.ndarray_dtype_tvar.id, self.ndarray_dtype_tvar.ty); @@ -1403,6 +1407,26 @@ impl<'a> BuiltinBuilder<'a> { ); match prim { + PrimDef::FunNpSize => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.primitives.int32, + &[(in_ndarray_ty.ty, "a")], + Box::new(|ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32); + Ok(Some(size.value.as_basic_value_enum())) + }), + ), PrimDef::FunNpShape | PrimDef::FunNpStrides => { // The fnuction signatures of `np_shape` an `np_size` are the same. diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index c4225811..275b80d8 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -53,6 +53,7 @@ pub enum PrimDef { FunNpIdentity, // NumPy ndarray property getters + FunNpSize, FunNpShape, FunNpStrides, @@ -243,6 +244,7 @@ impl PrimDef { PrimDef::FunNpIdentity => fun("np_identity", None), // NumPy NDArray property getters, + PrimDef::FunNpSize => fun("np_size", None), PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpStrides => fun("np_strides", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 5bcf4bb5..ca17c3da 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray property getters + module.np_size = np.size module.np_shape = np.shape module.np_strides = lambda ndarray: ndarray.strides