core/ndstrides: implement np_transpose() (no axes argument)

The IRRT implementation knows how to handle axes. But the argument is
not in NAC3 yet.
This commit is contained in:
lyken 2024-08-20 16:35:20 +08:00
parent d32268fb5d
commit b9e837109b
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
6 changed files with 216 additions and 122 deletions

View File

@ -11,4 +11,5 @@
#include "irrt/ndarray/indexing.hpp"
#include "irrt/ndarray/array.hpp"
#include "irrt/ndarray/reshape.hpp"
#include "irrt/ndarray/broadcast.hpp"
#include "irrt/ndarray/broadcast.hpp"
#include "irrt/ndarray/transpose.hpp"

View File

@ -0,0 +1,145 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/slice.hpp"
/*
* Notes on `np.transpose(<array>, <axes>)`
*
* TODO: `axes`, if specified, can actually contain negative indices,
* but it is not documented in numpy.
*
* Supporting it for now.
*/
namespace {
namespace ndarray {
namespace transpose {
/**
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
*
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
* is specified but the user, use this function to do assertions on it.
*
* @param ndims The number of dimensions of `<array>`
* @param num_axes Number of elements in `<axes>` as specified by the user.
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
* @param axes The user specified `<axes>`.
*/
template<typename SizeT>
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
if (ndims != num_axes) {
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM);
}
// TODO: Optimize this
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
for (SizeT i = 0; i < ndims; i++)
axe_specified[i] = false;
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == -1) {
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
NO_PARAM);
}
if (axe_specified[axis]) {
raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM);
}
axe_specified[axis] = true;
}
}
/**
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
*
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
*
* The transpose view created is returned by modifying `dst_ndarray`.
*
* The caller is responsible for setting up `dst_ndarray` before calling this function.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
*
* @param src_ndarray The NDArray to build a transpose view on
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
*/
template<typename SizeT>
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
const auto ndims = src_ndarray->ndims;
if (axes != nullptr)
assert_transpose_axes(ndims, num_axes, axes);
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
if (axes == nullptr) {
// `np.transpose(<array>, axes=None)`
/*
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
* is reversing the order of strides and shape.
*
* This is a fast implementation to handle this special (but very common) case.
*/
for (SizeT axis = 0; axis < ndims; axis++) {
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
}
} else {
// `np.transpose(<array>, <axes>)`
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
for (SizeT axis = 0; axis < ndims; axis++) {
// `i` cannot be OUT_OF_BOUNDS because of assertions
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
dst_ndarray->shape[axis] = src_ndarray->shape[i];
dst_ndarray->strides[axis] = src_ndarray->strides[i];
}
}
}
} // namespace transpose
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::transpose;
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray,
int32_t num_axes,
const int32_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray,
int64_t num_axes,
const int64_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
}

View File

@ -1203,3 +1203,20 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
.arg(dst_shape)
.returning_void();
}
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
num_axes: Instance<'ctx, Int<SizeT>>,
axes: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose");
FnCall::builder(generator, ctx, &name)
.arg(src_ndarray)
.arg(dst_ndarray)
.arg(num_axes)
.arg(axes)
.returning_void();
}

View File

@ -1991,113 +1991,6 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(())
}
/// Generates LLVM IR for `ndarray.transpose`.
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_transpose";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
// Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&n1,
|_, ctx, n| Ok(n.load_ndims(ctx)),
|generator, ctx, n, idx| {
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
let new_idx = ctx
.builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap();
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
},
)
.unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
ctx.builder.build_store(rem_idx, idx).unwrap();
// Incrementally calculate the new index in the transposed array
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
// The formula used for indexing is:
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1.load_ndims(ctx), false),
|generator, ctx, _, ndim| {
let ndim_rev =
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
let ndim_rev = ctx
.builder
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap();
let dim = unsafe {
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
};
let rem_idx_val =
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
let new_idx_val =
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
let add_component =
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
let rem_idx_val =
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
let new_idx_val =
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`

View File

@ -1,6 +1,7 @@
use crate::codegen::{
irrt::call_nac3_ndarray_reshape_resolve_and_check_new_shape, model::*, CodeGenContext,
CodeGenerator,
irrt::{call_nac3_ndarray_reshape_resolve_and_check_new_shape, call_nac3_ndarray_transpose},
model::*,
CodeGenContext, CodeGenerator,
};
use super::{indexing::RustNDIndex, NDArrayObject};
@ -86,4 +87,33 @@ impl<'ctx> NDArrayObject<'ctx> {
dst_ndarray
}
/// Create a transposed view on this ndarray like `np.transpose(<ndarray>, <axes> = None)`.
/// * `axes` - If specified, should be an array of the permutation (negative indices are **allowed**).
#[must_use]
pub fn transpose<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
axes: Option<Instance<'ctx, Ptr<Int<SizeT>>>>,
) -> Self {
// Define models
let transposed_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims);
let num_axes = self.ndims_llvm(generator, ctx.ctx);
// `axes = nullptr` if `axes` is unspecified.
let axes = axes.unwrap_or_else(|| Ptr(Int(SizeT)).nullptr(generator, ctx.ctx));
call_nac3_ndarray_transpose(
generator,
ctx,
self.instance,
transposed_ndarray.instance,
num_axes,
axes,
);
transposed_ndarray
}
}

View File

@ -1481,18 +1481,26 @@ impl<'a> BuiltinBuilder<'a> {
);
match prim {
PrimDef::FunNpTranspose => create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
in_ndarray_ty.ty,
&[(in_ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
),
PrimDef::FunNpTranspose => {
create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
in_ndarray_ty.ty,
&[(in_ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg_val };
let ndarray = NDArrayObject::from_object(generator, ctx, arg);
let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
}),
)
}
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.