#pragma once #include "irrt_utils.hpp" #include "irrt_typedefs.hpp" #include "irrt_slice.hpp" /* NDArray-related implementations. `*/ // NDArray indices are always `uint32_t`. using NDIndex = uint32_t; namespace { namespace ndarray_util { // Compute the strides of an ndarray given an ndarray `shape` // and assuming that the ndarray is *fully C-contagious*. // // You might want to read up on https://ajcr.net/stride-guide-part-1/. template static void set_strides_by_shape(SizeT ndims, SizeT* dst_strides, const SizeT* shape) { SizeT stride_product = 1; for (SizeT i = 0; i < ndims; i++) { int dim_i = ndims - i - 1; dst_strides[dim_i] = stride_product; stride_product *= shape[dim_i]; } } // Compute the size/# of elements of an ndarray given its shape template static SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) { SizeT size = 1; for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i]; return size; } } typedef uint8_t NDSliceType; extern "C" { const NDSliceType INPUT_SLICE_TYPE_INTEGER = 0; const NDSliceType INPUT_SLICE_TYPE_SLICE = 1; } struct NDSlice { NDSliceType type; /* type = INPUT_SLICE_TYPE_INTEGER => `slice` points to a single `SizeT` type = INPUT_SLICE_TYPE_SLICE => `slice` points to a single `NDSliceRange` */ uint8_t *slice; }; template SizeT deduce_ndims_after_slicing(SizeT ndims, const SizeT num_slices, const NDSlice *slices) { nac3_assert(num_slices <= ndims); SizeT final_ndims = ndims; for (SizeT i = 0; i < num_slices; i++) { if (slices[i].type == INPUT_SLICE_TYPE_INTEGER) { final_ndims--; // An integer slice demotes the rank by 1 } } return final_ndims; } template struct NDArrayIndicesIter { SizeT ndims; const SizeT *shape; SizeT *indices; void set_indices_zero() { __builtin_memset(indices, 0, sizeof(SizeT) * ndims); } void next() { for (SizeT i = 0; i < ndims; i++) { SizeT dim_i = ndims - i - 1; indices[dim_i]++; if (indices[dim_i] < shape[dim_i]) { break; } else { indices[dim_i] = 0; } } } }; // The NDArray object. `SizeT` is the *signed* size type of this ndarray. // // NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT // // Some resources you might find helpful: // - The official numpy implementations: // - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst // - On strides (about reshaping, slicing, C-contagiousness, etc) // - https://ajcr.net/stride-guide-part-1/. // - https://ajcr.net/stride-guide-part-2/. // - https://ajcr.net/stride-guide-part-3/. template struct NDArray { // The underlying data this `ndarray` is pointing to. // // NOTE: Formally this should be of type `void *`, but clang // translates `void *` to `i8 *` when run with `-S -emit-llvm`, // so we will put `uint8_t *` here for clarity. uint8_t *data; // The number of bytes of a single element in `data`. // // The `SizeT` is treated as `unsigned`. SizeT itemsize; // The number of dimensions of this shape. // // The `SizeT` is treated as `unsigned`. SizeT ndims; // Array shape, with length equal to `ndims`. // // The `SizeT` is treated as `unsigned`. // // NOTE: `shape` can contain 0. // (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`) SizeT *shape; // Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`. // // The `SizeT` is treated as `signed`. // // NOTE: `strides` can have negative numbers. // (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`) SizeT *strides; // Calculate the size/# of elements of an `ndarray`. // This function corresponds to `np.size()` or `ndarray.size` SizeT size() { return ndarray_util::calc_size_from_shape(ndims, shape); } // Calculate the number of bytes of its content of an `ndarray` *in its view*. // This function corresponds to `ndarray.nbytes` SizeT nbytes() { return this->size() * itemsize; } void set_value_at_pelement(uint8_t* pelement, uint8_t* pvalue) { __builtin_memcpy(pelement, pvalue, itemsize); } uint8_t* get_pelement(SizeT *indices) { uint8_t* element = data; for (SizeT dim_i = 0; dim_i < ndims; dim_i++) element += indices[dim_i] * strides[dim_i] * itemsize; return element; } // Is the given `indices` valid/in-bounds? bool in_bounds(SizeT *indices) { for (SizeT dim_i = 0; dim_i < ndims; dim_i++) { bool dim_ok = indices[dim_i] < shape[dim_i]; if (!dim_ok) return false; } return true; } // Fill the ndarray with a value void fill_generic(uint8_t* pvalue) { NDArrayIndicesIter iter; iter.ndims = this->ndims; iter.shape = this->shape; iter.indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndims); iter.set_indices_zero(); for (SizeT i = 0; i < this->size(); i++, iter.next()) { uint8_t* pelement = get_pelement(iter.indices); set_value_at_pelement(pelement, pvalue); } } // Set the strides of the ndarray with `ndarray_util::set_strides_by_shape` void set_strides_by_shape() { ndarray_util::set_strides_by_shape(ndims, strides, shape); } // https://numpy.org/doc/stable/reference/generated/numpy.eye.html void set_to_eye(SizeT k, uint8_t* zero_pvalue, uint8_t* one_pvalue) { __builtin_assume(ndims == 2); // TODO: Better implementation fill_generic(zero_pvalue); for (SizeT i = 0; i < min(shape[0], shape[1]); i++) { SizeT row = i; SizeT col = i + k; SizeT indices[2] = { row, col }; if (!in_bounds(indices)) continue; uint8_t* pelement = get_pelement(indices); set_value_at_pelement(pelement, one_pvalue); } } // To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`) void slice(SizeT num_slices, NDSlice* slices, NDArray*dst_ndarray) { // It is assumed that `dst_ndarray` is allocated by the caller and // has the correct `ndims`. nac3_assert(dst_ndarray->ndims == deduce_ndims_after_slicing(this->ndims, num_slices, slices)); SizeT this_axis = 0; SizeT guest_axis = 0; // for () { // } } }; } extern "C" { uint32_t __nac3_ndarray_size(NDArray* ndarray) { return ndarray->size(); } uint64_t __nac3_ndarray_size64(NDArray* ndarray) { return ndarray->size(); } void __nac3_ndarray_fill_generic(NDArray* ndarray, uint8_t* pvalue) { ndarray->fill_generic(pvalue); } void __nac3_ndarray_fill_generic64(NDArray* ndarray, uint8_t* pvalue) { ndarray->fill_generic(pvalue); } // void __nac3_ndarray_slice(NDArray* ndarray, int32_t num_slices, NDSlice *slices, NDArray *dst_ndarray) { // // ndarray->slice(num_slices, slices, dst_ndarray); // } }