forked from M-Labs/nac3
core/ndstrides: implement numpy broadcasting IRRT
This commit is contained in:
parent
7501a086d0
commit
eb295cf7e4
|
@ -0,0 +1,221 @@
|
|||
#pragma once
|
||||
|
||||
#include <irrt/error_context.hpp>
|
||||
#include <irrt/int_defs.hpp>
|
||||
#include <irrt/ndarray/def.hpp>
|
||||
#include <irrt/slice.hpp>
|
||||
|
||||
namespace {
|
||||
template <typename SizeT>
|
||||
struct ShapeEntry {
|
||||
SizeT ndims;
|
||||
SizeT* shape;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
namespace ndarray {
|
||||
namespace broadcast {
|
||||
namespace util {
|
||||
/**
|
||||
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
|
||||
*/
|
||||
template <typename SizeT>
|
||||
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape,
|
||||
SizeT src_ndims, const SizeT* src_shape) {
|
||||
/*
|
||||
* // See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||
|
||||
* This function handles this example:
|
||||
* ```
|
||||
* Image (3d array): 256 x 256 x 3
|
||||
* Scale (1d array): 3
|
||||
* Result (3d array): 256 x 256 x 3
|
||||
* ```
|
||||
|
||||
* Other interesting examples to consider:
|
||||
* - `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
||||
* - `can_broadcast_shape_to([3], [3, 1]) == false`
|
||||
* - `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
||||
|
||||
* In cases when the shapes contain zero(es):
|
||||
* - `can_broadcast_shape_to([0], [1]) == true`
|
||||
* - `can_broadcast_shape_to([0], [2]) == false`
|
||||
* - `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
||||
* - `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
||||
* - `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
||||
* - `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
||||
* - `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
||||
*/
|
||||
|
||||
// This is essentially doing the following in Python:
|
||||
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
||||
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
||||
SizeT target_dim_i = target_ndims - i - 1;
|
||||
SizeT src_dim_i = src_ndims - i - 1;
|
||||
|
||||
bool target_dim_exists = target_dim_i >= 0;
|
||||
bool src_dim_exists = src_dim_i >= 0;
|
||||
|
||||
SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1;
|
||||
SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1;
|
||||
|
||||
bool ok = src_dim == 1 || target_dim == src_dim;
|
||||
if (!ok) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs `np.broadcast_shapes`
|
||||
*/
|
||||
template <typename SizeT>
|
||||
void broadcast_shapes(ErrorContext* errctx, SizeT num_shapes,
|
||||
const ShapeEntry<SizeT>* shapes, SizeT dst_ndims,
|
||||
SizeT* dst_shape) {
|
||||
// `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it
|
||||
// for this function since it should already know in order to allocate `dst_shape` in the first place.
|
||||
// `dst_shape` must be pre-allocated.
|
||||
// `dst_shape` does not have to be initialized
|
||||
|
||||
// TODO: Implementation is not obvious
|
||||
|
||||
// This is essentially a `mconcat` where the neutral element is `[1, 1, 1, 1, ...]`, and the operation is commutative.
|
||||
|
||||
// Set `dst_shape` to all `1`s.
|
||||
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
|
||||
dst_shape[dst_axis] = 0;
|
||||
}
|
||||
|
||||
for (SizeT i = 0; i < num_shapes; i++) {
|
||||
ShapeEntry<SizeT> entry = shapes[i];
|
||||
SizeT entry_axis = entry.ndims - i;
|
||||
SizeT dst_axis = dst_ndims - i;
|
||||
|
||||
SizeT entry_dim = entry.shape[entry_axis];
|
||||
SizeT dst_dim = dst_shape[dst_axis];
|
||||
|
||||
if (dst_dim == 1) {
|
||||
dst_shape[dst_axis] = entry_dim;
|
||||
} else if (entry_dim == 1) {
|
||||
// Do nothing
|
||||
} else if (entry_dim == dst_dim) {
|
||||
// Do nothing
|
||||
} else {
|
||||
errctx->set_exception(errctx->exceptions->value_error,
|
||||
"shape mismatch: objects cannot be broadcast "
|
||||
"to a single shape.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace util
|
||||
|
||||
/**
|
||||
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
|
||||
*
|
||||
* Cautious note on https://github.com/numpy/numpy/issues/21744..
|
||||
*
|
||||
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
|
||||
* and return the result by modifying `dst_ndarray`.
|
||||
*
|
||||
* # Notes on `dst_ndarray`
|
||||
* The caller is responsible for allocating space for the resulting ndarray.
|
||||
* Here is what this function expects from `dst_ndarray` when called:
|
||||
* - `dst_ndarray->data` does not have to be initialized.
|
||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
|
||||
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
|
||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||
* When this function call ends:
|
||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
||||
* - `dst_ndarray->ndims` is unchanged.
|
||||
* - `dst_ndarray->shape` is unchanged.
|
||||
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
|
||||
*/
|
||||
template <typename SizeT>
|
||||
void broadcast_to(ErrorContext* errctx, const NDArray<SizeT>* src_ndarray,
|
||||
NDArray<SizeT>* dst_ndarray) {
|
||||
/*
|
||||
* Cautions:
|
||||
* ```
|
||||
* xs = np.zeros((4,))
|
||||
* ys = np.zero((4, 1))
|
||||
* ys[:] = xs # ok
|
||||
*
|
||||
* xs = np.zeros((1, 4))
|
||||
* ys = np.zero((4,))
|
||||
* ys[:] = xs # allowed
|
||||
* # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
|
||||
* # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
|
||||
* # This implementation will NOT support this assignment.
|
||||
* ```
|
||||
*/
|
||||
|
||||
if (!ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
|
||||
src_ndarray->shape)) {
|
||||
errctx->set_exception(errctx->exceptions->value_error,
|
||||
"operands could not be broadcast together");
|
||||
return;
|
||||
}
|
||||
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
||||
// TODO: Implementation is not obvious
|
||||
SizeT stride_product = 1;
|
||||
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {
|
||||
SizeT src_ndarray_dim_i = src_ndarray->ndims - i - 1;
|
||||
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
|
||||
|
||||
bool src_ndarray_dim_exists = src_ndarray_dim_i >= 0;
|
||||
bool dst_dim_exists = dst_dim_i >= 0;
|
||||
|
||||
bool c1 = src_ndarray_dim_exists &&
|
||||
src_ndarray->shape[src_ndarray_dim_i] == 1;
|
||||
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
|
||||
if (!src_ndarray_dim_exists || (c1 && c2)) {
|
||||
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
|
||||
} else {
|
||||
dst_ndarray->strides[dst_dim_i] =
|
||||
stride_product * src_ndarray->itemsize;
|
||||
stride_product *= src_ndarray->shape[src_ndarray_dim_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace broadcast
|
||||
} // namespace ndarray
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::broadcast;
|
||||
|
||||
void __nac3_ndarray_broadcast_to(ErrorContext* errctx,
|
||||
NDArray<int32_t>* src_ndarray,
|
||||
NDArray<int32_t>* dst_ndarray) {
|
||||
broadcast_to(errctx, src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_to64(ErrorContext* errctx,
|
||||
NDArray<int64_t>* src_ndarray,
|
||||
NDArray<int64_t>* dst_ndarray) {
|
||||
broadcast_to(errctx, src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_shapes(ErrorContext* errctx, int32_t num_shapes,
|
||||
const ShapeEntry<int32_t>* shapes,
|
||||
int32_t dst_ndims, int32_t* dst_shape) {
|
||||
ndarray::broadcast::util::broadcast_shapes(errctx, num_shapes, shapes,
|
||||
dst_ndims, dst_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_shapes64(ErrorContext* errctx, int64_t num_shapes,
|
||||
const ShapeEntry<int64_t>* shapes,
|
||||
int64_t dst_ndims, int64_t* dst_shape) {
|
||||
ndarray::broadcast::util::broadcast_shapes(errctx, num_shapes, shapes,
|
||||
dst_ndims, dst_shape);
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@
|
|||
#include <irrt/error_context.hpp>
|
||||
#include <irrt/int_defs.hpp>
|
||||
#include <irrt/ndarray/basic.hpp>
|
||||
#include <irrt/ndarray/broadcast.hpp>
|
||||
#include <irrt/ndarray/def.hpp>
|
||||
#include <irrt/ndarray/indexing.hpp>
|
||||
#include <irrt/ndarray/reshape.hpp>
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include <cstdlib>
|
||||
#include <test/test_core.hpp>
|
||||
#include <test/test_ndarray_basic.hpp>
|
||||
#include <test/test_ndarray_broadcast.hpp>
|
||||
#include <test/test_ndarray_indexing.hpp>
|
||||
#include <test/test_slice.hpp>
|
||||
|
||||
|
@ -14,5 +15,6 @@ int main() {
|
|||
test::slice::run();
|
||||
test::ndarray_basic::run();
|
||||
test::ndarray_indexing::run();
|
||||
test::ndarray_broadcast::run();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
#pragma once
|
||||
|
||||
#include <test/includes.hpp>
|
||||
|
||||
namespace test {
|
||||
namespace ndarray_broadcast {
|
||||
void test_can_broadcast_shape() {
|
||||
BEGIN_TEST();
|
||||
|
||||
assert_values_match(true,
|
||||
ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
1, (int32_t[]){3}, 5, (int32_t[]){1, 1, 1, 1, 3}));
|
||||
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
1, (int32_t[]){3}, 2, (int32_t[]){3, 1}));
|
||||
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
1, (int32_t[]){3}, 1, (int32_t[]){3}));
|
||||
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
1, (int32_t[]){1}, 1, (int32_t[]){3}));
|
||||
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
1, (int32_t[]){1}, 1, (int32_t[]){1}));
|
||||
assert_values_match(
|
||||
true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
3, (int32_t[]){256, 256, 3}, 3, (int32_t[]){256, 1, 3}));
|
||||
assert_values_match(true,
|
||||
ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){3}));
|
||||
assert_values_match(false,
|
||||
ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){2}));
|
||||
assert_values_match(true,
|
||||
ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){1}));
|
||||
|
||||
// In cases when the shapes contain zero(es)
|
||||
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
1, (int32_t[]){0}, 1, (int32_t[]){1}));
|
||||
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
1, (int32_t[]){0}, 1, (int32_t[]){2}));
|
||||
assert_values_match(true,
|
||||
ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
4, (int32_t[]){0, 4, 0, 0}, 1, (int32_t[]){1}));
|
||||
assert_values_match(
|
||||
true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 1, 1, 1}));
|
||||
assert_values_match(
|
||||
true, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 4, 1, 1}));
|
||||
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 3}));
|
||||
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
|
||||
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 0}));
|
||||
}
|
||||
|
||||
void test_ndarray_broadcast() {
|
||||
/*
|
||||
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
|
||||
# >>> [[19.9 29.9 39.9 49.9]]
|
||||
#
|
||||
# array = np.broadcast_to(array, (2, 3, 4))
|
||||
# >>> [[[19.9 29.9 39.9 49.9]
|
||||
# >>> [19.9 29.9 39.9 49.9]
|
||||
# >>> [19.9 29.9 39.9 49.9]]
|
||||
# >>> [[19.9 29.9 39.9 49.9]
|
||||
# >>> [19.9 29.9 39.9 49.9]
|
||||
# >>> [19.9 29.9 39.9 49.9]]]
|
||||
#
|
||||
# assery array.strides == (0, 0, 8)
|
||||
|
||||
*/
|
||||
BEGIN_TEST();
|
||||
|
||||
double in_data[4] = {19.9, 29.9, 39.9, 49.9};
|
||||
const int32_t in_ndims = 2;
|
||||
int32_t in_shape[in_ndims] = {1, 4};
|
||||
int32_t in_strides[in_ndims] = {};
|
||||
NDArray<int32_t> ndarray = {.data = (uint8_t*)in_data,
|
||||
.itemsize = sizeof(double),
|
||||
.ndims = in_ndims,
|
||||
.shape = in_shape,
|
||||
.strides = in_strides};
|
||||
ndarray::basic::set_strides_by_shape(&ndarray);
|
||||
|
||||
const int32_t dst_ndims = 3;
|
||||
int32_t dst_shape[dst_ndims] = {2, 3, 4};
|
||||
int32_t dst_strides[dst_ndims] = {};
|
||||
NDArray<int32_t> dst_ndarray = {
|
||||
.ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides};
|
||||
|
||||
ErrorContext errctx = create_testing_errctx();
|
||||
ndarray::broadcast::broadcast_to(&errctx, &ndarray, &dst_ndarray);
|
||||
assert_errctx_no_exception(&errctx);
|
||||
|
||||
assert_arrays_match(dst_ndims, ((int32_t[]){0, 0, 8}), dst_ndarray.strides);
|
||||
|
||||
assert_values_match(19.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){0, 0, 0}))));
|
||||
assert_values_match(29.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){0, 0, 1}))));
|
||||
assert_values_match(39.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){0, 0, 2}))));
|
||||
assert_values_match(49.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){0, 0, 3}))));
|
||||
assert_values_match(19.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){0, 1, 0}))));
|
||||
assert_values_match(29.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){0, 1, 1}))));
|
||||
assert_values_match(39.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){0, 1, 2}))));
|
||||
assert_values_match(49.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){0, 1, 3}))));
|
||||
assert_values_match(49.9,
|
||||
*((double*)ndarray::basic::get_pelement_by_indices(
|
||||
&dst_ndarray, ((int32_t[]){1, 2, 3}))));
|
||||
}
|
||||
|
||||
void run() {
|
||||
test_can_broadcast_shape();
|
||||
test_ndarray_broadcast();
|
||||
}
|
||||
} // namespace ndarray_broadcast
|
||||
} // namespace test
|
|
@ -0,0 +1,74 @@
|
|||
use crate::codegen::{
|
||||
irrt::{
|
||||
error_context::{check_error_context, setup_error_context},
|
||||
util::{function::CallFunction, get_sizet_dependent_function_name},
|
||||
},
|
||||
model::*,
|
||||
structure::ndarray::NpArray,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||
dst_ndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||
) {
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
|
||||
let perrctx = setup_error_context(tyctx, ctx);
|
||||
CallFunction::begin(
|
||||
tyctx,
|
||||
ctx,
|
||||
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_broadcast_to"),
|
||||
)
|
||||
.arg("errctx", perrctx)
|
||||
.arg("src_ndarray", src_ndarray)
|
||||
.arg("dst_ndarray", dst_ndarray)
|
||||
.returning_void();
|
||||
|
||||
check_error_context(generator, ctx, perrctx);
|
||||
}
|
||||
|
||||
/// Fields of [`ShapeEntry`]
|
||||
pub struct ShapeEntryFields<F: FieldVisitor> {
|
||||
pub ndims: F::Field<IntModel<SizeT>>,
|
||||
pub shape: F::Field<PtrModel<IntModel<SizeT>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct ShapeEntry;
|
||||
|
||||
impl StructKind for ShapeEntry {
|
||||
type Fields<F: FieldVisitor> = ShapeEntryFields<F>;
|
||||
|
||||
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
|
||||
Self::Fields { ndims: visitor.add("ndims"), shape: visitor.add("shape") }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
num_shape_entries: Int<'ctx, SizeT>,
|
||||
shape_entries: Ptr<'ctx, StructModel<ShapeEntry>>,
|
||||
dst_ndims: Int<'ctx, SizeT>,
|
||||
dst_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) {
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
|
||||
let perrctx = setup_error_context(tyctx, ctx);
|
||||
CallFunction::begin(
|
||||
tyctx,
|
||||
ctx,
|
||||
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_broadcast_shapes"),
|
||||
)
|
||||
.arg("errctx", perrctx)
|
||||
.arg("num_shapes", num_shape_entries)
|
||||
.arg("shapes", shape_entries)
|
||||
.arg("dst_ndims", dst_ndims)
|
||||
.arg("dst_shape", dst_shape)
|
||||
.returning_void();
|
||||
|
||||
check_error_context(generator, ctx, perrctx);
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
pub mod basic;
|
||||
pub mod broadcast;
|
||||
pub mod indexing;
|
||||
pub mod reshape;
|
||||
|
|
Loading…
Reference in New Issue