1
0
forked from M-Labs/nac3

[core] codegen/ndarray: Implement np_transpose without axes argument

Based on 052b67c8: core/ndstrides: implement np_transpose() (no axes
argument)

The IRRT implementation knows how to handle axes. But the argument is
not in NAC3 yet.
This commit is contained in:
David Mak 2024-12-18 16:22:59 +08:00
parent 43e440d2fd
commit 7375983e0c
8 changed files with 267 additions and 121 deletions

View File

@ -12,3 +12,4 @@
#include "irrt/ndarray/array.hpp" #include "irrt/ndarray/array.hpp"
#include "irrt/ndarray/reshape.hpp" #include "irrt/ndarray/reshape.hpp"
#include "irrt/ndarray/broadcast.hpp" #include "irrt/ndarray/broadcast.hpp"
#include "irrt/ndarray/transpose.hpp"

View File

@ -0,0 +1,143 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/slice.hpp"
/*
* Notes on `np.transpose(<array>, <axes>)`
*
* TODO: `axes`, if specified, can actually contain negative indices,
* but it is not documented in numpy.
*
* Supporting it for now.
*/
namespace {
namespace ndarray::transpose {
/**
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
*
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
* is specified but the user, use this function to do assertions on it.
*
* @param ndims The number of dimensions of `<array>`
* @param num_axes Number of elements in `<axes>` as specified by the user.
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
* @param axes The user specified `<axes>`.
*/
template<typename SizeT>
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
if (ndims != num_axes) {
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM);
}
// TODO: Optimize this
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
for (SizeT i = 0; i < ndims; i++)
axe_specified[i] = false;
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == -1) {
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
NO_PARAM);
}
if (axe_specified[axis]) {
raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM);
}
axe_specified[axis] = true;
}
}
/**
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
*
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
*
* The transpose view created is returned by modifying `dst_ndarray`.
*
* The caller is responsible for setting up `dst_ndarray` before calling this function.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
*
* @param src_ndarray The NDArray to build a transpose view on
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
*/
template<typename SizeT>
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
const auto ndims = src_ndarray->ndims;
if (axes != nullptr)
assert_transpose_axes(ndims, num_axes, axes);
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
if (axes == nullptr) {
// `np.transpose(<array>, axes=None)`
/*
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
* is reversing the order of strides and shape.
*
* This is a fast implementation to handle this special (but very common) case.
*/
for (SizeT axis = 0; axis < ndims; axis++) {
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
}
} else {
// `np.transpose(<array>, <axes>)`
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
for (SizeT axis = 0; axis < ndims; axis++) {
// `i` cannot be OUT_OF_BOUNDS because of assertions
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
dst_ndarray->shape[axis] = src_ndarray->shape[i];
dst_ndarray->strides[axis] = src_ndarray->strides[i];
}
}
}
} // namespace ndarray::transpose
} // namespace
extern "C" {
using namespace ndarray::transpose;
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray,
int32_t num_axes,
const int32_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray,
int64_t num_axes,
const int64_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
}

View File

@ -22,6 +22,7 @@ pub use broadcast::*;
pub use indexing::*; pub use indexing::*;
pub use iter::*; pub use iter::*;
pub use reshape::*; pub use reshape::*;
pub use transpose::*;
mod array; mod array;
mod basic; mod basic;
@ -29,6 +30,7 @@ mod broadcast;
mod indexing; mod indexing;
mod iter; mod iter;
mod reshape; mod reshape;
mod transpose;
/// Generates a call to `__nac3_ndarray_calc_size`. Returns a /// Generates a call to `__nac3_ndarray_calc_size`. Returns a
/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. /// [`usize`][CodeGenerator::get_size_type] representing the calculated total size.

View File

@ -0,0 +1,48 @@
use inkwell::{values::IntValue, AddressSpace};
use crate::codegen::{
expr::infer_and_call_function,
irrt::get_usize_dependent_function_name,
values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor},
CodeGenContext, CodeGenerator,
};
/// Generates a call to `__nac3_ndarray_transpose`.
///
/// Creates a transpose view of `src_ndarray` and writes the result to `dst_ndarray`.
///
/// `dst_ndarray` must fulfill the following preconditions:
///
/// - `dst_ndarray.ndims` must be initialized and must be equal to `src_ndarray.ndims`.
/// - `dst_ndarray.shape` must be allocated and may contain uninitialized values.
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>,
axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize));
assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into()));
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose");
infer_and_call_function(
ctx,
&name,
None,
&[
src_ndarray.as_base_value().into(),
dst_ndarray.as_base_value().into(),
axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(),
axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| {
axes.base_ptr(ctx, generator)
})
.into(),
],
None,
None,
);
}

