core/ndstrides: implement np_shape() and np_strides()

These functions are not important, but they are handy for debugging +
implementing them takes little effort.

NOTE: `np.strides()` is not an actual NumPy function. You can only(?)
access them thru `ndarray.strides`.
This commit is contained in:
lyken 2024-08-20 15:29:10 +08:00
parent 9da0b825d1
commit 1d189125e7
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
5 changed files with 166 additions and 4 deletions

View File

@ -26,7 +26,7 @@ use crate::{
typecheck::typedef::{Type, TypeEnum},
};
use super::any::AnyObject;
use super::{any::AnyObject, tuple::TupleObject};
/// Fields of [`NDArray`]
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {
@ -403,6 +403,62 @@ impl<'ctx> NDArrayObject<'ctx> {
})
.unwrap();
}
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Return a tuple of SizeT
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.shape)
.get_index_const(generator, ctx, i)
.truncate(generator, ctx, Int32);
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::from_objects(generator, ctx, objects)
}
/// Create the strides tuple of this ndarray like `np.strides(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Return a tuple of SizeT.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.strides)
.get_index_const(generator, ctx, i)
.truncate(generator, ctx, Int32);
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::from_objects(generator, ctx, objects)
}
}
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.

View File

@ -16,6 +16,7 @@ use crate::{
builtin_fns,
classes::{ProxyValue, RangeValue},
numpy::*,
object::{any::AnyObject, ndarray::NDArrayObject},
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
@ -511,6 +512,10 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpEye
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
self.build_ndarray_property_getter_function(prim)
}
PrimDef::FunStr => self.build_str_function(),
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
@ -1385,6 +1390,56 @@ impl<'a> BuiltinBuilder<'a> {
}
}
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]);
let mut var_map = self.num_var_map.clone();
var_map.insert(self.ndarray_dtype_tvar.id, self.ndarray_dtype_tvar.ty);
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.primitives.ndarray],
Some("T".into()),
None,
);
match prim {
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
// The fnuction signatures of `np_shape` an `np_size` are the same.
// The return type is a tuple of variable length depending on the ndims of the input ndarray.
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(in_ndarray_ty.ty, "a")],
Box::new(move |ctx, obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let result_tuple = match prim {
PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx),
PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx),
_ => unreachable!(),
};
Ok(Some(result_tuple.value.as_basic_value_enum()))
}),
)
}
_ => unreachable!(),
}
}
/// Build the `str()` function.
fn build_str_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunStr;
@ -1887,8 +1942,8 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier,
&into_var_map([ndarray_ty]),
prim.name(),
ndarray_ty.ty,
&[(ndarray_ty.ty, "x")],
self.ndarray_num_ty,
&[(self.ndarray_num_ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val =

View File

@ -52,6 +52,10 @@ pub enum PrimDef {
FunNpEye,
FunNpIdentity,
// NumPy ndarray property getters
FunNpShape,
FunNpStrides,
// Miscellaneous NumPy & SciPy functions
FunNpRound,
FunNpFloor,
@ -238,6 +242,10 @@ impl PrimDef {
PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None),
// NumPy NDArray property getters,
PrimDef::FunNpShape => fun("np_shape", None),
PrimDef::FunNpStrides => fun("np_strides", None),
// Miscellaneous NumPy & SciPy functions
PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunNpFloor => fun("np_floor", None),

View File

@ -1,7 +1,7 @@
use std::cmp::max;
use std::collections::{HashMap, HashSet};
use std::convert::{From, TryInto};
use std::iter::once;
use std::iter::{self, once};
use std::{cell::RefCell, sync::Arc};
use super::{
@ -1181,6 +1181,45 @@ impl<'a> Inferencer<'a> {
}));
}
if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 {
let ndarray = self.fold_expr(args.remove(0))?;
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
// Make a tuple of size `ndims` full of int32 (TODO: Make it usize)
let ret_ty = TypeEnum::TTuple {
ty: iter::repeat(self.primitives.int32).take(ndims as usize).collect_vec(),
is_vararg_ctx: false,
};
let ret_ty = self.unifier.add_ty(ret_ty);
let func_ty = TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "a".into(),
default_value: None,
ty: ndarray.custom.unwrap(),
is_vararg: false,
}],
ret: ret_ty,
vars: VarMap::new(),
});
let func_ty = self.unifier.add_ty(func_ty);
return Ok(Some(Located {
location,
custom: Some(ret_ty),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(func_ty),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![ndarray],
keywords: vec![],
},
}));
}
if id == &"np_dot".into() {
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;

View File

@ -179,6 +179,10 @@ def patch(module):
module.np_identity = np.identity
module.np_array = np.array
# NumPy NDArray property getters
module.np_shape = np.shape
module.np_strides = lambda ndarray: ndarray.strides
# NumPy Math functions
module.np_isnan = np.isnan
module.np_isinf = np.isinf