From 819e1e46081f0421d146e9ca65e795a551165570 Mon Sep 17 00:00:00 2001 From: lyken Date: Sat, 27 Jul 2024 12:35:22 +0800 Subject: [PATCH] a --- flake.nix | 4 +- nac3core/irrt/irrt/ndarray/transpose.hpp | 162 ++++++++++++++++++ nac3core/irrt/irrt_everything.hpp | 1 + nac3core/src/codegen/irrt/ndarray/mod.rs | 1 + .../src/codegen/irrt/ndarray/transpose.rs | 43 +++++ nac3core/src/codegen/numpy_new/view.rs | 40 ++++- nac3core/src/codegen/util/control_flow.rs | 28 +++ nac3core/src/codegen/util/mod.rs | 1 + nac3core/src/toplevel/builtins.rs | 56 +++++- nac3core/src/toplevel/helper.rs | 2 + nac3standalone/demo/interpret_demo.py | 1 + 11 files changed, 335 insertions(+), 4 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/transpose.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/transpose.rs create mode 100644 nac3core/src/codegen/util/control_flow.rs diff --git a/flake.nix b/flake.nix index a6ce5fce..ffebbb95 100644 --- a/flake.nix +++ b/flake.nix @@ -151,7 +151,7 @@ buildInputs = with pkgs; [ # build dependencies packages.x86_64-linux.llvm-nac3 - llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos + llvmPackages_14.clang llvmPackages_14.llvm.out llvmPackages_14.lldb.out # for running nac3standalone demos packages.x86_64-linux.llvm-tools-irrt cargo rustc @@ -163,7 +163,9 @@ clippy pre-commit rustfmt + rust-analyzer ]; + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; }; devShells.x86_64-linux.msys2 = pkgs.mkShell { name = "nac3-dev-shell-msys2"; diff --git a/nac3core/irrt/irrt/ndarray/transpose.hpp b/nac3core/irrt/irrt/ndarray/transpose.hpp new file mode 100644 index 00000000..ad415062 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/transpose.hpp @@ -0,0 +1,162 @@ +#pragma once + +#include +#include +#include + +/* + * Notes on `np.transpose(, )` + * + * TODO: `axes`, if specified, can actually contain negative indices, + * but it is not documented in numpy. + * + * Supporting it for now. + */ + +namespace { +namespace ndarray { +namespace transpose { +namespace util { + +/** + * @brief Do assertions on `` in `np.transpose(, )`. + * + * Note that `np.transpose`'s `` argument is optional. If the argument + * is specified but the user, use this function to do assertions on it. + * + * @param ndims The number of dimensions of `` + * @param num_axes Number of elements in `` as specified by the user. + * This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown. + * @param axes The user specified ``. + */ +template +void assert_transpose_axes(ErrorContext* errctx, SizeT ndims, SizeT num_axes, + const SizeT* axes) { + /* + * TODO: `axes` can actually contain negative indices, but it is not documented in numpy. + * + * Supporting it for now. + */ + + if (ndims != num_axes) { + errctx->set_error(errctx->error_ids->value_error, + "axes don't match array"); + return; + } + + // TODO: Optimize this + bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims); + for (SizeT i = 0; i < ndims; i++) axe_specified[i] = false; + + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); + if (axis == slice::OUT_OF_BOUNDS) { + // TODO: numpy actually throws a `numpy.exceptions.AxisError` + errctx->set_error( + errctx->error_ids->value_error, + "axis {0} is out of bounds for array of dimension {1}", axis, + ndims); + return; + } + + if (axe_specified[axis]) { + errctx->set_error(errctx->error_ids->value_error, + "repeated axis in transpose"); + return; + } + + axe_specified[axis] = true; + } +} +} // namespace util + +/** + * @brief Create a transpose view of `src_ndarray` and perform proper assertions. + * + * This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, )`. + * If `` is supposed to be `None`, caller can pass in a `nullptr` to ``. + * + * The transpose view created is returned by modifying `dst_ndarray`. + * + * The caller is responsible for setting up `dst_ndarray` before calling this function. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged + * - `dst_ndarray->shape` is updated according to how `np.transpose` works + * - `dst_ndarray->strides` is updated according to how `np.transpose` works + * + * @param src_ndarray The NDArray to build a transpose view on + * @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above, + * @param num_axes Number of elements in axes, can be undefined if `axes` is nullptr. + * @param axes Axes permutation. Set it to `nullptr` if `` is supposed to be `None`. + */ +template +void transpose(ErrorContext* errctx, const NDArray* src_ndarray, + NDArray* dst_ndarray, SizeT num_axes, const SizeT* axes) { + __builtin_assume(src_ndarray->ndims == dst_ndarray->ndims); + const auto ndims = src_ndarray->ndims; + + if (axes != nullptr) { + util::assert_transpose_axes(errctx, ndims, num_axes, axes); + if (errctx->has_error()) return; + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes. + if (axes == nullptr) { + // `np.transpose(, axes=None)` + + /* + * Minor note: `np.transpose(, axes=None)` is equivalent to + * `np.transpose(, axes=[N-1, N-2, ..., 0])` - basically it + * is reversing the order of strides and shape. + * + * This is a fast implementation to handle this special (but very common) case. + */ + + for (SizeT axis = 0; axis < ndims; axis++) { + dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1]; + dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1]; + } + } else { + // `np.transpose(, )` + + // Permute strides and shape according to `axes`, while resolving negative indices in `axes` + for (SizeT axis = 0; axis < ndims; axis++) { + // `i` cannot be OUT_OF_BOUNDS because of assertions + SizeT i = slice::resolve_index_in_length(ndims, axes[axis]); + + dst_ndarray->shape[axis] = src_ndarray->shape[i]; + dst_ndarray->strides[axis] = src_ndarray->strides[i]; + } + } +} +} // namespace transpose +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::transpose; +void __nac3_ndarray_transpose(ErrorContext* errctx, + const NDArray* src_ndarray, + NDArray* dst_ndarray, int32_t num_axes, + const int32_t* axes) { + transpose(errctx, src_ndarray, dst_ndarray, num_axes, axes); +} + +void __nac3_ndarray_transpose64(ErrorContext* errctx, + const NDArray* src_ndarray, + NDArray* dst_ndarray, int64_t num_axes, + const int64_t* axes) { + transpose(errctx, src_ndarray, dst_ndarray, num_axes, axes); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index 5af2d50b..70086570 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -9,5 +9,6 @@ #include #include #include +#include #include #include \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 65b7f558..6a72a914 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -3,3 +3,4 @@ pub mod basic; pub mod fill; pub mod indexing; pub mod reshape; +pub mod transpose; diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs new file mode 100644 index 00000000..867b632c --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -0,0 +1,43 @@ +use crate::codegen::{ + irrt::{ + error_context::{check_error_context, setup_error_context}, + util::get_sized_dependent_function_name, + }, + model::*, + structs::ndarray::NpArray, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Pointer<'ctx, StructModel>>, + dst_ndarray: Pointer<'ctx, StructModel>>, + axes_or_none: Option, SizeTModel<'ctx>>>, +) -> Pointer<'ctx, StructModel>> { + let sizet = generator.get_sizet(ctx.ctx); + let axes_model = PointerModel(sizet); + + let (num_axes, axes) = match axes_or_none { + Some(axes) => (axes.num_elements, axes.pointer), + None => { + // Please refer to the comment in the IRRT implementation + (sizet.constant(ctx.ctx, 0), axes_model.nullptr(ctx.ctx)) + } + }; + + let perrctx = setup_error_context(ctx); + FunctionBuilder::begin( + ctx, + &get_sized_dependent_function_name(sizet, "__nac3_ndarray_transpose"), + ) + .arg("errctx", perrctx) + .arg("src_ndarray", src_ndarray) + .arg("dst_ndarray", dst_ndarray) + .arg("num_axes", num_axes) + .arg("axes", axes) + .returning_void(); + check_error_context(generator, ctx, perrctx); + + dst_ndarray +} diff --git a/nac3core/src/codegen/numpy_new/view.rs b/nac3core/src/codegen/numpy_new/view.rs index 7e6e0198..d99eb679 100644 --- a/nac3core/src/codegen/numpy_new/view.rs +++ b/nac3core/src/codegen/numpy_new/view.rs @@ -10,9 +10,10 @@ use crate::{ call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, }, reshape::call_nac3_ndarray_resolve_and_check_new_shape, + transpose::call_nac3_ndarray_transpose, }, model::*, - structs::ndarray::NpArray, + structs::{list::List, ndarray::NpArray}, util::{array_writer::ArrayWriter, shape::parse_input_shape_arg}, CodeGenContext, CodeGenerator, }, @@ -129,3 +130,40 @@ pub fn gen_ndarray_reshape<'ctx>( let reshaped_ndarray = reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?; Ok(reshaped_ndarray.value) } + +pub fn gen_ndarray_transpose<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert!(matches!(args.len(), 1 | 2)); + + let sizet = generator.get_sizet(context.ctx); + let in_axes_model = PointerModel(StructModel(List { sizet, element: NIntModel(Int32) })); + + // Parse argument #1 ndarray + let ndarray_ty = fun.0.args[0].ty; + let ndarray_arg = args[0].1.clone().to_basic_value_enum(context, generator, ndarray_ty)?; + + // Parse argument #2 axes (optional) + let in_axes = if args.len() == 2 { + let in_shape_ty = fun.0.args[1].ty; + let in_shape_arg = + args[1].1.clone().to_basic_value_enum(context, generator, in_shape_ty)?; + + let in_shape = in_axes_model.review_value(context.ctx, in_shape_arg).unwrap(); + + let num_axes = in_shape.gep(context, |f| f.size).load(context, "num_axes"); + let axes = sizet.array_alloca(context, num_axes, "num_axes"); + + Some((in_shape_ty, in_shape_arg)) + } else { + None + }; + // call_nac3_ndarray_transpose(generator, ctx, src_ndarray, dst_ndarray, axes_or_none) + + todo!() +} diff --git a/nac3core/src/codegen/util/control_flow.rs b/nac3core/src/codegen/util/control_flow.rs new file mode 100644 index 00000000..1f06c427 --- /dev/null +++ b/nac3core/src/codegen/util/control_flow.rs @@ -0,0 +1,28 @@ +use inkwell::{types::IntType, values::IntValue}; + +use crate::codegen::{model::*, CodeGenContext, CodeGenerator}; + +/// Convenient structure that looks like +/// Python `range`, with a dependent int type. +pub struct ForRange { + pub start: T, + pub stop: T, + pub step: T, +} + +impl<'ctx, T> ForRange +where + T: Model<'ctx, Value = IntValue<'ctx>, Type = IntType<'ctx>>, +{ + pub fn end(ctx: &'ctx Context, end: T) -> Self { + todo!() + } +} + +pub fn for_range<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + range: ForRange, + body: BodyFn, +) { +} diff --git a/nac3core/src/codegen/util/mod.rs b/nac3core/src/codegen/util/mod.rs index b3ffc4ec..bd507c61 100644 --- a/nac3core/src/codegen/util/mod.rs +++ b/nac3core/src/codegen/util/mod.rs @@ -1,2 +1,3 @@ pub mod array_writer; +pub mod control_flow; pub mod shape; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 8f043555..d310ef32 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -496,7 +496,9 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), - PrimDef::FunNpReshape => self.build_ndarray_view_functions(prim), + PrimDef::FunNpReshape | PrimDef::FunNpTranspose => { + self.build_ndarray_view_functions(prim) + } PrimDef::FunStr => self.build_str_function(), @@ -1337,7 +1339,7 @@ impl<'a> BuiltinBuilder<'a> { // Build functions related to NDArray views fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape]); + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape, PrimDef::FunNpTranspose]); match prim { PrimDef::FunNpReshape => { @@ -1364,6 +1366,56 @@ impl<'a> BuiltinBuilder<'a> { }), ) } + PrimDef::FunNpTranspose => { + /* + # NDim has to be known (for checking axes's len) + def np_transpose( + array: NDArray[DType, NDim], + axes: Optional[List[int32]] = None, + ) -> NDArray[DType, NDim] + */ + + // TODO: Allow tuples (or even general iterables in the very far future) on `axes` + let optional_axes_ty = self + .unifier + .subst( + self.primitives.option, + &VarMap::from([(self.option_tvar.id, self.list_int32)]), + ) + .unwrap(); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "array".into(), + ty: self.primitives.ndarray, + default_value: None, + }, + FuncArg { + name: "axes".into(), + ty: optional_axes_ty, + default_value: Some(SymbolValue::OptionNone), + }, + ], + ret: self.primitives.ndarray, + vars: VarMap::default(), + })), + var_id: vec![self.ndarray_ndims_tvar.id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + numpy_new::view::gen_ndarray_transpose(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }, + )))), + loc: None, + } + } _ => unreachable!(), } } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index cc22e639..5420b733 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -47,6 +47,7 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, FunNpReshape, + FunNpTranspose, FunRound, FunRound64, FunNpRound, @@ -206,6 +207,7 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpReshape => fun("np_reshape", None), + PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunRound => fun("round", None), PrimDef::FunRound64 => fun("round64", None), PrimDef::FunNpRound => fun("np_round", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 4267aa66..a5ad1429 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): # NumPy view functions module.np_reshape = np.reshape + module.np_transpose = np.transpose # NumPy Math functions module.np_isnan = np.isnan