View File

@ -1307,114 +1307,6 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(()) Ok(())
} }
/// Generates LLVM IR for `ndarray.transpose`.
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_transpose";
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = llvm_ndarray_ty.map_value(n1, None);
let n_sz = n1.size(generator, ctx);
// Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&n1,
|_, ctx, n| Ok(n.load_ndims(ctx)),
|generator, ctx, n, idx| {
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
let new_idx = ctx
.builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap();
unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) }
},
)
.unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
ctx.builder.build_store(rem_idx, idx).unwrap();
// Incrementally calculate the new index in the transposed array
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
// The formula used for indexing is:
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1.load_ndims(ctx), false),
|generator, ctx, _, ndim| {
let ndim_rev =
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
let ndim_rev = ctx
.builder
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap();
let dim = unsafe {
n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None)
};
let rem_idx_val =
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
let new_idx_val =
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
let add_component =
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
let rem_idx_val =
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
let new_idx_val =
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// Generates LLVM IR for `ndarray.dot`. /// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals /// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul` /// For matrix multiplication use `np_matmul`

View File

@ -1,6 +1,6 @@
use std::iter::{once, repeat_n}; use std::iter::{once, repeat_n};
use inkwell::values::IntValue; use inkwell::values::{IntValue, PointerValue};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::{ use crate::codegen::{
@ -9,7 +9,7 @@ use crate::codegen::{
types::ndarray::NDArrayType, types::ndarray::NDArrayType,
values::{ values::{
ndarray::{NDArrayValue, RustNDIndex}, ndarray::{NDArrayValue, RustNDIndex},
ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -108,4 +108,50 @@ impl<'ctx> NDArrayValue<'ctx> {
dst_ndarray dst_ndarray
} }
/// Create a transposed view on this ndarray like
/// [`np.transpose(<ndarray>, <axes> = None)`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html).
///
/// * `axes` - If specified, should be an array of the permutation (negative indices are
/// **allowed**).
#[must_use]
pub fn transpose<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
axes: Option<PointerValue<'ctx>>,
) -> Self {
assert!(self.ndims.is_some(), "NDArrayValue::transpose is only supported for instances with compile-time known ndims (self.ndims = Some(...))");
assert!(
axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into())
);
// Define models
let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None);
let axes = if let Some(axes) = axes {
let num_axes = self.llvm_usize.const_int(self.ndims.unwrap(), false);
// `axes = nullptr` if `axes` is unspecified.
let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None);
Some(TypedArrayLikeAdapter::from(
axes,
|_, _, val| val.into_int_value(),
|_, _, val| val.into(),
))
} else {
None
};
irrt::ndarray::call_nac3_ndarray_transpose(
generator,
ctx,
*self,
transposed_ndarray,
axes.as_ref(),
);
transposed_ndarray
}
} }

View File

@ -1349,7 +1349,12 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty)
.map_value(arg_val.into_pointer_value(), None);
let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument
Ok(Some(ndarray.as_base_value().into()))
}), }),
), ),

View File

@ -210,6 +210,23 @@ def test_ndarray_nd_idx():
output_float64(x[1, 0]) output_float64(x[1, 0])
output_float64(x[1, 1]) output_float64(x[1, 1])
def test_ndarray_transpose():
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
y = np_transpose(x)
z = np_transpose(y)
output_int32(np_shape(x)[0])
output_int32(np_shape(x)[1])
output_ndarray_float_2(x)
output_int32(np_shape(y)[0])
output_int32(np_shape(y)[1])
output_ndarray_float_2(y)
output_int32(np_shape(z)[0])
output_int32(np_shape(z)[1])
output_ndarray_float_2(z)
def test_ndarray_reshape(): def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x = np_reshape(w, (1, 2, 1, -1)) x = np_reshape(w, (1, 2, 1, -1))
@ -1502,14 +1519,6 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_transpose():
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
y = np_transpose(x)
z = np_transpose(y)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_dot(): def test_ndarray_dot():
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
@ -1641,6 +1650,7 @@ def run() -> int32:
test_ndarray_slices() test_ndarray_slices()
test_ndarray_nd_idx() test_ndarray_nd_idx()
test_ndarray_transpose()
test_ndarray_reshape() test_ndarray_reshape()
test_ndarray_broadcast_to() test_ndarray_broadcast_to()
@ -1807,7 +1817,6 @@ def run() -> int32:
test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_transpose()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_cholesky() test_ndarray_cholesky()