From 694265ed6d2c2dcde065d1eb491084036da3eac3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 10:28:56 +0800 Subject: [PATCH] [core] codegen/ndarray: Implement np_{shape,strides} Based on 40c24486: core/ndstrides: implement np_shape() and np_strides() These functions are not important, but they are handy for debugging. `np.strides()` is not an actual NumPy function, but `ndarray.strides` is used. --- nac3core/src/codegen/values/ndarray/mod.rs | 81 ++++++++++++++++++- nac3core/src/toplevel/builtins.rs | 53 ++++++++++++ nac3core/src/toplevel/helper.rs | 8 ++ ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 41 +++++++++- nac3standalone/demo/interpret_demo.py | 4 + 10 files changed, 189 insertions(+), 12 deletions(-) diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 2907445e..b46b3d8e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -1,19 +1,22 @@ +use std::iter::repeat_n; + use inkwell::{ types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, }; +use itertools::Itertools; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, - TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ irrt, llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, stmt::gen_for_callback_incrementing, type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField}, + types::{ndarray::NDArrayType, structure::StructField, TupleType}, CodeGenContext, CodeGenerator, }; pub use contiguous::*; @@ -417,6 +420,76 @@ impl<'ctx> NDArrayValue<'ctx> { .unwrap(); } + /// Create the shape tuple of this ndarray like `np.shape()`. + /// + /// The returned integers in the tuple are in int32. + pub fn make_shape_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleValue<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::make_shape_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + let llvm_i32 = ctx.ctx.i32_type(); + + let objects = (0..self.ndims.unwrap()) + .map(|i| { + let dim = unsafe { + self.shape().get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i, false), + None, + ) + }; + ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap() + }) + .map(|obj| obj.as_basic_value_enum()) + .collect_vec(); + + TupleType::new( + generator, + ctx.ctx, + &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + ) + .construct_from_objects(ctx, objects, None) + } + + /// Create the strides tuple of this ndarray like `.strides`. + /// + /// The returned integers in the tuple are in int32. + pub fn make_strides_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleValue<'ctx> { + assert!(self.ndims.is_some(), "NDArrayValue::make_strides_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); + + let llvm_i32 = ctx.ctx.i32_type(); + + let objects = (0..self.ndims.unwrap()) + .map(|i| { + let dim = unsafe { + self.strides().get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i, false), + None, + ) + }; + ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap() + }) + .map(|obj| obj.as_basic_value_enum()) + .collect_vec(); + + TupleType::new( + generator, + ctx.ctx, + &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + ) + .construct_from_objects(ctx, objects, None) + } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] pub fn is_unsized(&self) -> Option { diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 36bd85d1..ac3fa08f 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -14,6 +14,7 @@ use crate::{ builtin_fns, numpy::*, stmt::exn_constructor, + types::ndarray::NDArrayType, values::{ProxyValue, RangeValue}, }, symbol_resolver::SymbolValue, @@ -368,6 +369,10 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + self.build_ndarray_property_getter_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -1242,6 +1247,54 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpShape, PrimDef::FunNpStrides]); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + // The function signatures of `np_shape` an `np_size` are the same. + // Mixed together for convenience. + + // The return type is a tuple of variable length depending on the ndims of the input ndarray. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding + + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + ret_ty, + &[(in_ndarray_ty.ty, "a")], + Box::new(move |ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray.into_pointer_value(), None); + + let result_tuple = match prim { + PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), + PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx), + _ => unreachable!(), + }; + + Ok(Some(result_tuple.as_base_value().into())) + }), + ) + } + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 71c1859b..75a7eabc 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -54,6 +54,10 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, + // NumPy ndarray property getters + FunNpShape, + FunNpStrides, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -240,6 +244,10 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), + // NumPy NDArray property getters, + PrimDef::FunNpShape => fun("np_shape", None), + PrimDef::FunNpStrides => fun("np_strides", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 93f2096b..9313448e 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(249)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index d3301d00..0aa21de1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar233]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar233\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 911426b9..2490cc75 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(246)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(251)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index d60daf83..a7230f4d 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar232, typevar233]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar232\", \"typevar233\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 517f6846..871a2f89 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(252)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(260)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 6068f630..8f1c54fc 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,7 +3,7 @@ use std::{ cmp::max, collections::{HashMap, HashSet}, convert::{From, TryInto}, - iter::once, + iter::{once, repeat_n}, sync::Arc, }; @@ -1234,6 +1234,45 @@ impl<'a> Inferencer<'a> { })); } + if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 { + let ndarray = self.fold_expr(args.remove(0))?; + + let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap()); + + // Make a tuple of size `ndims` full of int32 (TODO: Make it usize) + let ret_ty = TypeEnum::TTuple { + ty: repeat_n(self.primitives.int32, ndims as usize).collect_vec(), + is_vararg_ctx: false, + }; + let ret_ty = self.unifier.add_ty(ret_ty); + + let func_ty = TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "a".into(), + default_value: None, + ty: ndarray.custom.unwrap(), + is_vararg: false, + }], + ret: ret_ty, + vars: VarMap::new(), + }); + let func_ty = self.unifier.add_ty(func_ty); + + return Ok(Some(Located { + location, + custom: Some(ret_ty), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(func_ty), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![ndarray], + keywords: vec![], + }, + })); + } + if id == &"np_dot".into() { let arg0 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?; diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 4f19db95..5bcf4bb5 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,10 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray property getters + module.np_shape = np.shape + module.np_strides = lambda ndarray: ndarray.strides + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf