forked from M-Labs/nac3
core: irrt reformat & more progress
progress details: - organize test suite sources with namespaces - organize ndarray implementations with namespaces - remove extraneous code/comment - add more tests - some renaming - fix pre-existing bugs - ndarray::subscript now throw errors
This commit is contained in:
parent
61dd9762d8
commit
628965e519
|
@ -4,81 +4,81 @@
|
|||
#include <irrt/utils.hpp>
|
||||
|
||||
namespace {
|
||||
// nac3core's "str" struct type definition
|
||||
template <typename SizeT>
|
||||
struct Str {
|
||||
const char* content;
|
||||
SizeT length;
|
||||
};
|
||||
|
||||
// A limited set of errors IRRT could use.
|
||||
typedef uint32_t ErrorId;
|
||||
struct ErrorIds {
|
||||
ErrorId index_error;
|
||||
ErrorId value_error;
|
||||
ErrorId assertion_error;
|
||||
ErrorId runtime_error;
|
||||
};
|
||||
|
||||
struct ErrorContext {
|
||||
// Context
|
||||
ErrorIds* error_ids;
|
||||
|
||||
// Error thrown by IRRT
|
||||
ErrorId error_id;
|
||||
const char* message_template; // MUST BE `&'static`
|
||||
uint64_t param1;
|
||||
uint64_t param2;
|
||||
uint64_t param3;
|
||||
|
||||
void initialize(ErrorIds* error_ids) {
|
||||
this->error_ids = error_ids;
|
||||
clear_error();
|
||||
}
|
||||
|
||||
void clear_error() {
|
||||
// Point the message_template to an empty str. Don't set it to nullptr as a sentinel
|
||||
this->message_template = "";
|
||||
}
|
||||
|
||||
void set_error(ErrorId error_id, const char* message, uint64_t param1 = 0, uint64_t param2 = 0, uint64_t param3 = 0) {
|
||||
this->error_id = error_id;
|
||||
this->message_template = message;
|
||||
this->param1 = param1;
|
||||
this->param2 = param2;
|
||||
this->param3 = param3;
|
||||
}
|
||||
|
||||
bool has_error() {
|
||||
return !cstr_utils::is_empty(message_template);
|
||||
}
|
||||
|
||||
// nac3core's "str" struct type definition
|
||||
template <typename SizeT>
|
||||
void get_error_str(Str<SizeT> *dst_str) {
|
||||
dst_str->content = message_template;
|
||||
dst_str->length = (SizeT) cstr_utils::length(message_template);
|
||||
}
|
||||
};
|
||||
struct Str {
|
||||
const char* content;
|
||||
SizeT length;
|
||||
};
|
||||
|
||||
// A limited set of errors IRRT could use.
|
||||
typedef uint32_t ErrorId;
|
||||
struct ErrorIds {
|
||||
ErrorId index_error;
|
||||
ErrorId value_error;
|
||||
ErrorId assertion_error;
|
||||
ErrorId runtime_error;
|
||||
};
|
||||
|
||||
struct ErrorContext {
|
||||
// Context
|
||||
const ErrorIds* error_ids;
|
||||
|
||||
// Error thrown by IRRT
|
||||
ErrorId error_id;
|
||||
const char* message_template; // MUST BE `&'static`
|
||||
int64_t param1;
|
||||
int64_t param2;
|
||||
int64_t param3;
|
||||
|
||||
void initialize(const ErrorIds* error_ids) {
|
||||
this->error_ids = error_ids;
|
||||
clear_error();
|
||||
}
|
||||
|
||||
void clear_error() {
|
||||
// Point the message_template to an empty str. Don't set it to nullptr as a sentinel
|
||||
this->message_template = "";
|
||||
}
|
||||
|
||||
void set_error(ErrorId error_id, const char* message, int64_t param1 = 0, int64_t param2 = 0, int64_t param3 = 0) {
|
||||
this->error_id = error_id;
|
||||
this->message_template = message;
|
||||
this->param1 = param1;
|
||||
this->param2 = param2;
|
||||
this->param3 = param3;
|
||||
}
|
||||
|
||||
bool has_error() {
|
||||
return !cstr_utils::is_empty(message_template);
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
void get_error_str(Str<SizeT> *dst_str) {
|
||||
dst_str->content = message_template;
|
||||
dst_str->length = (SizeT) cstr_utils::length(message_template);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
void __nac3_error_context_initialize(ErrorContext* errctx, ErrorIds* error_ids) {
|
||||
errctx->initialize(error_ids);
|
||||
}
|
||||
void __nac3_error_context_initialize(ErrorContext* errctx, const ErrorIds* error_ids) {
|
||||
errctx->initialize(error_ids);
|
||||
}
|
||||
|
||||
bool __nac3_error_context_has_no_error(ErrorContext* errctx) {
|
||||
return !errctx->has_error();
|
||||
}
|
||||
bool __nac3_error_context_has_no_error(ErrorContext* errctx) {
|
||||
return !errctx->has_error();
|
||||
}
|
||||
|
||||
void __nac3_error_context_get_error_str(ErrorContext* errctx, Str<int32_t> *dst_str) {
|
||||
errctx->get_error_str<int32_t>(dst_str);
|
||||
}
|
||||
void __nac3_error_context_get_error_str(ErrorContext* errctx, Str<int32_t> *dst_str) {
|
||||
errctx->get_error_str<int32_t>(dst_str);
|
||||
}
|
||||
|
||||
void __nac3_error_context_get_error_str64(ErrorContext* errctx, Str<int64_t> *dst_str) {
|
||||
errctx->get_error_str<int64_t>(dst_str);
|
||||
}
|
||||
void __nac3_error_context_get_error_str64(ErrorContext* errctx, Str<int64_t> *dst_str) {
|
||||
errctx->get_error_str<int64_t>(dst_str);
|
||||
}
|
||||
|
||||
void __nac3_error_dummy_raise(ErrorContext* errctx) {
|
||||
errctx->set_error(errctx->error_ids->runtime_error, "THROWN FROM __nac3_error_dummy_raise!!!!!!");
|
||||
}
|
||||
void __nac3_error_dummy_raise(ErrorContext* errctx) {
|
||||
errctx->set_error(errctx->error_ids->runtime_error, "THROWN FROM __nac3_error_dummy_raise!!!!!!");
|
||||
}
|
||||
}
|
|
@ -4,10 +4,9 @@
|
|||
#include <irrt/error_context.hpp>
|
||||
#include <irrt/numpy/ndarray_def.hpp>
|
||||
|
||||
namespace {
|
||||
namespace ndarray {
|
||||
namespace { namespace ndarray { namespace basic {
|
||||
namespace util {
|
||||
// Throw an error if there is an axis with negative dimension
|
||||
// throw an error if there is an axis with negative dimension
|
||||
template <typename SizeT>
|
||||
void assert_shape_no_negative(ErrorContext* errctx, SizeT ndims, const SizeT* shape) {
|
||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||
|
@ -22,7 +21,7 @@ namespace ndarray {
|
|||
}
|
||||
}
|
||||
|
||||
// Compute the size/# of elements of an ndarray given its shape
|
||||
// compute the size/# of elements of an ndarray given its shape
|
||||
template <typename SizeT>
|
||||
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||||
SizeT size = 1;
|
||||
|
@ -30,12 +29,12 @@ namespace ndarray {
|
|||
return size;
|
||||
}
|
||||
|
||||
// Compute the strides of an ndarray given an ndarray `shape`
|
||||
// and assuming that the ndarray is *fully C-contagious*.
|
||||
// compute the strides of an ndarray given an ndarray `shape`
|
||||
// and assuming that the ndarray is *fully c-contagious*.
|
||||
//
|
||||
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
|
||||
// you might want to read up on https://ajcr.net/stride-guide-part-1/.
|
||||
//
|
||||
// This function might be used in isolation without an ndarray. That's
|
||||
// this function might be used in isolation without an ndarray. that's
|
||||
// why it separated out into its own util function.
|
||||
template <typename SizeT>
|
||||
void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
|
||||
|
@ -59,24 +58,23 @@ namespace ndarray {
|
|||
}
|
||||
}
|
||||
|
||||
// Calculate the size/# of elements of an `ndarray`.
|
||||
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
||||
// calculate the size/# of elements of an `ndarray`.
|
||||
// this function corresponds to `np.size(<ndarray>)` or `ndarray.size`
|
||||
template <typename SizeT>
|
||||
SizeT size(NDArray<SizeT>* ndarray) {
|
||||
return ndarray::util::calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
||||
return util::calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
||||
}
|
||||
|
||||
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
|
||||
// This function corresponds to `ndarray.nbytes`
|
||||
// calculate the number of bytes of its content of an `ndarray` *in its view*.
|
||||
// this function corresponds to `ndarray.nbytes`
|
||||
template <typename SizeT>
|
||||
SizeT nbytes(NDArray<SizeT>* ndarray) {
|
||||
return ndarray::size(ndarray) * ndarray->itemsize;
|
||||
return size(ndarray) * ndarray->itemsize;
|
||||
}
|
||||
|
||||
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
|
||||
template <typename SizeT>
|
||||
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
||||
ndarray::util::set_strides_by_shape(ndarray->itemsize, ndarray->ndims, ndarray->strides, ndarray->shape);
|
||||
util::set_strides_by_shape(ndarray->itemsize, ndarray->ndims, ndarray->strides, ndarray->shape);
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
|
@ -90,11 +88,11 @@ namespace ndarray {
|
|||
template <typename SizeT>
|
||||
uint8_t* get_nth_pelement(NDArray<SizeT>* ndarray, SizeT nth) {
|
||||
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
|
||||
ndarray::util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, nth);
|
||||
return ndarray::get_pelement_by_indices(ndarray, indices);
|
||||
util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, nth);
|
||||
return get_pelement_by_indices(ndarray, indices);
|
||||
}
|
||||
|
||||
// Get the pointer to the nth element of the ndarray as if it were flattened.
|
||||
// get the pointer to the nth element of the ndarray as if it were flattened.
|
||||
template <typename SizeT>
|
||||
uint8_t* checked_get_nth_pelement(NDArray<SizeT>* ndarray, ErrorContext* errctx, SizeT nth) {
|
||||
SizeT arr_size = ndarray->size();
|
||||
|
@ -106,46 +104,47 @@ namespace ndarray {
|
|||
);
|
||||
return 0;
|
||||
}
|
||||
return ndarray::get_nth_pelement(ndarray, nth);
|
||||
return get_nth_pelement(ndarray, nth);
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement, const uint8_t* pvalue) {
|
||||
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
||||
}
|
||||
};
|
||||
}
|
||||
} } }
|
||||
|
||||
extern "C" {
|
||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||
return ndarray::size(ndarray);
|
||||
}
|
||||
using namespace ndarray::basic;
|
||||
|
||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||
return ndarray::size(ndarray);
|
||||
}
|
||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||
return size(ndarray);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||
return ndarray::nbytes(ndarray);
|
||||
}
|
||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||
return size(ndarray);
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||
return ndarray::nbytes(ndarray);
|
||||
}
|
||||
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||
return nbytes(ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) {
|
||||
ndarray::util::assert_shape_no_negative(errctx, ndims, shape);
|
||||
}
|
||||
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||
return nbytes(ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_shape_no_negative64(ErrorContext* errctx, int64_t ndims, int64_t* shape) {
|
||||
ndarray::util::assert_shape_no_negative(errctx, ndims, shape);
|
||||
}
|
||||
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) {
|
||||
util::assert_shape_no_negative(errctx, ndims, shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||
ndarray::set_strides_by_shape(ndarray);
|
||||
}
|
||||
void __nac3_ndarray_util_assert_shape_no_negative64(ErrorContext* errctx, int64_t ndims, int64_t* shape) {
|
||||
util::assert_shape_no_negative(errctx, ndims, shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||
ndarray::set_strides_by_shape(ndarray);
|
||||
}
|
||||
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||
set_strides_by_shape(ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||
set_strides_by_shape(ndarray);
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
#include <irrt/numpy/ndarray_def.hpp>
|
||||
|
||||
namespace {
|
||||
namespace ndarray {
|
||||
namespace { namespace ndarray { namespace broadcast {
|
||||
namespace util {
|
||||
template <typename SizeT>
|
||||
bool can_broadcast_shape_to(
|
||||
|
@ -112,5 +111,4 @@ namespace ndarray {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} } }
|
|
@ -1,52 +1,55 @@
|
|||
#pragma once
|
||||
|
||||
namespace {
|
||||
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
||||
//
|
||||
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
||||
//
|
||||
// Some resources you might find helpful:
|
||||
// - The official numpy implementations:
|
||||
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
||||
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
||||
// - https://ajcr.net/stride-guide-part-1/.
|
||||
// - https://ajcr.net/stride-guide-part-2/.
|
||||
// - https://ajcr.net/stride-guide-part-3/.
|
||||
template <typename SizeT>
|
||||
struct NDArray {
|
||||
// The underlying data this `ndarray` is pointing to.
|
||||
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
|
||||
//
|
||||
// NOTE: Formally this should be of type `void *`, but clang
|
||||
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
||||
// so we will put `uint8_t *` here for clarity.
|
||||
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
|
||||
//
|
||||
// This pointer should point to the first element of the ndarray directly
|
||||
uint8_t *data;
|
||||
// Some resources you might find helpful:
|
||||
// - The official numpy implementations:
|
||||
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
||||
// - On strides (about reshaping, slicing, C-contagiousness, etc)
|
||||
// - https://ajcr.net/stride-guide-part-1/.
|
||||
// - https://ajcr.net/stride-guide-part-2/.
|
||||
// - https://ajcr.net/stride-guide-part-3/.
|
||||
template <typename SizeT>
|
||||
struct NDArray {
|
||||
// The underlying data this `ndarray` is pointing to.
|
||||
//
|
||||
// NOTE: Formally this should be of type `void *`, but clang
|
||||
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
||||
// so we will put `uint8_t *` here for clarity.
|
||||
//
|
||||
// This pointer should point to the first element of the ndarray directly
|
||||
uint8_t *data;
|
||||
|
||||
// The number of bytes of a single element in `data`.
|
||||
//
|
||||
// The `SizeT` is treated as `unsigned`.
|
||||
SizeT itemsize;
|
||||
// The number of bytes of a single element in `data`.
|
||||
//
|
||||
// The `SizeT` is treated as `unsigned`.
|
||||
SizeT itemsize;
|
||||
|
||||
// The number of dimensions of this shape.
|
||||
//
|
||||
// The `SizeT` is treated as `unsigned`.
|
||||
SizeT ndims;
|
||||
// The number of dimensions of this shape.
|
||||
//
|
||||
// The `SizeT` is treated as `unsigned`.
|
||||
SizeT ndims;
|
||||
|
||||
// Array shape, with length equal to `ndims`.
|
||||
//
|
||||
// The `SizeT` is treated as `unsigned`.
|
||||
//
|
||||
// NOTE: `shape` can contain 0.
|
||||
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
||||
SizeT *shape;
|
||||
// Array shape, with length equal to `ndims`.
|
||||
//
|
||||
// The `SizeT` is treated as `unsigned`.
|
||||
//
|
||||
// NOTE: `shape` can contain 0.
|
||||
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
|
||||
SizeT *shape;
|
||||
|
||||
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
||||
//
|
||||
// The `SizeT` is treated as `signed`.
|
||||
//
|
||||
// NOTE: `strides` can have negative numbers.
|
||||
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
||||
SizeT *strides;
|
||||
};
|
||||
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
|
||||
//
|
||||
// The `SizeT` is treated as `signed`.
|
||||
//
|
||||
// NOTE: `strides` can have negative numbers.
|
||||
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
|
||||
SizeT *strides;
|
||||
};
|
||||
|
||||
// Because ndarray is so complicated, its functions are splitted into
|
||||
// different files and namespaces.
|
||||
}
|
|
@ -3,26 +3,26 @@
|
|||
#include <irrt/numpy/ndarray_def.hpp>
|
||||
#include <irrt/numpy/ndarray_basic.hpp>
|
||||
|
||||
namespace {
|
||||
namespace ndarray {
|
||||
namespace { namespace ndarray { namespace fill {
|
||||
// Fill the ndarray with a value
|
||||
template <typename SizeT>
|
||||
void fill_generic(NDArray<SizeT>* ndarray, const uint8_t* pvalue) {
|
||||
const SizeT size = ndarray::size(ndarray);
|
||||
const SizeT size = ndarray::basic::size(ndarray);
|
||||
for (SizeT i = 0; i < size; i++) {
|
||||
uint8_t* pelement = ndarray::get_nth_pelement(ndarray, i); // No need for checked_get_nth_pelement
|
||||
ndarray::set_pelement_value(ndarray, pelement, pvalue);
|
||||
uint8_t* pelement = ndarray::basic::get_nth_pelement(ndarray, i); // No need for checked_get_nth_pelement
|
||||
ndarray::basic::set_pelement_value(ndarray, pelement, pvalue);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} } }
|
||||
|
||||
extern "C" {
|
||||
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
||||
ndarray::fill_generic(ndarray, pvalue);
|
||||
}
|
||||
using namespace ndarray::fill;
|
||||
|
||||
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
||||
ndarray::fill_generic(ndarray, pvalue);
|
||||
}
|
||||
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
||||
fill_generic(ndarray, pvalue);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
||||
fill_generic(ndarray, pvalue);
|
||||
}
|
||||
}
|
|
@ -3,31 +3,31 @@
|
|||
#include <irrt/slice.hpp>
|
||||
#include <irrt/numpy/ndarray_def.hpp>
|
||||
#include <irrt/numpy/ndarray_basic.hpp>
|
||||
#include <irrt/error_context.hpp>
|
||||
|
||||
namespace {
|
||||
typedef uint8_t NDSubscriptType;
|
||||
typedef uint8_t NDSubscriptType;
|
||||
|
||||
extern "C" {
|
||||
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
|
||||
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_SLICE = 1;
|
||||
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
|
||||
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_SLICE = 1;
|
||||
|
||||
struct NDSubscript {
|
||||
// A poor-man's enum variant type
|
||||
NDSubscriptType type;
|
||||
|
||||
/*
|
||||
if type == INPUT_SUBSCRIPT_TYPE_INDEX => `slice` points to a single `SliceIndex`
|
||||
if type == INPUT_SUBSCRIPT_TYPE_SLICE => `slice` points to a single `UserRange`
|
||||
|
||||
`SizeT` is controlled by the caller: `NDSubscript` only cares about where that
|
||||
slice is (the pointer), `NDSubscript` does not care/know about the actual `sizeof()`
|
||||
of the slice value.
|
||||
*/
|
||||
uint8_t* data;
|
||||
};
|
||||
}
|
||||
|
||||
struct NDSubscript {
|
||||
// A poor-man's enum variant type
|
||||
NDSubscriptType type;
|
||||
|
||||
/*
|
||||
if type == INPUT_SUBSCRIPT_TYPE_INDEX => `slice` points to a single `SizeT`
|
||||
if type == INPUT_SUBSCRIPT_TYPE_SLICE => `slice` points to a single `UserRange<SizeT>`
|
||||
|
||||
`SizeT` is controlled by the caller: `NDSubscript` only cares about where that
|
||||
slice is (the pointer), `NDSubscript` does not care/know about the actual `sizeof()`
|
||||
of the slice value.
|
||||
*/
|
||||
uint8_t* data;
|
||||
};
|
||||
|
||||
namespace ndarray {
|
||||
namespace { namespace ndarray { namespace subscript {
|
||||
namespace util {
|
||||
template<typename SizeT>
|
||||
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_subscripts, const NDSubscript* subscripts) {
|
||||
|
@ -61,7 +61,7 @@ namespace ndarray {
|
|||
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize`
|
||||
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
||||
template <typename SizeT>
|
||||
void subscript(SizeT num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
void subscript(ErrorContext* errctx, SizeT num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
||||
|
||||
|
@ -79,8 +79,19 @@ namespace ndarray {
|
|||
// Handle when the ndsubscript is just a single (possibly negative) integer
|
||||
// e.g., `my_array[::2, -5, ::-1]`
|
||||
// ^^------ like this
|
||||
SizeT index_user = *((SizeT*) ndsubscript->data);
|
||||
SizeT index = slice::resolve_index_in_length(src_ndarray->shape[src_axis], index_user);
|
||||
SliceIndex input_index = *((SliceIndex*) ndsubscript->data);
|
||||
|
||||
SliceIndex index = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input_index);
|
||||
if (index == slice::OUT_OF_BOUNDS) {
|
||||
// Error message copied from numpy by doing `np.zeros((3, 4))[100]`
|
||||
errctx->set_error(
|
||||
errctx->error_ids->index_error,
|
||||
"index {0} is out of bounds for axis {1} with size {2}",
|
||||
input_index, src_axis, src_ndarray->shape[src_axis]
|
||||
);
|
||||
return; // Terminate
|
||||
}
|
||||
|
||||
dst_ndarray->data += index * src_ndarray->strides[src_axis]; // Add offset
|
||||
|
||||
// Next
|
||||
|
@ -89,11 +100,14 @@ namespace ndarray {
|
|||
// Handle when the ndsubscript is a slice (represented by UserSlice in IRRT)
|
||||
// e.g., `my_array[::2, -5, ::-1]`
|
||||
// ^^^------^^^^----- like these
|
||||
UserSlice* user_slice = (UserSlice*) ndsubscript->data;
|
||||
UserSlice* input_user_slice = (UserSlice*) ndsubscript->data;
|
||||
|
||||
// TODO: use checked indices
|
||||
Slice slice;
|
||||
user_slice->indices(src_ndarray->shape[src_axis], &slice); // To resolve negative indices and other funny stuff written by the user
|
||||
input_user_slice->indices_checked(errctx, src_ndarray->shape[src_axis], &slice); // To resolve negative indices and other funny stuff written by the user
|
||||
if (errctx->has_error()) {
|
||||
return; // Propagate error
|
||||
}
|
||||
|
||||
// NOTE: There is no need to write special code to handle negative steps/strides.
|
||||
// This simple implementation meticulously handles both positive and negative steps/strides.
|
||||
|
@ -123,15 +137,16 @@ namespace ndarray {
|
|||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} } }
|
||||
|
||||
extern "C" {
|
||||
void __nac3_ndarray_subscript(int32_t num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
||||
ndarray::subscript(num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||
}
|
||||
using namespace ndarray::subscript;
|
||||
|
||||
void __nac3_ndarray_subscript64(int64_t num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
||||
ndarray::subscript(num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||
}
|
||||
void __nac3_ndarray_subscript(ErrorContext* errctx, int32_t num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
||||
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_subscript64(ErrorContext* errctx, int64_t num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
||||
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||
}
|
||||
}
|
|
@ -4,129 +4,140 @@
|
|||
#include <irrt/slice.hpp>
|
||||
|
||||
namespace {
|
||||
struct Slice {
|
||||
SliceIndex start;
|
||||
SliceIndex stop;
|
||||
SliceIndex step;
|
||||
struct Slice {
|
||||
SliceIndex start;
|
||||
SliceIndex stop;
|
||||
SliceIndex step;
|
||||
|
||||
// The length/The number of elements of the slice if it were a range,
|
||||
// i.e., the value of `len(range(this->start, this->stop, this->end))`
|
||||
SliceIndex len() {
|
||||
SliceIndex diff = stop - start;
|
||||
if (diff > 0 && step > 0) {
|
||||
return ((diff - 1) / step) + 1;
|
||||
} else if (diff < 0 && step < 0) {
|
||||
return ((diff + 1) / step) + 1;
|
||||
} else {
|
||||
return 0;
|
||||
// The length/The number of elements of the slice if it were a range,
|
||||
// i.e., the value of `len(range(this->start, this->stop, this->end))`
|
||||
SliceIndex len() {
|
||||
SliceIndex diff = stop - start;
|
||||
if (diff > 0 && step > 0) {
|
||||
return ((diff - 1) / step) + 1;
|
||||
} else if (diff < 0 && step < 0) {
|
||||
return ((diff + 1) / step) + 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace slice {
|
||||
// "Resolve" an index value under a length in Python lists.
|
||||
// If you have a `list` of length 100, `list[-1]` would resolve to `list[100-1] == list[99]`.
|
||||
//
|
||||
// If length == 0, this function returns 0
|
||||
//
|
||||
// If index is out of bounds, this function clamps the value
|
||||
// (to `list[0]` or `list[-1]` in the context of a list and depending on if index is + or -)
|
||||
SliceIndex resolve_index_in_length_clamped(SliceIndex length, SliceIndex index) {
|
||||
if (index < 0) {
|
||||
// Remember that index is negative, so do a plus here
|
||||
return max<SliceIndex>(length + index, 0);
|
||||
} else {
|
||||
return min<SliceIndex>(length, index);
|
||||
}
|
||||
}
|
||||
|
||||
const SliceIndex OUT_OF_BOUNDS = -1;
|
||||
|
||||
// Like `resolve_index_in_length`.
|
||||
// But also checks if the resolved index is in
|
||||
// bounds (function returns true) or out of bounds
|
||||
// (function returns false); `0 <= resolved index < length` is false).
|
||||
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) {
|
||||
SliceIndex resolved = index < 0 ? length + index : index;
|
||||
|
||||
bool in_bounds = 0 <= resolved && resolved < length;
|
||||
return in_bounds ? resolved : OUT_OF_BOUNDS;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace slice {
|
||||
// "Resolve" an index value under a length in Python lists.
|
||||
// If you have a `list` of length 100, `list[-1]` would resolve to `list[100-1] == list[99]`.
|
||||
// A user-written Python-like slice.
|
||||
//
|
||||
// If length == 0, this function returns 0
|
||||
// i.e., this slice is a triple of either an int or nothing. (e.g., `my_array[:10:2]`, `start` is None)
|
||||
//
|
||||
// If index is out of bounds, this function clamps the value
|
||||
// (to `list[0]` or `list[-1]` in the context of a list and depending on if index is + or -)
|
||||
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) {
|
||||
if (index < 0) {
|
||||
// Remember that index is negative, so do a plus here
|
||||
return max<SliceIndex>(length + index, 0);
|
||||
} else {
|
||||
return min<SliceIndex>(length, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
// You can "resolve" a `UserSlice` by using `user_slice.indices(<length>)`
|
||||
struct UserSlice {
|
||||
// Did the user specify `start`? If 0, `start` is undefined (and contains an empty value)
|
||||
bool start_defined;
|
||||
SliceIndex start;
|
||||
|
||||
// A user-written Python-like slice.
|
||||
//
|
||||
// i.e., this slice is a triple of either an int or nothing. (e.g., `my_array[:10:2]`, `start` is None)
|
||||
//
|
||||
// You can "resolve" a `UserSlice` by using `UserSlice::indices(<length>)`
|
||||
struct UserSlice {
|
||||
// Did the user specify `start`? If 0, `start` is undefined (and contains an empty value)
|
||||
bool start_defined;
|
||||
SliceIndex start;
|
||||
// Similar to `start_defined`
|
||||
bool stop_defined;
|
||||
SliceIndex stop;
|
||||
|
||||
// Similar to `start_defined`
|
||||
bool stop_defined;
|
||||
SliceIndex stop;
|
||||
// Similar to `start_defined`
|
||||
bool step_defined;
|
||||
SliceIndex step;
|
||||
|
||||
// Similar to `start_defined`
|
||||
bool step_defined;
|
||||
SliceIndex step;
|
||||
|
||||
// Constructor faithfully follows Python's `slice()`.
|
||||
explicit UserSlice(SliceIndex stop) {
|
||||
start_defined = false;
|
||||
stop_defined = true;
|
||||
step_defined = false;
|
||||
|
||||
this->stop = stop;
|
||||
}
|
||||
|
||||
explicit UserSlice(SliceIndex start, SliceIndex stop) {
|
||||
start_defined = true;
|
||||
stop_defined = true;
|
||||
step_defined = false;
|
||||
|
||||
this->start = start;
|
||||
this->stop = stop;
|
||||
}
|
||||
|
||||
explicit UserSlice(SliceIndex start, SliceIndex stop, SliceIndex step) {
|
||||
start_defined = true;
|
||||
stop_defined = true;
|
||||
step_defined = true;
|
||||
|
||||
this->start = start;
|
||||
this->stop = stop;
|
||||
this->step = step;
|
||||
}
|
||||
|
||||
// Like Python's `slice(start, stop, step).indices(length)`
|
||||
void indices(SliceIndex length, Slice* result) {
|
||||
// NOTE: This function implements Python's `slice.indices` *FAITHFULLY*.
|
||||
// SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546
|
||||
result->step = step_defined ? step : 1;
|
||||
bool step_is_negative = result->step < 0;
|
||||
|
||||
if (start_defined) {
|
||||
result->start = slice::resolve_index_in_length(length, start);
|
||||
} else {
|
||||
result->start = step_is_negative ? length - 1 : 0;
|
||||
}
|
||||
|
||||
if (stop_defined) {
|
||||
result->stop = slice::resolve_index_in_length(length, stop);
|
||||
} else {
|
||||
result->stop = step_is_negative ? -1 : length;
|
||||
}
|
||||
}
|
||||
|
||||
// `indices()` but asserts `this->step != 0` and `this->length >= 0`
|
||||
void checked_indices(ErrorContext* errctx, SliceIndex length, Slice* result) {
|
||||
if (!(length >= 0)) {
|
||||
errctx->set_error(
|
||||
errctx->error_ids->value_error,
|
||||
"length should not be negative, got {0}", // Edited. Error message copied from python by doing `slice(0, 0, 0).indices(100)`
|
||||
length
|
||||
);
|
||||
return;
|
||||
// Convenient constructor for C++ internal use only (say testing)
|
||||
UserSlice() {
|
||||
this->reset();
|
||||
}
|
||||
|
||||
if (!(this->step_defined && this->step != 0)) {
|
||||
// Error message
|
||||
errctx->set_error(
|
||||
errctx->error_ids->value_error,
|
||||
"slice step cannot be zero" // Error message copied from python by doing `slice(0, 0, 0).indices(100)`
|
||||
);
|
||||
return;
|
||||
void reset() {
|
||||
this->start_defined = false;
|
||||
this->stop_defined = false;
|
||||
this->step_defined = false;
|
||||
}
|
||||
this->indices(length, result);
|
||||
}
|
||||
};
|
||||
|
||||
void set_start(SliceIndex start) {
|
||||
this->start_defined = true;
|
||||
this->start = start;
|
||||
}
|
||||
|
||||
void set_stop(SliceIndex stop) {
|
||||
this->stop_defined = true;
|
||||
this->stop = stop;
|
||||
}
|
||||
|
||||
void set_step(SliceIndex step) {
|
||||
this->step_defined = true;
|
||||
this->step = step;
|
||||
}
|
||||
|
||||
// Like Python's `slice(start, stop, step).indices(length)`
|
||||
void indices(SliceIndex length, Slice* result) {
|
||||
// NOTE: This function implements Python's `slice.indices` *FAITHFULLY*.
|
||||
// SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546
|
||||
result->step = step_defined ? step : 1;
|
||||
bool step_is_negative = result->step < 0;
|
||||
|
||||
if (start_defined) {
|
||||
result->start = slice::resolve_index_in_length_clamped(length, start);
|
||||
} else {
|
||||
result->start = step_is_negative ? length - 1 : 0;
|
||||
}
|
||||
|
||||
if (stop_defined) {
|
||||
result->stop = slice::resolve_index_in_length_clamped(length, stop);
|
||||
} else {
|
||||
result->stop = step_is_negative ? -1 : length;
|
||||
}
|
||||
}
|
||||
|
||||
// `indices()` but asserts `this->step != 0` and `this->length >= 0`
|
||||
void indices_checked(ErrorContext* errctx, SliceIndex length, Slice* result) {
|
||||
if (length < 0) {
|
||||
errctx->set_error(
|
||||
errctx->error_ids->value_error,
|
||||
"length should not be negative, got {0}", // Edited. Error message copied from python by doing `slice(0, 0, 0).indices(100)`
|
||||
length
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this->step_defined && this->step == 0) {
|
||||
// Error message
|
||||
errctx->set_error(
|
||||
errctx->error_ids->value_error,
|
||||
"slice step cannot be zero" // Error message copied from python by doing `slice(0, 0, 0).indices(100)`
|
||||
);
|
||||
return;
|
||||
}
|
||||
this->indices(length, result);
|
||||
}
|
||||
};
|
||||
}
|
|
@ -3,76 +3,76 @@
|
|||
#include <irrt/int_defs.hpp>
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
const T& max(const T& a, const T& b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T& min(const T& a, const T& b) {
|
||||
return a > b ? b : a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool arrays_match(int len, T* as, T* bs) {
|
||||
for (int i = 0; i < len; i++) {
|
||||
if (as[i] != bs[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace cstr_utils {
|
||||
bool is_empty(const char* str) {
|
||||
return str[0] == '\0';
|
||||
template <typename T>
|
||||
const T& max(const T& a, const T& b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
int8_t compare(const char* a, const char* b) {
|
||||
uint32_t i = 0;
|
||||
while (true) {
|
||||
if (a[i] < b[i]) {
|
||||
return -1;
|
||||
} else if (a[i] > b[i]) {
|
||||
return 1;
|
||||
} else { // a[i] == b[i]
|
||||
if (a[i] == '\0') {
|
||||
return 0;
|
||||
} else {
|
||||
i++;
|
||||
template <typename T>
|
||||
const T& min(const T& a, const T& b) {
|
||||
return a > b ? b : a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool arrays_match(int len, T* as, T* bs) {
|
||||
for (int i = 0; i < len; i++) {
|
||||
if (as[i] != bs[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace cstr_utils {
|
||||
bool is_empty(const char* str) {
|
||||
return str[0] == '\0';
|
||||
}
|
||||
|
||||
int8_t compare(const char* a, const char* b) {
|
||||
uint32_t i = 0;
|
||||
while (true) {
|
||||
if (a[i] < b[i]) {
|
||||
return -1;
|
||||
} else if (a[i] > b[i]) {
|
||||
return 1;
|
||||
} else { // a[i] == b[i]
|
||||
if (a[i] == '\0') {
|
||||
return 0;
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int8_t equal(const char* a, const char* b) {
|
||||
return compare(a, b) == 0;
|
||||
}
|
||||
|
||||
uint32_t length(const char* str) {
|
||||
uint32_t length = 0;
|
||||
while (*str != '\0') {
|
||||
length++;
|
||||
str++;
|
||||
}
|
||||
return length;
|
||||
}
|
||||
|
||||
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
|
||||
for (uint32_t i = 0; i < dst_max_size; i++) {
|
||||
bool is_last = i + 1 == dst_max_size;
|
||||
if (is_last && src[i] != '\0') {
|
||||
dst[i] = '\0';
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src[i] == '\0') {
|
||||
dst[i] = '\0';
|
||||
return true;
|
||||
}
|
||||
|
||||
dst[i] = src[i];
|
||||
int8_t equal(const char* a, const char* b) {
|
||||
return compare(a, b) == 0;
|
||||
}
|
||||
|
||||
__builtin_unreachable();
|
||||
uint32_t length(const char* str) {
|
||||
uint32_t length = 0;
|
||||
while (*str != '\0') {
|
||||
length++;
|
||||
str++;
|
||||
}
|
||||
return length;
|
||||
}
|
||||
|
||||
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
|
||||
for (uint32_t i = 0; i < dst_max_size; i++) {
|
||||
bool is_last = i + 1 == dst_max_size;
|
||||
if (is_last && src[i] != '\0') {
|
||||
dst[i] = '\0';
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src[i] == '\0') {
|
||||
dst[i] = '\0';
|
||||
return true;
|
||||
}
|
||||
|
||||
dst[i] = src[i];
|
||||
}
|
||||
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,12 +9,15 @@
|
|||
|
||||
#include <test/core.hpp>
|
||||
#include <test/test_core.hpp>
|
||||
#include <test/test_ndarray.hpp>
|
||||
#include <test/test_ndarray_basic.hpp>
|
||||
#include <test/test_ndarray_subscript.hpp>
|
||||
#include <test/test_slice.hpp>
|
||||
|
||||
int main() {
|
||||
test_int_exp();
|
||||
run_all_tests_ndarray();
|
||||
run_all_tests_ndarray_slice();
|
||||
// Be wise about the order of suites!!
|
||||
test::core::run();
|
||||
test::slice::run();
|
||||
test::ndarray_basic::run();
|
||||
test::ndarray_subscript::run();
|
||||
return 0;
|
||||
}
|
|
@ -85,4 +85,69 @@ void __assert_values_match(const char* file, int line, T expected, T got) {
|
|||
}
|
||||
}
|
||||
|
||||
#define assert_values_match(expected, got) __assert_values_match(__FILE__, __LINE__, expected, got)
|
||||
#define assert_values_match(expected, got) __assert_values_match(__FILE__, __LINE__, expected, got)
|
||||
|
||||
// A fake set of ErrorIds for testing only
|
||||
const ErrorIds TEST_ERROR_IDS = {
|
||||
.index_error = 0,
|
||||
.value_error = 1,
|
||||
.assertion_error = 2,
|
||||
.runtime_error = 3,
|
||||
};
|
||||
|
||||
ErrorContext create_testing_errctx() {
|
||||
// Everything is global so it is fine to directly return a struct ErrorContext
|
||||
ErrorContext errctx;
|
||||
errctx.initialize(&TEST_ERROR_IDS);
|
||||
return errctx;
|
||||
}
|
||||
|
||||
void debug_print_errctx_content(ErrorContext* errctx) {
|
||||
if (errctx->has_error()) {
|
||||
printf(
|
||||
"(Error ID %d): %s ... where param1 = %ld, param2 = %ld, param3 = %ld\n",
|
||||
errctx->error_id,
|
||||
errctx->message_template,
|
||||
errctx->param1,
|
||||
errctx->param2,
|
||||
errctx->param3
|
||||
);
|
||||
} else {
|
||||
printf("<no error>\n");
|
||||
}
|
||||
}
|
||||
|
||||
void __assert_errctx_no_error(const char* file, int line, ErrorContext* errctx) {
|
||||
if (errctx->has_error()) {
|
||||
print_assertion_failed(file, line);
|
||||
printf("Expecting no error but caught the following:\n\n");
|
||||
debug_print_errctx_content(errctx);
|
||||
test_fail();
|
||||
}
|
||||
}
|
||||
|
||||
#define assert_errctx_no_error(errctx) __assert_errctx_no_error(__FILE__, __LINE__, errctx)
|
||||
|
||||
void __assert_errctx_has_error(const char* file, int line, ErrorContext *errctx, ErrorId expected_error_id) {
|
||||
if (errctx->has_error()) {
|
||||
if (errctx->error_id == expected_error_id) {
|
||||
// OK
|
||||
} else {
|
||||
// Otherwise it got the wrong kind of error
|
||||
print_assertion_failed(file, line);
|
||||
printf(
|
||||
"Expecting error id %d but got error id %d. Error caught:\n\n",
|
||||
expected_error_id,
|
||||
errctx->error_id
|
||||
);
|
||||
debug_print_errctx_content(errctx);
|
||||
test_fail();
|
||||
}
|
||||
} else {
|
||||
print_assertion_failed(file, line);
|
||||
printf("Expecting an error, but there is none.");
|
||||
test_fail();
|
||||
}
|
||||
}
|
||||
|
||||
#define assert_errctx_has_error(errctx, expected_error_id) __assert_errctx_has_error(__FILE__, __LINE__, errctx, expected_error_id)
|
|
@ -6,6 +6,11 @@
|
|||
template <class T>
|
||||
void print_value(const T& value) {}
|
||||
|
||||
template <>
|
||||
void print_value(const bool& value) {
|
||||
printf("%s", value ? "true" : "false");
|
||||
}
|
||||
|
||||
template <>
|
||||
void print_value(const int8_t& value) {
|
||||
printf("%d", value);
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
#pragma once
|
||||
|
||||
#include <test/core.hpp>
|
||||
#include <irrt/core.hpp>
|
||||
#include <irrt_everything.hpp>
|
||||
|
||||
namespace test {
|
||||
namespace core {
|
||||
void test_int_exp() {
|
||||
BEGIN_TEST();
|
||||
|
||||
assert_values_match(125, __nac3_int_exp_impl<int32_t>(5, 3));
|
||||
assert_values_match(3125, __nac3_int_exp_impl<int32_t>(5, 5));
|
||||
}
|
||||
|
||||
void run() {
|
||||
test_int_exp();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,12 +3,14 @@
|
|||
#include <test/core.hpp>
|
||||
#include <irrt_everything.hpp>
|
||||
|
||||
namespace test {
|
||||
namespace ndarray_basic {
|
||||
void test_calc_size_from_shape_normal() {
|
||||
// Test shapes with normal values
|
||||
BEGIN_TEST();
|
||||
|
||||
int32_t shape[4] = { 2, 3, 5, 7 };
|
||||
assert_values_match(210, ndarray::util::calc_size_from_shape<int32_t>(4, shape));
|
||||
assert_values_match(210, ndarray::basic::util::calc_size_from_shape<int32_t>(4, shape));
|
||||
}
|
||||
|
||||
void test_calc_size_from_shape_has_zero() {
|
||||
|
@ -16,7 +18,7 @@ void test_calc_size_from_shape_has_zero() {
|
|||
BEGIN_TEST();
|
||||
|
||||
int32_t shape[4] = { 2, 0, 5, 7 };
|
||||
assert_values_match(0, ndarray::util::calc_size_from_shape<int32_t>(4, shape));
|
||||
assert_values_match(0, ndarray::basic::util::calc_size_from_shape<int32_t>(4, shape));
|
||||
}
|
||||
|
||||
void test_set_strides_by_shape() {
|
||||
|
@ -25,7 +27,7 @@ void test_set_strides_by_shape() {
|
|||
|
||||
int32_t shape[4] = { 99, 3, 5, 7 };
|
||||
int32_t strides[4] = { 0 };
|
||||
ndarray::util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
|
||||
ndarray::basic::util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
|
||||
|
||||
int32_t expected_strides[4] = {
|
||||
105 * sizeof(int32_t),
|
||||
|
@ -36,8 +38,10 @@ void test_set_strides_by_shape() {
|
|||
assert_arrays_match(4, expected_strides, strides);
|
||||
}
|
||||
|
||||
void run_all_tests_ndarray() {
|
||||
void run() {
|
||||
test_calc_size_from_shape_normal();
|
||||
test_calc_size_from_shape_has_zero();
|
||||
test_set_strides_by_shape();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,182 @@
|
|||
#pragma once
|
||||
|
||||
#include <test/core.hpp>
|
||||
#include <irrt_everything.hpp>
|
||||
|
||||
namespace test { namespace ndarray_subscript {
|
||||
void test_ndsubscript_normal_1() {
|
||||
/*
|
||||
Reference Python code:
|
||||
```python
|
||||
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
|
||||
# array([[ 0., 1., 2., 3.],
|
||||
# [ 4., 5., 6., 7.],
|
||||
# [ 8., 9., 10., 11.]])
|
||||
|
||||
dst_ndarray = ndarray[-2:, 1::2]
|
||||
# array([[ 5., 7.],
|
||||
# [ 9., 11.]])
|
||||
|
||||
assert dst_ndarray.shape == (2, 2)
|
||||
assert dst_ndarray.strides == (32, 16)
|
||||
assert dst_ndarray[0, 0] == 5.0
|
||||
assert dst_ndarray[0, 1] == 7.0
|
||||
assert dst_ndarray[1, 0] == 9.0
|
||||
assert dst_ndarray[1, 1] == 11.0
|
||||
```
|
||||
*/
|
||||
BEGIN_TEST();
|
||||
|
||||
double src_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
||||
int32_t src_itemsize = sizeof(double);
|
||||
const int32_t src_ndims = 2;
|
||||
int32_t src_shape[src_ndims] = { 3, 4 };
|
||||
int32_t src_strides[src_ndims] = {};
|
||||
NDArray<int32_t> src_ndarray = {
|
||||
.data = (uint8_t*) src_data,
|
||||
.itemsize = src_itemsize,
|
||||
.ndims = src_ndims,
|
||||
.shape = src_shape,
|
||||
.strides = src_strides
|
||||
};
|
||||
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
||||
|
||||
// Destination ndarray
|
||||
// As documented, ndims and shape & strides must be allocated and determined by the caller.
|
||||
const int32_t dst_ndims = 2;
|
||||
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
|
||||
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
|
||||
NDArray<int32_t> dst_ndarray = {
|
||||
.data = nullptr,
|
||||
.ndims = dst_ndims,
|
||||
.shape = dst_shape,
|
||||
.strides = dst_strides
|
||||
};
|
||||
|
||||
// Create the slice in `ndarray[-2::, 1::2]`
|
||||
UserSlice subscript_1;
|
||||
subscript_1.set_start(-2);
|
||||
|
||||
UserSlice subscript_2;
|
||||
subscript_2.set_start(1);
|
||||
subscript_2.set_step(2);
|
||||
|
||||
const int32_t num_ndsubscripts = 2;
|
||||
NDSubscript ndsubscripts[num_ndsubscripts] = {
|
||||
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_1 },
|
||||
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_2 }
|
||||
};
|
||||
|
||||
ErrorContext errctx = create_testing_errctx();
|
||||
ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray);
|
||||
assert_errctx_no_error(&errctx);
|
||||
|
||||
int32_t expected_shape[dst_ndims] = { 2, 2 };
|
||||
int32_t expected_strides[dst_ndims] = { 32, 16 };
|
||||
|
||||
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
|
||||
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
|
||||
|
||||
// dst_ndarray[0, 0]
|
||||
assert_values_match(
|
||||
5.0,
|
||||
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0, 0 }))
|
||||
);
|
||||
// dst_ndarray[0, 1]
|
||||
assert_values_match(
|
||||
7.0,
|
||||
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0, 1 }))
|
||||
);
|
||||
// dst_ndarray[1, 0]
|
||||
assert_values_match(
|
||||
9.0,
|
||||
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1, 0 }))
|
||||
);
|
||||
// dst_ndarray[1, 1]
|
||||
assert_values_match(
|
||||
11.0,
|
||||
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1, 1 }))
|
||||
);
|
||||
}
|
||||
|
||||
void test_ndsubscript_normal_2() {
|
||||
/*
|
||||
```python
|
||||
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
|
||||
# array([[ 0., 1., 2., 3.],
|
||||
# [ 4., 5., 6., 7.],
|
||||
# [ 8., 9., 10., 11.]])
|
||||
|
||||
dst_ndarray = ndarray[2, ::-2]
|
||||
# array([11., 9.])
|
||||
|
||||
assert dst_ndarray.shape == (2,)
|
||||
assert dst_ndarray.strides == (-16,)
|
||||
assert dst_ndarray[0] == 11.0
|
||||
assert dst_ndarray[1] == 9.0
|
||||
```
|
||||
*/
|
||||
BEGIN_TEST();
|
||||
|
||||
double src_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
|
||||
int32_t src_itemsize = sizeof(double);
|
||||
const int32_t src_ndims = 2;
|
||||
int32_t src_shape[src_ndims] = { 3, 4 };
|
||||
int32_t src_strides[src_ndims] = {};
|
||||
NDArray<int32_t> src_ndarray = {
|
||||
.data = (uint8_t*) src_data,
|
||||
.itemsize = src_itemsize,
|
||||
.ndims = src_ndims,
|
||||
.shape = src_shape,
|
||||
.strides = src_strides
|
||||
};
|
||||
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
||||
|
||||
// Destination ndarray
|
||||
// As documented, ndims and shape & strides must be allocated and determined by the caller.
|
||||
const int32_t dst_ndims = 1;
|
||||
int32_t dst_shape[dst_ndims] = {999}; // Empty values
|
||||
int32_t dst_strides[dst_ndims] = {999}; // Empty values
|
||||
NDArray<int32_t> dst_ndarray = {
|
||||
.data = nullptr,
|
||||
.ndims = dst_ndims,
|
||||
.shape = dst_shape,
|
||||
.strides = dst_strides
|
||||
};
|
||||
|
||||
// Create the slice in `ndarray[2, ::-2]`
|
||||
int32_t subscript_1 = 2;
|
||||
|
||||
UserSlice subscript_2;
|
||||
subscript_2.set_step(-2);
|
||||
|
||||
const int32_t num_ndsubscripts = 2;
|
||||
NDSubscript ndsubscripts[num_ndsubscripts] = {
|
||||
{ .type = INPUT_SUBSCRIPT_TYPE_INDEX, .data = (uint8_t*) &subscript_1 },
|
||||
{ .type = INPUT_SUBSCRIPT_TYPE_SLICE, .data = (uint8_t*) &subscript_2 }
|
||||
};
|
||||
|
||||
ErrorContext errctx = create_testing_errctx();
|
||||
ndarray::subscript::subscript(&errctx, num_ndsubscripts, ndsubscripts, &src_ndarray, &dst_ndarray);
|
||||
assert_errctx_no_error(&errctx);
|
||||
|
||||
int32_t expected_shape[dst_ndims] = { 2 };
|
||||
int32_t expected_strides[dst_ndims] = { -16 };
|
||||
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
|
||||
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
|
||||
|
||||
assert_values_match(
|
||||
11.0,
|
||||
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 0 }))
|
||||
);
|
||||
assert_values_match(
|
||||
9.0,
|
||||
*((double *) ndarray::basic::get_pelement_by_indices(&dst_ndarray, (int32_t[dst_ndims]) { 1 }))
|
||||
);
|
||||
}
|
||||
|
||||
void run() {
|
||||
test_ndsubscript_normal_1();
|
||||
test_ndsubscript_normal_2();
|
||||
}
|
||||
} }
|
|
@ -3,18 +3,94 @@
|
|||
#include <test/core.hpp>
|
||||
#include <irrt_everything.hpp>
|
||||
|
||||
void test_slice_1() {
|
||||
namespace test {
|
||||
namespace slice {
|
||||
void test_slice_normal() {
|
||||
// Normal situation
|
||||
BEGIN_TEST();
|
||||
|
||||
UserSlice user_slice(5);
|
||||
UserSlice user_slice;
|
||||
user_slice.set_stop(5);
|
||||
|
||||
Slice slice;
|
||||
user_slice.indices(100, &slice);
|
||||
|
||||
printf("%d, %d, %d\n", slice.start, slice.stop, slice.step);
|
||||
|
||||
assert_values_match(0, slice.start);
|
||||
assert_values_match(5, slice.stop);
|
||||
assert_values_match(1, slice.step);
|
||||
}
|
||||
|
||||
void run_all_tests_ndarray_slice() {
|
||||
test_slice_1();
|
||||
void test_slice_start_too_large() {
|
||||
// Start is too large and should be clamped to length
|
||||
BEGIN_TEST();
|
||||
|
||||
UserSlice user_slice;
|
||||
user_slice.set_start(400);
|
||||
|
||||
Slice slice;
|
||||
user_slice.indices(100, &slice);
|
||||
|
||||
assert_values_match(100, slice.start);
|
||||
assert_values_match(100, slice.stop);
|
||||
assert_values_match(1, slice.step);
|
||||
}
|
||||
|
||||
void test_slice_negative_start_stop() {
|
||||
// Negative start/stop should be resolved
|
||||
BEGIN_TEST();
|
||||
|
||||
UserSlice user_slice;
|
||||
user_slice.set_start(-10);
|
||||
user_slice.set_stop(-5);
|
||||
|
||||
Slice slice;
|
||||
user_slice.indices(100, &slice);
|
||||
|
||||
assert_values_match(90, slice.start);
|
||||
assert_values_match(95, slice.stop);
|
||||
assert_values_match(1, slice.step);
|
||||
}
|
||||
|
||||
void test_slice_only_negative_step() {
|
||||
// Things like `[::-5]` should be handled correctly
|
||||
BEGIN_TEST();
|
||||
|
||||
UserSlice user_slice;
|
||||
user_slice.set_step(-5);
|
||||
|
||||
Slice slice;
|
||||
user_slice.indices(100, &slice);
|
||||
|
||||
assert_values_match(99, slice.start);
|
||||
assert_values_match(-1, slice.stop);
|
||||
assert_values_match(-5, slice.step);
|
||||
}
|
||||
|
||||
void test_slice_step_zero() {
|
||||
// Step = 0 is a value error
|
||||
BEGIN_TEST();
|
||||
|
||||
ErrorContext errctx = create_testing_errctx();
|
||||
|
||||
UserSlice user_slice;
|
||||
user_slice.set_start(2);
|
||||
user_slice.set_stop(12);
|
||||
user_slice.set_step(0);
|
||||
|
||||
Slice slice;
|
||||
user_slice.indices_checked(&errctx, 100, &slice);
|
||||
|
||||
assert_errctx_has_error(&errctx, errctx.error_ids->value_error);
|
||||
}
|
||||
|
||||
void run() {
|
||||
test_slice_normal();
|
||||
test_slice_start_too_large();
|
||||
test_slice_negative_start_stop();
|
||||
test_slice_only_negative_step();
|
||||
test_slice_step_zero();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue