From be19165eadce5b0da1f0d3b26214b261500a4664 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 15:29:10 +0800 Subject: [PATCH] core/ndstrides: implement np_shape() and np_strides() These functions are not important, but they are handy for debugging. `np.strides()` is not an actual NumPy function, but `ndarray.strides` is used. --- nac3core/src/codegen/object/ndarray/mod.rs | 58 ++++++++++++++++++- nac3core/src/toplevel/builtins.rs | 57 +++++++++++++++++- nac3core/src/toplevel/helper.rs | 8 +++ nac3core/src/typecheck/type_inferencer/mod.rs | 41 ++++++++++++- nac3standalone/demo/interpret_demo.py | 4 ++ 5 files changed, 164 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 48677864..9692f227 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -27,7 +27,7 @@ use crate::{ typecheck::typedef::Type, }; -use super::any::AnyObject; +use super::{any::AnyObject, tuple::TupleObject}; /// Fields of [`NDArray`] pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { @@ -427,6 +427,62 @@ impl<'ctx> NDArrayObject<'ctx> { }) .unwrap(); } + + /// Create the shape tuple of this ndarray like `np.shape()`. + /// + /// The returned integers in the tuple are in int32. + pub fn make_shape_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleObject<'ctx> { + // TODO: Return a tuple of SizeT + + let mut objects = Vec::with_capacity(self.ndims as usize); + + for i in 0..self.ndims { + let dim = self + .instance + .get(generator, ctx, |f| f.shape) + .get_index_const(generator, ctx, i64::try_from(i).unwrap()) + .truncate_or_bit_cast(generator, ctx, Int32); + + objects.push(AnyObject { + ty: ctx.primitives.int32, + value: dim.value.as_basic_value_enum(), + }); + } + + TupleObject::from_objects(generator, ctx, objects) + } + + /// Create the strides tuple of this ndarray like `.strides`. + /// + /// The returned integers in the tuple are in int32. + pub fn make_strides_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleObject<'ctx> { + // TODO: Return a tuple of SizeT. + + let mut objects = Vec::with_capacity(self.ndims as usize); + + for i in 0..self.ndims { + let dim = self + .instance + .get(generator, ctx, |f| f.strides) + .get_index_const(generator, ctx, i64::try_from(i).unwrap()) + .truncate_or_bit_cast(generator, ctx, Int32); + + objects.push(AnyObject { + ty: ctx.primitives.int32, + value: dim.value.as_basic_value_enum(), + }); + } + + TupleObject::from_objects(generator, ctx, objects) + } } /// A convenience enum for implementing functions that acts on scalars or ndarrays or both. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e2325ebb..9f6a2b9c 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -16,6 +16,7 @@ use crate::{ builtin_fns, classes::{ProxyValue, RangeValue}, numpy::*, + object::{any::AnyObject, ndarray::NDArrayObject}, stmt::exn_constructor, }, symbol_resolver::SymbolValue, @@ -511,6 +512,10 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + self.build_ndarray_property_getter_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -1385,6 +1390,54 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + // The function signatures of `np_shape` an `np_size` are the same. + // Mixed together for convenience. + + // The return type is a tuple of variable length depending on the ndims of the input ndarray. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding + + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + ret_ty, + &[(in_ndarray_ty.ty, "a")], + Box::new(move |ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let result_tuple = match prim { + PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), + PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx), + _ => unreachable!(), + }; + + Ok(Some(result_tuple.value.as_basic_value_enum())) + }), + ) + } + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; @@ -1887,8 +1940,8 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &into_var_map([ndarray_ty]), prim.name(), - ndarray_ty.ty, - &[(ndarray_ty.ty, "x")], + self.ndarray_num_ty, + &[(self.ndarray_num_ty, "x")], Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg_val = diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 690f06c3..477027ca 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -53,6 +53,10 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, + // NumPy ndarray property getters + FunNpShape, + FunNpStrides, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -239,6 +243,10 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), + // NumPy NDArray property getters, + PrimDef::FunNpShape => fun("np_shape", None), + PrimDef::FunNpStrides => fun("np_strides", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a5b8cd49..6bdf48df 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1,7 +1,7 @@ use std::cmp::max; use std::collections::{HashMap, HashSet}; use std::convert::{From, TryInto}; -use std::iter::once; +use std::iter::{self, once}; use std::{cell::RefCell, sync::Arc}; use super::{ @@ -1183,6 +1183,45 @@ impl<'a> Inferencer<'a> { })); } + if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 { + let ndarray = self.fold_expr(args.remove(0))?; + + let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap()); + + // Make a tuple of size `ndims` full of int32 (TODO: Make it usize) + let ret_ty = TypeEnum::TTuple { + ty: iter::repeat(self.primitives.int32).take(ndims as usize).collect_vec(), + is_vararg_ctx: false, + }; + let ret_ty = self.unifier.add_ty(ret_ty); + + let func_ty = TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "a".into(), + default_value: None, + ty: ndarray.custom.unwrap(), + is_vararg: false, + }], + ret: ret_ty, + vars: VarMap::new(), + }); + let func_ty = self.unifier.add_ty(func_ty); + + return Ok(Some(Located { + location, + custom: Some(ret_ty), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(func_ty), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![ndarray], + keywords: vec![], + }, + })); + } + if id == &"np_dot".into() { let arg0 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?; diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 4f19db95..5bcf4bb5 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,10 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray property getters + module.np_shape = np.shape + module.np_strides = lambda ndarray: ndarray.strides + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf