From 31ab9675ca6103a4202ccb58ed597e4cd70cd699 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 10 Jul 2024 11:56:31 +0800 Subject: [PATCH] core: more irrt --- nac3core/irrt/irrt.cpp | 6 +- nac3core/irrt/irrt_basic.hpp | 217 ++++++++++++++++++++++++ nac3core/irrt/irrt_everything.hpp | 14 ++ nac3core/irrt/irrt_numpy_ndarray.hpp | 242 ++++++++++++++++++++++++++ nac3core/irrt/irrt_slice.hpp | 65 +++++++ nac3core/irrt/irrt_test.cpp | 245 ++++++++++++++++++++++++--- nac3core/irrt/irrt_typedefs.hpp | 12 ++ nac3core/irrt/irrt_utils.hpp | 27 +++ 8 files changed, 804 insertions(+), 24 deletions(-) create mode 100644 nac3core/irrt/irrt_basic.hpp create mode 100644 nac3core/irrt/irrt_everything.hpp create mode 100644 nac3core/irrt/irrt_numpy_ndarray.hpp create mode 100644 nac3core/irrt/irrt_slice.hpp create mode 100644 nac3core/irrt/irrt_typedefs.hpp create mode 100644 nac3core/irrt/irrt_utils.hpp diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 4cde95b7..d92b497c 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,3 +1,5 @@ -#include "irrt.hpp" +#include "irrt_everything.hpp" -// All the implementations are from `irrt.hpp` +/* + This file will be read by `clang-irrt` to conveniently produce LLVM IR for `nac3core/codegen`. +*/ diff --git a/nac3core/irrt/irrt_basic.hpp b/nac3core/irrt/irrt_basic.hpp new file mode 100644 index 00000000..08214927 --- /dev/null +++ b/nac3core/irrt/irrt_basic.hpp @@ -0,0 +1,217 @@ +#pragma once + +#include "irrt_utils.hpp" +#include "irrt_typedefs.hpp" + +/* + This header contains IRRT implementations + that do not deserved to be categorized (e.g., into numpy, etc.) + + Check out other *.hpp files before including them here!! +*/ + +// The type of an index or a value describing the length of a range/slice is +// always `int32_t`. +typedef int32_t SliceIndex; + +namespace { + // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c + // need to make sure `exp >= 0` before calling this function + template + T __nac3_int_exp_impl(T base, T exp) { + T res = 1; + /* repeated squaring method */ + do { + if (exp & 1) { + res *= base; /* for n odd */ + } + exp >>= 1; + base *= base; + } while (exp); + return res; + } +} + +extern "C" { + #define DEF_nac3_int_exp_(T) \ + T __nac3_int_exp_##T(T base, T exp) {\ + return __nac3_int_exp_impl(base, exp);\ + } + + DEF_nac3_int_exp_(int32_t) + DEF_nac3_int_exp_(int64_t) + DEF_nac3_int_exp_(uint32_t) + DEF_nac3_int_exp_(uint64_t) + + SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) { + if (i < 0) { + i = len + i; + } + if (i < 0) { + return 0; + } else if (i > len) { + return len; + } + return i; + } + + SliceIndex __nac3_range_slice_len( + const SliceIndex start, + const SliceIndex end, + const SliceIndex step + ) { + SliceIndex diff = end - start; + if (diff > 0 && step > 0) { + return ((diff - 1) / step) + 1; + } else if (diff < 0 && step < 0) { + return ((diff + 1) / step) + 1; + } else { + return 0; + } + } + + // Handle list assignment and dropping part of the list when + // both dest_step and src_step are +1. + // - All the index must *not* be out-of-bound or negative, + // - The end index is *inclusive*, + // - The length of src and dest slice size should already + // be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest) + SliceIndex __nac3_list_slice_assign_var_size( + SliceIndex dest_start, + SliceIndex dest_end, + SliceIndex dest_step, + uint8_t *dest_arr, + SliceIndex dest_arr_len, + SliceIndex src_start, + SliceIndex src_end, + SliceIndex src_step, + uint8_t *src_arr, + SliceIndex src_arr_len, + const SliceIndex size + ) { + /* if dest_arr_len == 0, do nothing since we do not support extending list */ + if (dest_arr_len == 0) return dest_arr_len; + /* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */ + if (src_step == dest_step && dest_step == 1) { + const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0; + const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0; + if (src_len > 0) { + __builtin_memmove( + dest_arr + dest_start * size, + src_arr + src_start * size, + src_len * size + ); + } + if (dest_len > 0) { + /* dropping */ + __builtin_memmove( + dest_arr + (dest_start + src_len) * size, + dest_arr + (dest_end + 1) * size, + (dest_arr_len - dest_end - 1) * size + ); + } + /* shrink size */ + return dest_arr_len - (dest_len - src_len); + } + /* if two range overlaps, need alloca */ + uint8_t need_alloca = + (dest_arr == src_arr) + && !( + max(dest_start, dest_end) < min(src_start, src_end) + || max(src_start, src_end) < min(dest_start, dest_end) + ); + if (need_alloca) { + uint8_t *tmp = reinterpret_cast(__builtin_alloca(src_arr_len * size)); + __builtin_memcpy(tmp, src_arr, src_arr_len * size); + src_arr = tmp; + } + SliceIndex src_ind = src_start; + SliceIndex dest_ind = dest_start; + for (; + (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); + src_ind += src_step, dest_ind += dest_step + ) { + /* for constant optimization */ + if (size == 1) { + __builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1); + } else if (size == 4) { + __builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4); + } else if (size == 8) { + __builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8); + } else { + /* memcpy for var size, cannot overlap after previous alloca */ + __builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size); + } + } + /* only dest_step == 1 can we shrink the dest list. */ + /* size should be ensured prior to calling this function */ + if (dest_step == 1 && dest_end >= dest_start) { + __builtin_memmove( + dest_arr + dest_ind * size, + dest_arr + (dest_end + 1) * size, + (dest_arr_len - dest_end - 1) * size + ); + return dest_arr_len - (dest_end - dest_ind) - 1; + } + return dest_arr_len; + } + + int32_t __nac3_isinf(double x) { + return __builtin_isinf(x); + } + + int32_t __nac3_isnan(double x) { + return __builtin_isnan(x); + } + + double tgamma(double arg); + + double __nac3_gamma(double z) { + // Handling for denormals + // | x | Python gamma(x) | C tgamma(x) | + // --- | ----------------- | --------------- | ----------- | + // (1) | nan | nan | nan | + // (2) | -inf | -inf | inf | + // (3) | inf | inf | inf | + // (4) | 0.0 | inf | inf | + // (5) | {-1.0, -2.0, ...} | inf | nan | + + // (1)-(3) + if (__builtin_isinf(z) || __builtin_isnan(z)) { + return z; + } + + double v = tgamma(z); + + // (4)-(5) + return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v; + } + + double lgamma(double arg); + + double __nac3_gammaln(double x) { + // libm's handling of value overflows differs from scipy: + // - scipy: gammaln(-inf) -> -inf + // - libm : lgamma(-inf) -> inf + + if (__builtin_isinf(x)) { + return x; + } + + return lgamma(x); + } + + double j0(double x); + + double __nac3_j0(double x) { + // libm's handling of value overflows differs from scipy: + // - scipy: j0(inf) -> nan + // - libm : j0(inf) -> 0.0 + + if (__builtin_isinf(x)) { + return __builtin_nan(""); + } + + return j0(x); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp new file mode 100644 index 00000000..81e0bdc8 --- /dev/null +++ b/nac3core/irrt/irrt_everything.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "irrt_utils.hpp" +#include "irrt_typedefs.hpp" +#include "irrt_basic.hpp" +#include "irrt_slice.hpp" +#include "irrt_numpy_ndarray.hpp" + +/* + All IRRT implementations. + + We don't have any pre-compiled objects, so we are writing all implementations in headers and + concatenate them with `#include` into one massive source file that contains all the IRRT stuff. +*/ \ No newline at end of file diff --git a/nac3core/irrt/irrt_numpy_ndarray.hpp b/nac3core/irrt/irrt_numpy_ndarray.hpp new file mode 100644 index 00000000..8a4784be --- /dev/null +++ b/nac3core/irrt/irrt_numpy_ndarray.hpp @@ -0,0 +1,242 @@ +#pragma once + +#include "irrt_utils.hpp" +#include "irrt_typedefs.hpp" +#include "irrt_slice.hpp" + +/* + NDArray-related implementations. +`*/ + +// NDArray indices are always `uint32_t`. +using NDIndex = uint32_t; + +namespace { + namespace ndarray_util { + // Compute the strides of an ndarray given an ndarray `shape` + // and assuming that the ndarray is *fully C-contagious*. + // + // You might want to read up on https://ajcr.net/stride-guide-part-1/. + template + static void set_strides_by_shape(SizeT ndims, SizeT* dst_strides, const SizeT* shape) { + SizeT stride_product = 1; + for (SizeT i = 0; i < ndims; i++) { + int dim_i = ndims - i - 1; + dst_strides[dim_i] = stride_product; + stride_product *= shape[dim_i]; + } + } + + // Compute the size/# of elements of an ndarray given its shape + template + static SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) { + SizeT size = 1; + for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i]; + return size; + } + } + + typedef uint8_t NDSliceType; + extern "C" { + const NDSliceType INPUT_SLICE_TYPE_INTEGER = 0; + const NDSliceType INPUT_SLICE_TYPE_SLICE = 1; + } + + struct NDSlice { + NDSliceType type; + + /* + type = INPUT_SLICE_TYPE_INTEGER => `slice` points to a single `SizeT` + type = INPUT_SLICE_TYPE_SLICE => `slice` points to a single `NDSliceRange` + */ + uint8_t *slice; + }; + + template + SizeT deduce_ndims_after_slicing(SizeT ndims, const SizeT num_slices, const NDSlice *slices) { + nac3_assert(num_slices <= ndims); + + SizeT final_ndims = ndims; + for (SizeT i = 0; i < num_slices; i++) { + if (slices[i].type == INPUT_SLICE_TYPE_INTEGER) { + final_ndims--; // An integer slice demotes the rank by 1 + } + } + return final_ndims; + } + + template + struct NDArrayIndicesIter { + SizeT ndims; + const SizeT *shape; + SizeT *indices; + + void set_indices_zero() { + __builtin_memset(indices, 0, sizeof(SizeT) * ndims); + } + + void next() { + for (SizeT i = 0; i < ndims; i++) { + SizeT dim_i = ndims - i - 1; + + indices[dim_i]++; + if (indices[dim_i] < shape[dim_i]) { + break; + } else { + indices[dim_i] = 0; + } + } + } + }; + + // The NDArray object. `SizeT` is the *signed* size type of this ndarray. + // + // NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT + // + // Some resources you might find helpful: + // - The official numpy implementations: + // - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst + // - On strides (about reshaping, slicing, C-contagiousness, etc) + // - https://ajcr.net/stride-guide-part-1/. + // - https://ajcr.net/stride-guide-part-2/. + // - https://ajcr.net/stride-guide-part-3/. + template + struct NDArray { + // The underlying data this `ndarray` is pointing to. + // + // NOTE: Formally this should be of type `void *`, but clang + // translates `void *` to `i8 *` when run with `-S -emit-llvm`, + // so we will put `uint8_t *` here for clarity. + uint8_t *data; + + // The number of bytes of a single element in `data`. + // + // The `SizeT` is treated as `unsigned`. + SizeT itemsize; + + // The number of dimensions of this shape. + // + // The `SizeT` is treated as `unsigned`. + SizeT ndims; + + // Array shape, with length equal to `ndims`. + // + // The `SizeT` is treated as `unsigned`. + // + // NOTE: `shape` can contain 0. + // (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`) + SizeT *shape; + + // Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`. + // + // The `SizeT` is treated as `signed`. + // + // NOTE: `strides` can have negative numbers. + // (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`) + SizeT *strides; + + // Calculate the size/# of elements of an `ndarray`. + // This function corresponds to `np.size()` or `ndarray.size` + SizeT size() { + return ndarray_util::calc_size_from_shape(ndims, shape); + } + + // Calculate the number of bytes of its content of an `ndarray` *in its view*. + // This function corresponds to `ndarray.nbytes` + SizeT nbytes() { + return this->size() * itemsize; + } + + void set_value_at_pelement(uint8_t* pelement, uint8_t* pvalue) { + __builtin_memcpy(pelement, pvalue, itemsize); + } + + uint8_t* get_pelement(SizeT *indices) { + uint8_t* element = data; + for (SizeT dim_i = 0; dim_i < ndims; dim_i++) + element += indices[dim_i] * strides[dim_i] * itemsize; + return element; + } + + // Is the given `indices` valid/in-bounds? + bool in_bounds(SizeT *indices) { + for (SizeT dim_i = 0; dim_i < ndims; dim_i++) { + bool dim_ok = indices[dim_i] < shape[dim_i]; + if (!dim_ok) return false; + } + return true; + } + + // Fill the ndarray with a value + void fill_generic(uint8_t* pvalue) { + NDArrayIndicesIter iter; + iter.ndims = this->ndims; + iter.shape = this->shape; + iter.indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndims); + iter.set_indices_zero(); + + for (SizeT i = 0; i < this->size(); i++, iter.next()) { + uint8_t* pelement = get_pelement(iter.indices); + set_value_at_pelement(pelement, pvalue); + } + } + + // Set the strides of the ndarray with `ndarray_util::set_strides_by_shape` + void set_strides_by_shape() { + ndarray_util::set_strides_by_shape(ndims, strides, shape); + } + + // https://numpy.org/doc/stable/reference/generated/numpy.eye.html + void set_to_eye(SizeT k, uint8_t* zero_pvalue, uint8_t* one_pvalue) { + __builtin_assume(ndims == 2); + + // TODO: Better implementation + + fill_generic(zero_pvalue); + for (SizeT i = 0; i < min(shape[0], shape[1]); i++) { + SizeT row = i; + SizeT col = i + k; + SizeT indices[2] = { row, col }; + + if (!in_bounds(indices)) continue; + + uint8_t* pelement = get_pelement(indices); + set_value_at_pelement(pelement, one_pvalue); + } + } + + // To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`) + void slice(SizeT num_slices, NDSlice* slices, NDArray*dst_ndarray) { + // It is assumed that `dst_ndarray` is allocated by the caller and + // has the correct `ndims`. + nac3_assert(dst_ndarray->ndims == deduce_ndims_after_slicing(this->ndims, num_slices, slices)); + + SizeT this_axis = 0; + SizeT guest_axis = 0; + // for () { + // } + } + }; +} + +extern "C" { + uint32_t __nac3_ndarray_size(NDArray* ndarray) { + return ndarray->size(); + } + + uint64_t __nac3_ndarray_size64(NDArray* ndarray) { + return ndarray->size(); + } + + void __nac3_ndarray_fill_generic(NDArray* ndarray, uint8_t* pvalue) { + ndarray->fill_generic(pvalue); + } + + void __nac3_ndarray_fill_generic64(NDArray* ndarray, uint8_t* pvalue) { + ndarray->fill_generic(pvalue); + } + + // void __nac3_ndarray_slice(NDArray* ndarray, int32_t num_slices, NDSlice *slices, NDArray *dst_ndarray) { + // // ndarray->slice(num_slices, slices, dst_ndarray); + // } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_slice.hpp b/nac3core/irrt/irrt_slice.hpp new file mode 100644 index 00000000..02802d44 --- /dev/null +++ b/nac3core/irrt/irrt_slice.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include "irrt_utils.hpp" +#include "irrt_typedefs.hpp" + +namespace { + // A proper slice in IRRT, all negative indices have be resolved to absolute values. + template + struct Slice { + T start; + T stop; + T step; + }; + + template + T resolve_index_in_length(T length, T index) { + nac3_assert(length >= 0); + if (index < 0) { + // Remember that index is negative, so do a plus here + return max(length + index, 0); + } else { + return min(length, index); + } + } + + // NOTE: using a bitfield for the `*_defined` is better, at the + // cost of a more annoying implementation in nac3core inkwell + template + struct UserSlice { + uint8_t start_defined; + T start; + + uint8_t stop_defined; + T stop; + + uint8_t step_defined; + T step; + + // Like Python's `slice(start, stop, step).indices(length)` + Slice indices(T length) { + // NOTE: This function implements Python's `slice.indices` *FAITHFULLY*. + // SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546 + nac3_assert(length >= 0); + nac3_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user + + Slice result; + result.step = step_defined ? step : 1; + bool step_is_negative = result.step < 0; + + if (start_defined) { + result.start = resolve_index_in_length(length, start); + } else { + result.start = step_is_negative ? length - 1 : 0; + } + + if (stop_defined) { + result.stop = resolve_index_in_length(length, stop); + } else { + result.stop = step_is_negative ? -1 : length; + } + + return result; + } + }; +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index edd3f5e6..e541865a 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -2,17 +2,20 @@ #include #include +// set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` has it all #define IRRT_DONT_TYPEDEF_INTS -#include "irrt.hpp" +#include "irrt_everything.hpp" -static void __test_fail(const char *file, int line) { - // NOTE: Try to make the location info follow a format that - // VSCode/other IDEs would recognize as a clickable URL. - printf("[!] test_fail() invoked at %s:%d", file, line); +void test_fail() { + printf("[!] Test failed\n"); exit(1); } -#define test_fail() __test_fail(__FILE__, __LINE__); +void __begin_test(const char* function_name, const char* file, int line) { + printf("######### Running %s @ %s:%d\n", function_name, file, line); +} + +#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__) template bool arrays_match(int len, T *as, T *bs) { @@ -23,40 +26,238 @@ bool arrays_match(int len, T *as, T *bs) { } template -void debug_print_array(const char* format, int len, T *as) { +void debug_print_array(const char* format, int len, T* as) { printf("["); for (int i = 0; i < len; i++) { if (i != 0) printf(", "); printf(format, as[i]); } - printf("]\n"); + printf("]"); } template -bool assert_arrays_match(const char *label, const char *format, int len, T *expected, T *got) { - auto match = arrays_match(len, expected, got); - - if (!match) { +void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) { + if (!arrays_match(len, expected, got)) { printf("expected %s: ", label); debug_print_array(format, len, expected); + printf("\n"); printf("got %s: ", label); debug_print_array(format, len, got); + printf("\n"); + test_fail(); } - - return match; } -static void test_strides_from_shape() { - const uint64_t ndims = 4; - uint64_t shape[ndims] = { 999, 3, 5, 7 }; - uint64_t strides[ndims] = { 0 }; - __nac3_ndarray_strides_from_shape64(ndims, shape, strides); +template +void assert_values_match(const char* label, const char* format, T expected, T got) { + if (expected != got) { + printf("expected %s: ", label); + printf(format, expected); + printf("\n"); + printf("got %s: ", label); + printf(format, got); + printf("\n"); + test_fail(); + } +} - uint64_t expected_strides[ndims] = { 3*5*7, 5*7, 7, 1 }; - if (!assert_arrays_match("strides", "%u", ndims, expected_strides, strides)) test_fail(); +void test_calc_size_from_shape_normal() { + // Test shapes with normal values + BEGIN_TEST(); + + int32_t shape[4] = { 2, 3, 5, 7 }; + debug_print_array("%d", 4, shape); + assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape(4, shape)); +} + +void test_calc_size_from_shape_has_zero() { + // Test shapes with 0 in them + BEGIN_TEST(); + + int32_t shape[4] = { 2, 0, 5, 7 }; + assert_values_match("size", "%d", 0, ndarray_util::calc_size_from_shape(4, shape)); +} + +void test_set_strides_by_shape() { + // Test `set_strides_by_shape()` + BEGIN_TEST(); + + int32_t shape[4] = { 99, 3, 5, 7 }; + int32_t strides[4] = { 0 }; + ndarray_util::set_strides_by_shape(4, strides, shape); + + int32_t expected_strides[4] = { 105, 35, 7, 1 }; + assert_arrays_match("strides", "%u", 4u, expected_strides, strides); +} + +void test_ndarray_indices_iter_normal() { + // Test NDArrayIndicesIter normal behavior + BEGIN_TEST(); + + int32_t shape[3] = { 1, 2, 3 }; + int32_t indices[3] = { 0, 0, 0 }; + auto iter = NDArrayIndicesIter { + .ndims = 3u, + .shape = shape, + .indices = indices + }; + + assert_arrays_match("indices #0", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 }); + iter.next(); + assert_arrays_match("indices #1", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 }); + iter.next(); + assert_arrays_match("indices #2", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 2 }); + iter.next(); + assert_arrays_match("indices #3", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 0 }); + iter.next(); + assert_arrays_match("indices #4", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 1 }); + iter.next(); + assert_arrays_match("indices #5", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 2 }); + iter.next(); + assert_arrays_match("indices #6", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 }); // Loops back + iter.next(); + assert_arrays_match("indices #7", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 }); +} + +void test_ndarray_fill_generic() { + // Test ndarray fill_generic + BEGIN_TEST(); + + // Choose a type that's neither int32_t nor uint64_t (candidates of SizeT) to spice it up + // Also make all the octets non-zero, to see if `memcpy` in `fill_generic` is working perfectly. + uint16_t fill_value = 0xFACE; + + uint16_t in_data[6] = { 100, 101, 102, 103, 104, 105 }; // Fill `data` with values that != `999` + int32_t in_itemsize = sizeof(uint16_t); + const int32_t in_ndims = 2; + int32_t in_shape[in_ndims] = { 2, 3 }; + int32_t in_strides[in_ndims] = {}; + NDArray ndarray = { + .data = (uint8_t*) in_data, + .itemsize = in_itemsize, + .ndims = in_ndims, + .shape = in_shape, + .strides = in_strides, + }; + ndarray.set_strides_by_shape(); + ndarray.fill_generic((uint8_t*) &fill_value); // `fill_generic` here + + uint16_t expected_data[6] = { fill_value, fill_value, fill_value, fill_value, fill_value, fill_value }; + assert_arrays_match("data", "0x%hX", 6, expected_data, in_data); +} + +void test_ndarray_set_to_eye() { + // Test `set_to_eye` behavior (helper function to implement `np.eye()`) + BEGIN_TEST(); + + double in_data[9] = { 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0 }; + int32_t in_itemsize = sizeof(double); + const int32_t in_ndims = 2; + int32_t in_shape[in_ndims] = { 3, 3 }; + int32_t in_strides[in_ndims] = {}; + NDArray ndarray = { + .data = (uint8_t*) in_data, + .itemsize = in_itemsize, + .ndims = in_ndims, + .shape = in_shape, + .strides = in_strides, + }; + ndarray.set_strides_by_shape(); + + double zero = 0.0; + double one = 1.0; + ndarray.set_to_eye(1, (uint8_t*) &zero, (uint8_t*) &one); + + assert_values_match("in_data[0]", "%f", 0.0, in_data[0]); + assert_values_match("in_data[1]", "%f", 1.0, in_data[1]); + assert_values_match("in_data[2]", "%f", 0.0, in_data[2]); + assert_values_match("in_data[3]", "%f", 0.0, in_data[3]); + assert_values_match("in_data[4]", "%f", 0.0, in_data[4]); + assert_values_match("in_data[5]", "%f", 1.0, in_data[5]); + assert_values_match("in_data[6]", "%f", 0.0, in_data[6]); + assert_values_match("in_data[7]", "%f", 0.0, in_data[7]); + assert_values_match("in_data[8]", "%f", 0.0, in_data[8]); +} + +void test_slice_1() { + // Test `slice(5, None, None).indices(100) == slice(5, 100, 1)` + BEGIN_TEST(); + + UserSlice user_slice = { + .start_defined = 1, + .start = 5, + .stop_defined = 0, + .step_defined = 0, + }; + + auto slice = user_slice.indices(100); + assert_values_match("start", "%d", 5, slice.start); + assert_values_match("stop", "%d", 100, slice.stop); + assert_values_match("step", "%d", 1, slice.step); +} + +void test_slice_2() { + // Test `slice(400, 999, None).indices(100) == slice(100, 100, 1)` + BEGIN_TEST(); + + UserSlice user_slice = { + .start_defined = 1, + .start = 400, + .stop_defined = 0, + .step_defined = 0, + }; + + auto slice = user_slice.indices(100); + assert_values_match("start", "%d", 100, slice.start); + assert_values_match("stop", "%d", 100, slice.stop); + assert_values_match("step", "%d", 1, slice.step); +} + +void test_slice_3() { + // Test `slice(-10, -5, None).indices(100) == slice(90, 95, 1)` + BEGIN_TEST(); + + UserSlice user_slice = { + .start_defined = 1, + .start = -10, + .stop_defined = 1, + .stop = -5, + .step_defined = 0, + }; + + auto slice = user_slice.indices(100); + assert_values_match("start", "%d", 90, slice.start); + assert_values_match("stop", "%d", 95, slice.stop); + assert_values_match("step", "%d", 1, slice.step); +} + +void test_slice_4() { + // Test `slice(None, None, -5).indices(100) == (99, -1, -5)` + BEGIN_TEST(); + + UserSlice user_slice = { + .start_defined = 0, + .stop_defined = 0, + .step_defined = 1, + .step = -5 + }; + + auto slice = user_slice.indices(100); + assert_values_match("start", "%d", 99, slice.start); + assert_values_match("stop", "%d", -1, slice.stop); + assert_values_match("step", "%d", -5, slice.step); } int main() { - test_strides_from_shape(); + test_calc_size_from_shape_normal(); + test_calc_size_from_shape_has_zero(); + test_set_strides_by_shape(); + test_ndarray_indices_iter_normal(); + test_ndarray_fill_generic(); + test_ndarray_set_to_eye(); + test_slice_1(); + test_slice_2(); + test_slice_3(); + test_slice_4(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/irrt_typedefs.hpp b/nac3core/irrt/irrt_typedefs.hpp new file mode 100644 index 00000000..acd75da8 --- /dev/null +++ b/nac3core/irrt/irrt_typedefs.hpp @@ -0,0 +1,12 @@ +#pragma once + +// This is made toggleable since `irrt_test.cpp` itself would include +// headers that define the `int_t` family. +#ifndef IRRT_DONT_TYPEDEF_INTS +typedef _BitInt(8) int8_t; +typedef unsigned _BitInt(8) uint8_t; +typedef _BitInt(32) int32_t; +typedef unsigned _BitInt(32) uint32_t; +typedef _BitInt(64) int64_t; +typedef unsigned _BitInt(64) uint64_t; +#endif \ No newline at end of file diff --git a/nac3core/irrt/irrt_utils.hpp b/nac3core/irrt/irrt_utils.hpp new file mode 100644 index 00000000..7ddc9ac0 --- /dev/null +++ b/nac3core/irrt/irrt_utils.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "irrt_typedefs.hpp" + +namespace { + template + T max(T a, T b) { + return a > b ? a : b; + } + + template + T min(T a, T b) { + return a > b ? b : a; + } + + void nac3_assert(bool condition) { + // Doesn't do anything (for now (?)) + // Helps to make code self-documenting + + if (!condition) { + // TODO: don't crash the program + // TODO: address 0 on hardware might be writable? + uint8_t* death = nullptr; + *death = 0; + } + } +} \ No newline at end of file