From d2072d9248f2aba3620a03c124087958f1a670fd Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 16:03:48 +0800 Subject: [PATCH] core/ndstrides: implement np_size() --- nac3core/src/toplevel/builtins.rs | 29 +++++++++++++++++++++++++-- nac3core/src/toplevel/helper.rs | 2 ++ nac3standalone/demo/interpret_demo.py | 1 + 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 9f6a2b9c..0453da90 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -15,6 +15,7 @@ use crate::{ codegen::{ builtin_fns, classes::{ProxyValue, RangeValue}, + model::*, numpy::*, object::{any::AnyObject, ndarray::NDArrayObject}, stmt::exn_constructor, @@ -512,7 +513,7 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), - PrimDef::FunNpShape | PrimDef::FunNpStrides => { + PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => { self.build_ndarray_property_getter_function(prim) } @@ -1391,7 +1392,10 @@ impl<'a> BuiltinBuilder<'a> { } fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1400,6 +1404,27 @@ impl<'a> BuiltinBuilder<'a> { ); match prim { + PrimDef::FunNpSize => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + self.primitives.int32, + &[(in_ndarray_ty.ty, "a")], + Box::new(|ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let size = + ndarray.size(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32); + Ok(Some(size.value.as_basic_value_enum())) + }), + ), PrimDef::FunNpShape | PrimDef::FunNpStrides => { // The function signatures of `np_shape` an `np_size` are the same. // Mixed together for convenience. diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 477027ca..039e4570 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -54,6 +54,7 @@ pub enum PrimDef { FunNpIdentity, // NumPy ndarray property getters + FunNpSize, FunNpShape, FunNpStrides, @@ -244,6 +245,7 @@ impl PrimDef { PrimDef::FunNpIdentity => fun("np_identity", None), // NumPy NDArray property getters, + PrimDef::FunNpSize => fun("np_size", None), PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpStrides => fun("np_strides", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 5bcf4bb5..ca17c3da 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray property getters + module.np_size = np.size module.np_shape = np.shape module.np_strides = lambda ndarray: ndarray.strides