From 7a7a67b522ec9b48b48ff2a7103b078678c41656 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 16:35:20 +0800 Subject: [PATCH] core/ndstrides: implement np_transpose() (no axes argument) The IRRT implementation knows how to handle axes. But the argument is not in NAC3 yet. --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/ndarray/transpose.hpp | 155 ++++++++++++++++++++ nac3core/src/codegen/irrt/mod.rs | 17 +++ nac3core/src/codegen/numpy.rs | 106 ------------- nac3core/src/codegen/object/ndarray/view.rs | 34 ++++- nac3core/src/toplevel/builtins.rs | 32 ++-- 6 files changed, 225 insertions(+), 120 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/transpose.hpp diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index f47d539..923fe32 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/transpose.hpp b/nac3core/irrt/irrt/ndarray/transpose.hpp new file mode 100644 index 0000000..ab5fe00 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/transpose.hpp @@ -0,0 +1,155 @@ +#pragma once + +#include +#include +#include + +/* + * Notes on `np.transpose(, )` + * + * TODO: `axes`, if specified, can actually contain negative indices, + * but it is not documented in numpy. + * + * Supporting it for now. + */ + +namespace +{ +namespace ndarray +{ +namespace transpose +{ +/** + * @brief Do assertions on `` in `np.transpose(, )`. + * + * Note that `np.transpose`'s `` argument is optional. If the argument + * is specified but the user, use this function to do assertions on it. + * + * @param ndims The number of dimensions of `` + * @param num_axes Number of elements in `` as specified by the user. + * This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown. + * @param axes The user specified ``. + */ +template void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT *axes) +{ + if (ndims != num_axes) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM); + } + + // TODO: Optimize this + bool *axe_specified = (bool *)__builtin_alloca(sizeof(bool) * ndims); + for (SizeT i = 0; i < ndims; i++) + axe_specified[i] = false; + + for (SizeT i = 0; i < ndims; i++) + { + SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); + if (axis == -1) + { + // TODO: numpy actually throws a `numpy.exceptions.AxisError` + raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, + NO_PARAM); + } + + if (axe_specified[axis]) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM); + } + + axe_specified[axis] = true; + } +} + +/** + * @brief Create a transpose view of `src_ndarray` and perform proper assertions. + * + * This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, )`. + * If `` is supposed to be `None`, caller can pass in a `nullptr` to ``. + * + * The transpose view created is returned by modifying `dst_ndarray`. + * + * The caller is responsible for setting up `dst_ndarray` before calling this function. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged + * - `dst_ndarray->shape` is updated according to how `np.transpose` works + * - `dst_ndarray->strides` is updated according to how `np.transpose` works + * + * @param src_ndarray The NDArray to build a transpose view on + * @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above, + * @param num_axes Number of elements in axes. Unused if `axes` is nullptr. + * @param axes Axes permutation. Set it to `nullptr` if `` is `None`. + */ +template +void transpose(const NDArray *src_ndarray, NDArray *dst_ndarray, SizeT num_axes, const SizeT *axes) +{ + debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims); + const auto ndims = src_ndarray->ndims; + + if (axes != nullptr) + assert_transpose_axes(ndims, num_axes, axes); + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes. + if (axes == nullptr) + { + // `np.transpose(, axes=None)` + + /* + * Minor note: `np.transpose(, axes=None)` is equivalent to + * `np.transpose(, axes=[N-1, N-2, ..., 0])` - basically it + * is reversing the order of strides and shape. + * + * This is a fast implementation to handle this special (but very common) case. + */ + + for (SizeT axis = 0; axis < ndims; axis++) + { + dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1]; + dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1]; + } + } + else + { + // `np.transpose(, )` + + // Permute strides and shape according to `axes`, while resolving negative indices in `axes` + for (SizeT axis = 0; axis < ndims; axis++) + { + // `i` cannot be OUT_OF_BOUNDS because of assertions + SizeT i = slice::resolve_index_in_length(ndims, axes[axis]); + + dst_ndarray->shape[axis] = src_ndarray->shape[i]; + dst_ndarray->strides[axis] = src_ndarray->strides[i]; + } + } +} +} // namespace transpose +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::transpose; + void __nac3_ndarray_transpose(const NDArray *src_ndarray, NDArray *dst_ndarray, int32_t num_axes, + const int32_t *axes) + { + transpose(src_ndarray, dst_ndarray, num_axes, axes); + } + + void __nac3_ndarray_transpose64(const NDArray *src_ndarray, NDArray *dst_ndarray, + int64_t num_axes, const int64_t *axes) + { + transpose(src_ndarray, dst_ndarray, num_axes, axes); + } +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 22540c7..6a9547c 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1214,3 +1214,20 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( .arg(dst_shape) .returning_void(); } + +pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, + num_axes: Instance<'ctx, Int>, + axes: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); + CallFunction::begin(generator, ctx, &name) + .arg(src_ndarray) + .arg(dst_ndarray) + .arg(num_axes) + .arg(axes) + .returning_void(); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d7ea4fb..cc351af 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1990,112 +1990,6 @@ pub fn gen_ndarray_fill<'ctx>( Ok(()) } -/// Generates LLVM IR for `ndarray.transpose`. -pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_transpose"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - - // Dimensions are reversed in the transposed array - let out = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &n1, - |_, ctx, n| Ok(n.load_ndims(ctx)), - |generator, ctx, n, idx| { - let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap(); - let new_idx = ctx - .builder - .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") - .unwrap(); - unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) } - }, - ) - .unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - - let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap(); - ctx.builder.build_store(rem_idx, idx).unwrap(); - - // Incrementally calculate the new index in the transposed array - // For each index, we first decompose it into the n-dims and use those to reconstruct the new index - // The formula used for indexing is: - // idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n1.load_ndims(ctx), false), - |generator, ctx, _, ndim| { - let ndim_rev = - ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap(); - let ndim_rev = ctx - .builder - .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") - .unwrap(); - let dim = unsafe { - n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None) - }; - - let rem_idx_val = - ctx.builder.build_load(rem_idx, "").unwrap().into_int_value(); - let new_idx_val = - ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - - let add_component = - ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap(); - let rem_idx_val = - ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap(); - - let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap(); - let new_idx_val = - ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap(); - - ctx.builder.build_store(rem_idx, rem_idx_val).unwrap(); - ctx.builder.build_store(new_idx, new_idx_val).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - unreachable!( - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` diff --git a/nac3core/src/codegen/object/ndarray/view.rs b/nac3core/src/codegen/object/ndarray/view.rs index f08e05a..bae0bbd 100644 --- a/nac3core/src/codegen/object/ndarray/view.rs +++ b/nac3core/src/codegen/object/ndarray/view.rs @@ -1,6 +1,7 @@ use crate::codegen::{ - irrt::call_nac3_ndarray_reshape_resolve_and_check_new_shape, model::*, CodeGenContext, - CodeGenerator, + irrt::{call_nac3_ndarray_reshape_resolve_and_check_new_shape, call_nac3_ndarray_transpose}, + model::*, + CodeGenContext, CodeGenerator, }; use super::{indexing::RustNDIndex, NDArrayObject}; @@ -86,4 +87,33 @@ impl<'ctx> NDArrayObject<'ctx> { dst_ndarray } + + /// Create a transposed view on this ndarray like `np.transpose(, = None)`. + /// * `axes` - If specified, should be an array of the permutation (negative indices are **allowed**). + #[must_use] + pub fn transpose( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + axes: Option>>>, + ) -> Self { + // Define models + let transposed_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims); + + let num_axes = self.ndims_llvm(generator, ctx.ctx); + + // `axes = nullptr` if `axes` is unspecified. + let axes = axes.unwrap_or_else(|| Ptr(Int(SizeT)).nullptr(generator, ctx.ctx)); + + call_nac3_ndarray_transpose( + generator, + ctx, + self.instance, + transposed_ndarray.instance, + num_axes, + axes, + ); + + transposed_ndarray + } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 485b634..7eae6a2 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1481,18 +1481,26 @@ impl<'a> BuiltinBuilder<'a> { ); match prim { - PrimDef::FunNpTranspose => create_fn_by_codegen( - self.unifier, - &into_var_map([in_ndarray_ty]), - prim.name(), - in_ndarray_ty.ty, - &[(in_ndarray_ty.ty, "x")], - Box::new(move |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) - }), - ), + PrimDef::FunNpTranspose => { + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + in_ndarray_ty.ty, + &[(in_ndarray_ty.ty, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let arg = AnyObject { ty: arg_ty, value: arg_val }; + let ndarray = NDArrayObject::from_object(generator, ctx, arg); + + let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument + Ok(Some(ndarray.instance.value.as_basic_value_enum())) + }), + ) + } // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and // the `param_ty` for `create_fn_by_codegen`.