nac3/nac3core/irrt/irrt/ndarray/transpose.hpp
David Mak 7375983e0c [core] codegen/ndarray: Implement np_transpose without axes argument
Based on 052b67c8: core/ndstrides: implement np_transpose() (no axes
argument)

The IRRT implementation knows how to handle axes. But the argument is
not in NAC3 yet.
2025-01-03 15:15:08 +08:00

143 lines
5.7 KiB
C++

#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/slice.hpp"
/*
* Notes on `np.transpose(<array>, <axes>)`
*
* TODO: `axes`, if specified, can actually contain negative indices,
* but it is not documented in numpy.
*
* Supporting it for now.
*/
namespace {
namespace ndarray::transpose {
/**
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
*
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
* is specified but the user, use this function to do assertions on it.
*
* @param ndims The number of dimensions of `<array>`
* @param num_axes Number of elements in `<axes>` as specified by the user.
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
* @param axes The user specified `<axes>`.
*/
template<typename SizeT>
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
if (ndims != num_axes) {
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM);
}
// TODO: Optimize this
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
for (SizeT i = 0; i < ndims; i++)
axe_specified[i] = false;
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == -1) {
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
NO_PARAM);
}
if (axe_specified[axis]) {
raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM);
}
axe_specified[axis] = true;
}
}
/**
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
*
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
*
* The transpose view created is returned by modifying `dst_ndarray`.
*
* The caller is responsible for setting up `dst_ndarray` before calling this function.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
*
* @param src_ndarray The NDArray to build a transpose view on
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
*/
template<typename SizeT>
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
const auto ndims = src_ndarray->ndims;
if (axes != nullptr)
assert_transpose_axes(ndims, num_axes, axes);
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
if (axes == nullptr) {
// `np.transpose(<array>, axes=None)`
/*
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
* is reversing the order of strides and shape.
*
* This is a fast implementation to handle this special (but very common) case.
*/
for (SizeT axis = 0; axis < ndims; axis++) {
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
}
} else {
// `np.transpose(<array>, <axes>)`
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
for (SizeT axis = 0; axis < ndims; axis++) {
// `i` cannot be OUT_OF_BOUNDS because of assertions
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
dst_ndarray->shape[axis] = src_ndarray->shape[i];
dst_ndarray->strides[axis] = src_ndarray->strides[i];
}
}
}
} // namespace ndarray::transpose
} // namespace
extern "C" {
using namespace ndarray::transpose;
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray,
int32_t num_axes,
const int32_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray,
int64_t num_axes,
const int64_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
}