diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 1c076dce..371bcc64 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -16,6 +16,7 @@ use crate::{ codegen::{ builtin_fns, classes::{ProxyValue, RangeValue}, + model::*, numpy::*, object::{ any::AnyObject, @@ -516,7 +517,7 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), - PrimDef::FunNpShape | PrimDef::FunNpStrides => { + PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => { self.build_ndarray_property_getter_function(prim) } @@ -1395,7 +1396,10 @@ impl<'a> BuiltinBuilder<'a> { } fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides], + ); let mut var_map = self.num_var_map.clone(); var_map.insert(self.ndarray_dtype_tvar.id, self.ndarray_dtype_tvar.ty); @@ -1407,6 +1411,26 @@ impl<'a> BuiltinBuilder<'a> { ); match prim { + PrimDef::FunNpSize => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.primitives.int32, + &[(in_ndarray_ty.ty, "a")], + Box::new(|ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32); + Ok(Some(size.value.as_basic_value_enum())) + }), + ), PrimDef::FunNpShape | PrimDef::FunNpStrides => { // The fnuction signatures of `np_shape` an `np_size` are the same. diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 95ffdf7c..2533489a 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -53,6 +53,7 @@ pub enum PrimDef { FunNpIdentity, // NumPy ndarray property getters + FunNpSize, FunNpShape, FunNpStrides, @@ -246,6 +247,7 @@ impl PrimDef { PrimDef::FunNpIdentity => fun("np_identity", None), // NumPy NDArray property getters, + PrimDef::FunNpSize => fun("np_size", None), PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpStrides => fun("np_strides", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index f9ad5630..8784ce53 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -185,6 +185,7 @@ def patch(module): module.np_reshape = np.reshape # NumPy NDArray property getters + module.np_size = np.size module.np_shape = np.shape module.np_strides = lambda ndarray: ndarray.strides