NDArray with strides + NDArrayObject + Models + Exceptions in IRRT. #506

Closed
lyken wants to merge 51 commits from ndstrides-intro into ndstrides
Showing only changes of commit fda7f8d827 - Show all commits

View File

@ -4,7 +4,7 @@
#include <irrt/math_util.hpp>
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;
@ -44,7 +44,7 @@ SizeT __nac3_ndarray_calc_size_impl(const SizeT *list_data, SizeT list_len, Size
}
template <typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT *dims, SizeT num_dims, NDIndex *idxs)
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT *dims, SizeT num_dims, NDIndexInt *idxs)
{
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++)
@ -57,7 +57,7 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT *dims, SizeT n
}
template <typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT *dims, SizeT num_dims, const NDIndex *indices, SizeT num_indices)
SizeT __nac3_ndarray_flatten_index_impl(const SizeT *dims, SizeT num_dims, const NDIndexInt *indices, SizeT num_indices)
{
SizeT idx = 0;
SizeT stride = 1;
@ -115,8 +115,8 @@ void __nac3_ndarray_calc_broadcast_impl(const SizeT *lhs_dims, SizeT lhs_ndims,
}
template <typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT *src_dims, SizeT src_ndims, const NDIndex *in_idx,
NDIndex *out_idx)
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT *src_dims, SizeT src_ndims, const NDIndexInt *in_idx,
NDIndexInt *out_idx)
{
for (SizeT i = 0; i < src_ndims; ++i)
{
@ -324,23 +324,23 @@ extern "C"
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t *dims, uint32_t num_dims, NDIndex *idxs)
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t *dims, uint32_t num_dims, NDIndexInt *idxs)
{
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t *dims, uint64_t num_dims, NDIndex *idxs)
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t *dims, uint64_t num_dims, NDIndexInt *idxs)
{
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t __nac3_ndarray_flatten_index(const uint32_t *dims, uint32_t num_dims, const NDIndex *indices,
uint32_t __nac3_ndarray_flatten_index(const uint32_t *dims, uint32_t num_dims, const NDIndexInt *indices,
uint32_t num_indices)
{
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(const uint64_t *dims, uint64_t num_dims, const NDIndex *indices,
uint64_t __nac3_ndarray_flatten_index64(const uint64_t *dims, uint64_t num_dims, const NDIndexInt *indices,
uint64_t num_indices)
{
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
@ -358,14 +358,14 @@ extern "C"
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t *src_dims, uint32_t src_ndims, const NDIndex *in_idx,
NDIndex *out_idx)
void __nac3_ndarray_calc_broadcast_idx(const uint32_t *src_dims, uint32_t src_ndims, const NDIndexInt *in_idx,
NDIndexInt *out_idx)
{
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t *src_dims, uint64_t src_ndims, const NDIndex *in_idx,
NDIndex *out_idx)
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t *src_dims, uint64_t src_ndims, const NDIndexInt *in_idx,
NDIndexInt *out_idx)
{
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}