forked from M-Labs/nac3
[core] codegen/ndarray: Implement np_transpose without axes argument
Based on 052b67c8
: core/ndstrides: implement np_transpose() (no axes
argument)
The IRRT implementation knows how to handle axes. But the argument is
not in NAC3 yet.
This commit is contained in:
parent
43e440d2fd
commit
7375983e0c
@ -11,4 +11,5 @@
|
|||||||
#include "irrt/ndarray/indexing.hpp"
|
#include "irrt/ndarray/indexing.hpp"
|
||||||
#include "irrt/ndarray/array.hpp"
|
#include "irrt/ndarray/array.hpp"
|
||||||
#include "irrt/ndarray/reshape.hpp"
|
#include "irrt/ndarray/reshape.hpp"
|
||||||
#include "irrt/ndarray/broadcast.hpp"
|
#include "irrt/ndarray/broadcast.hpp"
|
||||||
|
#include "irrt/ndarray/transpose.hpp"
|
143
nac3core/irrt/irrt/ndarray/transpose.hpp
Normal file
143
nac3core/irrt/irrt/ndarray/transpose.hpp
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/debug.hpp"
|
||||||
|
#include "irrt/exception.hpp"
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
#include "irrt/ndarray/def.hpp"
|
||||||
|
#include "irrt/slice.hpp"
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Notes on `np.transpose(<array>, <axes>)`
|
||||||
|
*
|
||||||
|
* TODO: `axes`, if specified, can actually contain negative indices,
|
||||||
|
* but it is not documented in numpy.
|
||||||
|
*
|
||||||
|
* Supporting it for now.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray::transpose {
|
||||||
|
/**
|
||||||
|
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
|
||||||
|
*
|
||||||
|
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
|
||||||
|
* is specified but the user, use this function to do assertions on it.
|
||||||
|
*
|
||||||
|
* @param ndims The number of dimensions of `<array>`
|
||||||
|
* @param num_axes Number of elements in `<axes>` as specified by the user.
|
||||||
|
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
|
||||||
|
* @param axes The user specified `<axes>`.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
|
||||||
|
if (ndims != num_axes) {
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Optimize this
|
||||||
|
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
|
||||||
|
for (SizeT i = 0; i < ndims; i++)
|
||||||
|
axe_specified[i] = false;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
|
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
||||||
|
if (axis == -1) {
|
||||||
|
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
|
||||||
|
NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (axe_specified[axis]) {
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
axe_specified[axis] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
|
||||||
|
*
|
||||||
|
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
|
||||||
|
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
|
||||||
|
*
|
||||||
|
* The transpose view created is returned by modifying `dst_ndarray`.
|
||||||
|
*
|
||||||
|
* The caller is responsible for setting up `dst_ndarray` before calling this function.
|
||||||
|
* Here is what this function expects from `dst_ndarray` when called:
|
||||||
|
* - `dst_ndarray->data` does not have to be initialized.
|
||||||
|
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||||
|
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
|
||||||
|
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
||||||
|
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||||
|
* When this function call ends:
|
||||||
|
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
||||||
|
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
||||||
|
* - `dst_ndarray->ndims` is unchanged
|
||||||
|
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
|
||||||
|
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
|
||||||
|
*
|
||||||
|
* @param src_ndarray The NDArray to build a transpose view on
|
||||||
|
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
|
||||||
|
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
|
||||||
|
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
|
||||||
|
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
|
||||||
|
const auto ndims = src_ndarray->ndims;
|
||||||
|
|
||||||
|
if (axes != nullptr)
|
||||||
|
assert_transpose_axes(ndims, num_axes, axes);
|
||||||
|
|
||||||
|
dst_ndarray->data = src_ndarray->data;
|
||||||
|
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||||
|
|
||||||
|
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
|
||||||
|
if (axes == nullptr) {
|
||||||
|
// `np.transpose(<array>, axes=None)`
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
|
||||||
|
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
|
||||||
|
* is reversing the order of strides and shape.
|
||||||
|
*
|
||||||
|
* This is a fast implementation to handle this special (but very common) case.
|
||||||
|
*/
|
||||||
|
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
|
||||||
|
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// `np.transpose(<array>, <axes>)`
|
||||||
|
|
||||||
|
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
// `i` cannot be OUT_OF_BOUNDS because of assertions
|
||||||
|
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
|
||||||
|
|
||||||
|
dst_ndarray->shape[axis] = src_ndarray->shape[i];
|
||||||
|
dst_ndarray->strides[axis] = src_ndarray->strides[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace ndarray::transpose
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
using namespace ndarray::transpose;
|
||||||
|
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
|
||||||
|
NDArray<int32_t>* dst_ndarray,
|
||||||
|
int32_t num_axes,
|
||||||
|
const int32_t* axes) {
|
||||||
|
transpose(src_ndarray, dst_ndarray, num_axes, axes);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
|
||||||
|
NDArray<int64_t>* dst_ndarray,
|
||||||
|
int64_t num_axes,
|
||||||
|
const int64_t* axes) {
|
||||||
|
transpose(src_ndarray, dst_ndarray, num_axes, axes);
|
||||||
|
}
|
||||||
|
}
|
@ -22,6 +22,7 @@ pub use broadcast::*;
|
|||||||
pub use indexing::*;
|
pub use indexing::*;
|
||||||
pub use iter::*;
|
pub use iter::*;
|
||||||
pub use reshape::*;
|
pub use reshape::*;
|
||||||
|
pub use transpose::*;
|
||||||
|
|
||||||
mod array;
|
mod array;
|
||||||
mod basic;
|
mod basic;
|
||||||
@ -29,6 +30,7 @@ mod broadcast;
|
|||||||
mod indexing;
|
mod indexing;
|
||||||
mod iter;
|
mod iter;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
|
mod transpose;
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns a
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns a
|
||||||
/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size.
|
/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size.
|
||||||
|
48
nac3core/src/codegen/irrt/ndarray/transpose.rs
Normal file
48
nac3core/src/codegen/irrt/ndarray/transpose.rs
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
use inkwell::{values::IntValue, AddressSpace};
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
expr::infer_and_call_function,
|
||||||
|
irrt::get_usize_dependent_function_name,
|
||||||
|
values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_transpose`.
|
||||||
|
///
|
||||||
|
/// Creates a transpose view of `src_ndarray` and writes the result to `dst_ndarray`.
|
||||||
|
///
|
||||||
|
/// `dst_ndarray` must fulfill the following preconditions:
|
||||||
|
///
|
||||||
|
/// - `dst_ndarray.ndims` must be initialized and must be equal to `src_ndarray.ndims`.
|
||||||
|
/// - `dst_ndarray.shape` must be allocated and may contain uninitialized values.
|
||||||
|
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
|
||||||
|
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: NDArrayValue<'ctx>,
|
||||||
|
dst_ndarray: NDArrayValue<'ctx>,
|
||||||
|
axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>,
|
||||||
|
) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize));
|
||||||
|
assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into()));
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose");
|
||||||
|
infer_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
src_ndarray.as_base_value().into(),
|
||||||
|
dst_ndarray.as_base_value().into(),
|
||||||
|
axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(),
|
||||||
|
axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| {
|
||||||
|
axes.base_ptr(ctx, generator)
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
@ -1307,114 +1307,6 @@ pub fn gen_ndarray_fill<'ctx>(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.transpose`.
|
|
||||||
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "ndarray_transpose";
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
|
||||||
let n_sz = n1.size(generator, ctx);
|
|
||||||
|
|
||||||
// Dimensions are reversed in the transposed array
|
|
||||||
let out = create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&n1,
|
|
||||||
|_, ctx, n| Ok(n.load_ndims(ctx)),
|
|
||||||
|generator, ctx, n, idx| {
|
|
||||||
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
|
|
||||||
let new_idx = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
|
|
||||||
.unwrap();
|
|
||||||
unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) }
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(n_sz, false),
|
|
||||||
|generator, ctx, _, idx| {
|
|
||||||
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
|
||||||
|
|
||||||
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
||||||
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
||||||
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
|
|
||||||
ctx.builder.build_store(rem_idx, idx).unwrap();
|
|
||||||
|
|
||||||
// Incrementally calculate the new index in the transposed array
|
|
||||||
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
|
|
||||||
// The formula used for indexing is:
|
|
||||||
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(n1.load_ndims(ctx), false),
|
|
||||||
|generator, ctx, _, ndim| {
|
|
||||||
let ndim_rev =
|
|
||||||
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
|
|
||||||
let ndim_rev = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
|
|
||||||
.unwrap();
|
|
||||||
let dim = unsafe {
|
|
||||||
n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let rem_idx_val =
|
|
||||||
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
|
|
||||||
let new_idx_val =
|
|
||||||
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
|
||||||
|
|
||||||
let add_component =
|
|
||||||
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
|
|
||||||
let rem_idx_val =
|
|
||||||
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
|
|
||||||
|
|
||||||
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
|
|
||||||
let new_idx_val =
|
|
||||||
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
|
|
||||||
|
|
||||||
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
|
|
||||||
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
|
||||||
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(out.as_base_value().into())
|
|
||||||
} else {
|
|
||||||
codegen_unreachable!(
|
|
||||||
ctx,
|
|
||||||
"{FN_NAME}() not supported for '{}'",
|
|
||||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.dot`.
|
/// Generates LLVM IR for `ndarray.dot`.
|
||||||
/// Calculate inner product of two vectors or literals
|
/// Calculate inner product of two vectors or literals
|
||||||
/// For matrix multiplication use `np_matmul`
|
/// For matrix multiplication use `np_matmul`
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use std::iter::{once, repeat_n};
|
use std::iter::{once, repeat_n};
|
||||||
|
|
||||||
use inkwell::values::IntValue;
|
use inkwell::values::{IntValue, PointerValue};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
@ -9,7 +9,7 @@ use crate::codegen::{
|
|||||||
types::ndarray::NDArrayType,
|
types::ndarray::NDArrayType,
|
||||||
values::{
|
values::{
|
||||||
ndarray::{NDArrayValue, RustNDIndex},
|
ndarray::{NDArrayValue, RustNDIndex},
|
||||||
ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor,
|
ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
@ -108,4 +108,50 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
|
|
||||||
dst_ndarray
|
dst_ndarray
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a transposed view on this ndarray like
|
||||||
|
/// [`np.transpose(<ndarray>, <axes> = None)`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html).
|
||||||
|
///
|
||||||
|
/// * `axes` - If specified, should be an array of the permutation (negative indices are
|
||||||
|
/// **allowed**).
|
||||||
|
#[must_use]
|
||||||
|
pub fn transpose<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
axes: Option<PointerValue<'ctx>>,
|
||||||
|
) -> Self {
|
||||||
|
assert!(self.ndims.is_some(), "NDArrayValue::transpose is only supported for instances with compile-time known ndims (self.ndims = Some(...))");
|
||||||
|
assert!(
|
||||||
|
axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into())
|
||||||
|
);
|
||||||
|
|
||||||
|
// Define models
|
||||||
|
let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None);
|
||||||
|
|
||||||
|
let axes = if let Some(axes) = axes {
|
||||||
|
let num_axes = self.llvm_usize.const_int(self.ndims.unwrap(), false);
|
||||||
|
|
||||||
|
// `axes = nullptr` if `axes` is unspecified.
|
||||||
|
let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None);
|
||||||
|
|
||||||
|
Some(TypedArrayLikeAdapter::from(
|
||||||
|
axes,
|
||||||
|
|_, _, val| val.into_int_value(),
|
||||||
|
|_, _, val| val.into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
irrt::ndarray::call_nac3_ndarray_transpose(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
*self,
|
||||||
|
transposed_ndarray,
|
||||||
|
axes.as_ref(),
|
||||||
|
);
|
||||||
|
|
||||||
|
transposed_ndarray
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1349,7 +1349,12 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
|
||||||
|
let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty)
|
||||||
|
.map_value(arg_val.into_pointer_value(), None);
|
||||||
|
|
||||||
|
let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument
|
||||||
|
Ok(Some(ndarray.as_base_value().into()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -210,6 +210,23 @@ def test_ndarray_nd_idx():
|
|||||||
output_float64(x[1, 0])
|
output_float64(x[1, 0])
|
||||||
output_float64(x[1, 1])
|
output_float64(x[1, 1])
|
||||||
|
|
||||||
|
def test_ndarray_transpose():
|
||||||
|
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
|
||||||
|
y = np_transpose(x)
|
||||||
|
z = np_transpose(y)
|
||||||
|
|
||||||
|
output_int32(np_shape(x)[0])
|
||||||
|
output_int32(np_shape(x)[1])
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
|
output_int32(np_shape(y)[0])
|
||||||
|
output_int32(np_shape(y)[1])
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
output_int32(np_shape(z)[0])
|
||||||
|
output_int32(np_shape(z)[1])
|
||||||
|
output_ndarray_float_2(z)
|
||||||
|
|
||||||
def test_ndarray_reshape():
|
def test_ndarray_reshape():
|
||||||
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
|
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
|
||||||
x = np_reshape(w, (1, 2, 1, -1))
|
x = np_reshape(w, (1, 2, 1, -1))
|
||||||
@ -1502,14 +1519,6 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
|||||||
output_ndarray_float_2(nextafter_x_zeros)
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
output_ndarray_float_2(nextafter_x_ones)
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
def test_ndarray_transpose():
|
|
||||||
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
|
|
||||||
y = np_transpose(x)
|
|
||||||
z = np_transpose(y)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
|
||||||
output_ndarray_float_2(y)
|
|
||||||
|
|
||||||
def test_ndarray_dot():
|
def test_ndarray_dot():
|
||||||
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
||||||
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
||||||
@ -1641,6 +1650,7 @@ def run() -> int32:
|
|||||||
test_ndarray_slices()
|
test_ndarray_slices()
|
||||||
test_ndarray_nd_idx()
|
test_ndarray_nd_idx()
|
||||||
|
|
||||||
|
test_ndarray_transpose()
|
||||||
test_ndarray_reshape()
|
test_ndarray_reshape()
|
||||||
test_ndarray_broadcast_to()
|
test_ndarray_broadcast_to()
|
||||||
|
|
||||||
@ -1807,7 +1817,6 @@ def run() -> int32:
|
|||||||
test_ndarray_nextafter_broadcast()
|
test_ndarray_nextafter_broadcast()
|
||||||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||||
test_ndarray_transpose()
|
|
||||||
|
|
||||||
test_ndarray_dot()
|
test_ndarray_dot()
|
||||||
test_ndarray_cholesky()
|
test_ndarray_cholesky()
|
||||||
|
Loading…
Reference in New Issue
Block a user