forked from M-Labs/nac3
1
0
Fork 0

core/irrt: rename NDIndex to NDIndexInt

Unfortunately the name `NDIndex` is used in later commits. Renaming this
typedef to `NDIndexInt` to avoid amending. `NDIndexInt` will be removed
anyway when ndarray strides is completed.
This commit is contained in:
lyken 2024-08-15 22:28:23 +08:00
parent b6a1880226
commit 853fa39537
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
2 changed files with 19 additions and 14 deletions

View File

@ -8,6 +8,6 @@ using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64); using uint64_t = unsigned _BitInt(64);
// NDArray indices are always `uint32_t`. // NDArray indices are always `uint32_t`.
using NDIndex = uint32_t; using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`. // The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t; using SliceIndex = int32_t;

View File

@ -19,7 +19,7 @@ SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, Size
} }
template<typename SizeT> template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) { void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) {
SizeT stride = 1; SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) { for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1; SizeT i = num_dims - dim - 1;
@ -30,7 +30,10 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n
} }
template<typename SizeT> template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) { SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims,
SizeT num_dims,
const NDIndexInt* indices,
SizeT num_indices) {
SizeT idx = 0; SizeT idx = 0;
SizeT stride = 1; SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) { for (SizeT i = 0; i < num_dims; ++i) {
@ -77,8 +80,8 @@ void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
template<typename SizeT> template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims, void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims, SizeT src_ndims,
const NDIndex* in_idx, const NDIndexInt* in_idx,
NDIndex* out_idx) { NDIndexInt* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) { for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1; SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
@ -96,21 +99,23 @@ __nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
} }
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) { void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) { void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
uint32_t uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) { __nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
} }
uint64_t uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims,
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) { uint64_t num_dims,
const NDIndexInt* indices,
uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
} }
@ -132,15 +137,15 @@ void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims, void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims, uint32_t src_ndims,
const NDIndex* in_idx, const NDIndexInt* in_idx,
NDIndex* out_idx) { NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
} }
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims, uint64_t src_ndims,
const NDIndex* in_idx, const NDIndexInt* in_idx,
NDIndex* out_idx) { NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
} }
} }