core: irrt general numpy broadcasting
This commit is contained in:
parent
d18c769cdc
commit
9aae290727
|
@ -13,6 +13,17 @@ using NDIndex = uint32_t;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
namespace ndarray_util {
|
namespace ndarray_util {
|
||||||
|
template <typename SizeT>
|
||||||
|
static void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
||||||
|
for (int32_t i = 0; i < ndims; i++) {
|
||||||
|
int32_t dim_i = ndims - i - 1;
|
||||||
|
int32_t dim = shape[dim_i];
|
||||||
|
|
||||||
|
indices[dim_i] = nth % dim;
|
||||||
|
nth /= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Compute the strides of an ndarray given an ndarray `shape`
|
// Compute the strides of an ndarray given an ndarray `shape`
|
||||||
// and assuming that the ndarray is *fully C-contagious*.
|
// and assuming that the ndarray is *fully C-contagious*.
|
||||||
//
|
//
|
||||||
|
@ -34,6 +45,57 @@ namespace {
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i];
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i];
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
static bool can_broadcast_shape_to(
|
||||||
|
const SizeT target_ndims,
|
||||||
|
const SizeT *target_shape,
|
||||||
|
const SizeT src_ndims,
|
||||||
|
const SizeT *src_shape
|
||||||
|
) {
|
||||||
|
/*
|
||||||
|
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||||
|
|
||||||
|
This function handles this example:
|
||||||
|
```
|
||||||
|
Image (3d array): 256 x 256 x 3
|
||||||
|
Scale (1d array): 3
|
||||||
|
Result (3d array): 256 x 256 x 3
|
||||||
|
```
|
||||||
|
|
||||||
|
Other interesting examples to consider:
|
||||||
|
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
||||||
|
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
||||||
|
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
||||||
|
|
||||||
|
In cases when the shapes contain zero(es):
|
||||||
|
- `can_broadcast_shape_to([0], [1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0], [2]) == false`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
||||||
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
||||||
|
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
||||||
|
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
||||||
|
*/
|
||||||
|
|
||||||
|
// This is essentially doing the following in Python:
|
||||||
|
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
||||||
|
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
||||||
|
SizeT target_dim_i = target_ndims - i - 1;
|
||||||
|
SizeT src_dim_i = src_ndims - i - 1;
|
||||||
|
|
||||||
|
bool target_dim_exists = target_dim_i >= 0;
|
||||||
|
bool src_dim_exists = src_dim_i >= 0;
|
||||||
|
|
||||||
|
SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1;
|
||||||
|
SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1;
|
||||||
|
|
||||||
|
bool ok = src_dim == 1 || target_dim == src_dim;
|
||||||
|
if (!ok) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef uint8_t NDSliceType;
|
typedef uint8_t NDSliceType;
|
||||||
|
@ -55,7 +117,7 @@ namespace {
|
||||||
|
|
||||||
namespace ndarray_util {
|
namespace ndarray_util {
|
||||||
template<typename SizeT>
|
template<typename SizeT>
|
||||||
SizeT deduce_ndims_after_slicing(SizeT ndims, const SizeT num_slices, const NDSlice *slices) {
|
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_slices, const NDSlice *slices) {
|
||||||
irrt_assert(num_slices <= ndims);
|
irrt_assert(num_slices <= ndims);
|
||||||
|
|
||||||
SizeT final_ndims = ndims;
|
SizeT final_ndims = ndims;
|
||||||
|
@ -150,17 +212,26 @@ namespace {
|
||||||
return this->size() * itemsize;
|
return this->size() * itemsize;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_value_at_pelement(uint8_t* pelement, uint8_t* pvalue) {
|
void set_value_at_pelement(uint8_t* pelement, const uint8_t* pvalue) {
|
||||||
__builtin_memcpy(pelement, pvalue, itemsize);
|
__builtin_memcpy(pelement, pvalue, itemsize);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t* get_pelement(SizeT *indices) {
|
uint8_t* get_pelement(const SizeT *indices) {
|
||||||
uint8_t* element = data;
|
uint8_t* element = data;
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
|
||||||
element += indices[dim_i] * strides[dim_i];
|
element += indices[dim_i] * strides[dim_i];
|
||||||
return element;
|
return element;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint8_t* get_nth_pelement(SizeT nth) {
|
||||||
|
irrt_assert(0 <= nth);
|
||||||
|
irrt_assert(nth < this->size());
|
||||||
|
|
||||||
|
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
|
||||||
|
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
|
||||||
|
return get_pelement(indices);
|
||||||
|
}
|
||||||
|
|
||||||
// Get pointer to the first element of this ndarray, assuming
|
// Get pointer to the first element of this ndarray, assuming
|
||||||
// `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`)
|
// `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`)
|
||||||
//
|
//
|
||||||
|
@ -171,7 +242,7 @@ namespace {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is the given `indices` valid/in-bounds?
|
// Is the given `indices` valid/in-bounds?
|
||||||
bool in_bounds(SizeT *indices) {
|
bool in_bounds(const SizeT *indices) {
|
||||||
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
|
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
|
||||||
bool dim_ok = indices[dim_i] < shape[dim_i];
|
bool dim_ok = indices[dim_i] < shape[dim_i];
|
||||||
if (!dim_ok) return false;
|
if (!dim_ok) return false;
|
||||||
|
@ -180,7 +251,7 @@ namespace {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill the ndarray with a value
|
// Fill the ndarray with a value
|
||||||
void fill_generic(uint8_t* pvalue) {
|
void fill_generic(const uint8_t* pvalue) {
|
||||||
NDArrayIndicesIter<SizeT> iter;
|
NDArrayIndicesIter<SizeT> iter;
|
||||||
iter.ndims = this->ndims;
|
iter.ndims = this->ndims;
|
||||||
iter.shape = this->shape;
|
iter.shape = this->shape;
|
||||||
|
@ -199,7 +270,7 @@ namespace {
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
|
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
|
||||||
void set_to_eye(SizeT k, uint8_t* zero_pvalue, uint8_t* one_pvalue) {
|
void set_to_eye(SizeT k, const uint8_t* zero_pvalue, const uint8_t* one_pvalue) {
|
||||||
__builtin_assume(ndims == 2);
|
__builtin_assume(ndims == 2);
|
||||||
|
|
||||||
// TODO: Better implementation
|
// TODO: Better implementation
|
||||||
|
@ -275,6 +346,63 @@ namespace {
|
||||||
|
|
||||||
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation
|
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
|
||||||
|
// Assumptions:
|
||||||
|
// - `this` has to be fully initialized.
|
||||||
|
// - `dst_ndarray->ndims` has to be set.
|
||||||
|
// - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to.
|
||||||
|
//
|
||||||
|
// Other notes:
|
||||||
|
// - `dst_ndarray->data` does not have to be set, it will be set to `this->data`.
|
||||||
|
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->data`.
|
||||||
|
// - `dst_ndarray->strides` does not have to be set, it will be overwritten.
|
||||||
|
//
|
||||||
|
// Cautions:
|
||||||
|
// ```
|
||||||
|
// xs = np.zeros((4,))
|
||||||
|
// ys = np.zero((4, 1))
|
||||||
|
// ys[:] = xs # ok
|
||||||
|
//
|
||||||
|
// xs = np.zeros((1, 4))
|
||||||
|
// ys = np.zero((4,))
|
||||||
|
// ys[:] = xs # allowed
|
||||||
|
// # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
|
||||||
|
// # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
|
||||||
|
// # This implementation will NOT support this assignment.
|
||||||
|
// ```
|
||||||
|
void broadcast_to(NDArray<SizeT>* dst_ndarray) {
|
||||||
|
dst_ndarray->data = this->data;
|
||||||
|
dst_ndarray->itemsize = this->itemsize;
|
||||||
|
|
||||||
|
irrt_assert(
|
||||||
|
ndarray_util::can_broadcast_shape_to(
|
||||||
|
dst_ndarray->ndims,
|
||||||
|
dst_ndarray->shape,
|
||||||
|
this->ndims,
|
||||||
|
this->shape
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
SizeT stride_product = 1;
|
||||||
|
for (SizeT i = 0; i < max(this->ndims, dst_ndarray->ndims); i++) {
|
||||||
|
SizeT this_dim_i = this->ndims - i - 1;
|
||||||
|
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
|
||||||
|
|
||||||
|
bool this_dim_exists = this_dim_i >= 0;
|
||||||
|
bool dst_dim_exists = dst_dim_i >= 0;
|
||||||
|
|
||||||
|
// TODO: Explain how this works
|
||||||
|
bool c1 = this_dim_exists && this->shape[this_dim_i] == 1;
|
||||||
|
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
|
||||||
|
if (!this_dim_exists || (c1 && c2)) {
|
||||||
|
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
|
||||||
|
} else {
|
||||||
|
dst_ndarray->strides[dst_dim_i] = stride_product * this->itemsize;
|
||||||
|
stride_product *= this->shape[this_dim_i]; // NOTE: this_dim_exist must be true here.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,10 +33,11 @@ void debug_print_array(const char* format, int len, T* as) {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) {
|
void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) {
|
||||||
if (!arrays_match(len, expected, got)) {
|
if (!arrays_match(len, expected, got)) {
|
||||||
printf("expected %s: ", label);
|
printf(">>>>>>> %s\n", label);
|
||||||
|
printf(" Expecting = ");
|
||||||
debug_print_array(format, len, expected);
|
debug_print_array(format, len, expected);
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("got %s: ", label);
|
printf(" Got = ");
|
||||||
debug_print_array(format, len, got);
|
debug_print_array(format, len, got);
|
||||||
printf("\n");
|
printf("\n");
|
||||||
test_fail();
|
test_fail();
|
||||||
|
@ -46,22 +47,89 @@ void assert_arrays_match(const char* label, const char* format, int len, T* expe
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void assert_values_match(const char* label, const char* format, T expected, T got) {
|
void assert_values_match(const char* label, const char* format, T expected, T got) {
|
||||||
if (expected != got) {
|
if (expected != got) {
|
||||||
printf("expected %s: ", label);
|
printf(">>>>>>> %s\n", label);
|
||||||
|
printf(" Expecting = ");
|
||||||
printf(format, expected);
|
printf(format, expected);
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("got %s: ", label);
|
printf(" Got = ");
|
||||||
printf(format, got);
|
printf(format, got);
|
||||||
printf("\n");
|
printf("\n");
|
||||||
test_fail();
|
test_fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print_repeated(const char *str, int count) {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
printf("%s", str);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT, typename ElementT>
|
||||||
|
void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* cursor, SizeT depth, NDArray<SizeT>* ndarray) {
|
||||||
|
// A really lazy recursive implementation
|
||||||
|
|
||||||
|
// Add left padding unless its the first entry (since there would be "[[[" before it)
|
||||||
|
if (!first) {
|
||||||
|
print_repeated(" ", depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
const SizeT dim = ndarray->shape[depth];
|
||||||
|
if (depth + 1 == ndarray->ndims) {
|
||||||
|
// Recursed down to last dimension, print the values in a nice list
|
||||||
|
printf("[");
|
||||||
|
|
||||||
|
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
|
||||||
|
for (SizeT i = 0; i < dim; i++) {
|
||||||
|
ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
|
||||||
|
ElementT* pelement = (ElementT*) ndarray->get_pelement(indices);
|
||||||
|
ElementT element = *pelement;
|
||||||
|
|
||||||
|
if (i != 0) printf(", "); // List delimiter
|
||||||
|
printf(format, element);
|
||||||
|
printf("(@");
|
||||||
|
debug_print_array("%d", ndarray->ndims, indices);
|
||||||
|
printf(")");
|
||||||
|
|
||||||
|
(*cursor)++;
|
||||||
|
}
|
||||||
|
printf("]");
|
||||||
|
} else {
|
||||||
|
printf("[");
|
||||||
|
for (SizeT i = 0; i < ndarray->shape[depth]; i++) {
|
||||||
|
__print_ndarray_aux<SizeT, ElementT>(
|
||||||
|
format,
|
||||||
|
i == 0, // first?
|
||||||
|
i + 1 == dim, // last?
|
||||||
|
cursor,
|
||||||
|
depth + 1,
|
||||||
|
ndarray
|
||||||
|
);
|
||||||
|
}
|
||||||
|
printf("]");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add newline unless its the last entry (since there will be "]]]" after it)
|
||||||
|
if (!last) {
|
||||||
|
print_repeated("\n", depth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT, typename ElementT>
|
||||||
|
void print_ndarray(const char *format, NDArray<SizeT>* ndarray) {
|
||||||
|
if (ndarray->ndims == 0) {
|
||||||
|
printf("<empty ndarray>");
|
||||||
|
} else {
|
||||||
|
SizeT cursor = 0;
|
||||||
|
__print_ndarray_aux<SizeT, ElementT>(format, true, true, &cursor, 0, ndarray);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
void test_calc_size_from_shape_normal() {
|
void test_calc_size_from_shape_normal() {
|
||||||
// Test shapes with normal values
|
// Test shapes with normal values
|
||||||
BEGIN_TEST();
|
BEGIN_TEST();
|
||||||
|
|
||||||
int32_t shape[4] = { 2, 3, 5, 7 };
|
int32_t shape[4] = { 2, 3, 5, 7 };
|
||||||
debug_print_array("%d", 4, shape);
|
|
||||||
assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
|
assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,9 +335,6 @@ void test_ndslice_1() {
|
||||||
assert dst_ndarray[0, 1] == 7.0
|
assert dst_ndarray[0, 1] == 7.0
|
||||||
assert dst_ndarray[1, 0] == 9.0
|
assert dst_ndarray[1, 0] == 9.0
|
||||||
assert dst_ndarray[1, 1] == 11.0
|
assert dst_ndarray[1, 1] == 11.0
|
||||||
|
|
||||||
dst_ndarray[1, 0] == 99 # Write to `dst_ndarray`
|
|
||||||
assert ndarray[1, 3] == 99 # `ndarray` also updates!!
|
|
||||||
```
|
```
|
||||||
*/
|
*/
|
||||||
BEGIN_TEST();
|
BEGIN_TEST();
|
||||||
|
@ -410,6 +475,160 @@ void test_ndslice_2() {
|
||||||
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 })));
|
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 })));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_can_broadcast_shape() {
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 5, (int32_t[]) { 1, 1, 1, 1, 3 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([3], [3, 1]) == false",
|
||||||
|
"%d",
|
||||||
|
false,
|
||||||
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 2, (int32_t[]) { 3, 1 }));
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([3], [3]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 1, (int32_t[]) { 3 }));
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([1], [3]) == false",
|
||||||
|
"%d",
|
||||||
|
false,
|
||||||
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 3 }));
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([1], [1]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 1 }));
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 3, (int32_t[]) { 256, 1, 3 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([256, 256, 3], [3]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 3 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([256, 256, 3], [2]) == false",
|
||||||
|
"%d",
|
||||||
|
false,
|
||||||
|
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 2 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([256, 256, 3], [1]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 1 })
|
||||||
|
);
|
||||||
|
|
||||||
|
// In cases when the shapes contain zero(es)
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([0], [1]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 1 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([0], [2]) == false",
|
||||||
|
"%d",
|
||||||
|
false,
|
||||||
|
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 2 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([0, 4, 0, 0], [1]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 1, (int32_t[]) { 1 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 1, 1, 1 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true",
|
||||||
|
"%d",
|
||||||
|
true,
|
||||||
|
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 4, 1, 1 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([4, 3], [0, 3]) == false",
|
||||||
|
"%d",
|
||||||
|
false,
|
||||||
|
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 3 })
|
||||||
|
);
|
||||||
|
assert_values_match(
|
||||||
|
"can_broadcast_shape_to([4, 3], [0, 0]) == false",
|
||||||
|
"%d",
|
||||||
|
false,
|
||||||
|
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 0 })
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ndarray_broadcast_1() {
|
||||||
|
/*
|
||||||
|
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
|
||||||
|
# >>> [[19.9 29.9 39.9 49.9]]
|
||||||
|
#
|
||||||
|
# array = np.broadcast_to(array, (2, 3, 4))
|
||||||
|
# >>> [[[19.9 29.9 39.9 49.9]
|
||||||
|
# >>> [19.9 29.9 39.9 49.9]
|
||||||
|
# >>> [19.9 29.9 39.9 49.9]]
|
||||||
|
# >>> [[19.9 29.9 39.9 49.9]
|
||||||
|
# >>> [19.9 29.9 39.9 49.9]
|
||||||
|
# >>> [19.9 29.9 39.9 49.9]]]
|
||||||
|
#
|
||||||
|
# assery array.strides == (0, 0, 8)
|
||||||
|
|
||||||
|
*/
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
double in_data[4] = { 19.9, 29.9, 39.9, 49.9 };
|
||||||
|
const int32_t in_ndims = 2;
|
||||||
|
int32_t in_shape[in_ndims] = {1, 4};
|
||||||
|
int32_t in_strides[in_ndims] = {};
|
||||||
|
NDArray<int32_t> ndarray = {
|
||||||
|
.data = (uint8_t*) in_data,
|
||||||
|
.itemsize = sizeof(double),
|
||||||
|
.ndims = in_ndims,
|
||||||
|
.shape = in_shape,
|
||||||
|
.strides = in_strides
|
||||||
|
};
|
||||||
|
ndarray.set_strides_by_shape();
|
||||||
|
|
||||||
|
const int32_t dst_ndims = 3;
|
||||||
|
int32_t dst_shape[dst_ndims] = {2, 3, 4};
|
||||||
|
int32_t dst_strides[dst_ndims] = {};
|
||||||
|
NDArray<int32_t> dst_ndarray = {
|
||||||
|
.ndims = dst_ndims,
|
||||||
|
.shape = dst_shape,
|
||||||
|
.strides = dst_strides
|
||||||
|
};
|
||||||
|
|
||||||
|
ndarray.broadcast_to(&dst_ndarray);
|
||||||
|
|
||||||
|
assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides);
|
||||||
|
|
||||||
|
assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 0})));
|
||||||
|
assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 1})));
|
||||||
|
assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 2})));
|
||||||
|
assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 3})));
|
||||||
|
assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 0})));
|
||||||
|
assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 1})));
|
||||||
|
assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 2})));
|
||||||
|
assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 3})));
|
||||||
|
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {1, 2, 3})));
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
test_calc_size_from_shape_normal();
|
test_calc_size_from_shape_normal();
|
||||||
test_calc_size_from_shape_has_zero();
|
test_calc_size_from_shape_has_zero();
|
||||||
|
@ -423,5 +642,7 @@ int main() {
|
||||||
test_slice_4();
|
test_slice_4();
|
||||||
test_ndslice_1();
|
test_ndslice_1();
|
||||||
test_ndslice_2();
|
test_ndslice_2();
|
||||||
|
test_can_broadcast_shape();
|
||||||
|
test_ndarray_broadcast_1();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
|
@ -30,6 +30,7 @@ namespace {
|
||||||
*death = 0; // TODO: address 0 on hardware might be writable?
|
*death = 0; // TODO: address 0 on hardware might be writable?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Make this a macro and allow it to be toggled on/off (e.g., debug vs release)
|
||||||
void irrt_assert(bool condition) {
|
void irrt_assert(bool condition) {
|
||||||
if (!condition) irrt_panic();
|
if (!condition) irrt_panic();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue