From eb295cf7e4f0c31d754320cc817cf0700fa37a89 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 30 Jul 2024 16:02:46 +0800 Subject: [PATCH] core/ndstrides: implement numpy broadcasting IRRT --- nac3core/irrt/irrt/ndarray/broadcast.hpp | 221 ++++++++++++++++++ nac3core/irrt/irrt_everything.hpp | 1 + nac3core/irrt/irrt_test.cpp | 2 + nac3core/irrt/test/test_ndarray_broadcast.hpp | 129 ++++++++++ .../src/codegen/irrt/ndarray/broadcast.rs | 74 ++++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 1 + 6 files changed, 428 insertions(+) create mode 100644 nac3core/irrt/irrt/ndarray/broadcast.hpp create mode 100644 nac3core/irrt/test/test_ndarray_broadcast.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/broadcast.rs diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 00000000..1dcbc577 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,221 @@ +#pragma once + +#include +#include +#include +#include + +namespace { +template +struct ShapeEntry { + SizeT ndims; + SizeT* shape; +}; +} // namespace + +namespace { +namespace ndarray { +namespace broadcast { +namespace util { +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, + SizeT src_ndims, const SizeT* src_shape) { + /* + * // See https://numpy.org/doc/stable/user/basics.broadcasting.html + + * This function handles this example: + * ``` + * Image (3d array): 256 x 256 x 3 + * Scale (1d array): 3 + * Result (3d array): 256 x 256 x 3 + * ``` + + * Other interesting examples to consider: + * - `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true` + * - `can_broadcast_shape_to([3], [3, 1]) == false` + * - `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true` + + * In cases when the shapes contain zero(es): + * - `can_broadcast_shape_to([0], [1]) == true` + * - `can_broadcast_shape_to([0], [2]) == false` + * - `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true` + * - `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true` + * - `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true` + * - `can_broadcast_shape_to([4, 3], [0, 3]) == false` + * - `can_broadcast_shape_to([4, 3], [0, 0]) == false` + */ + + // This is essentially doing the following in Python: + // `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)` + for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) { + SizeT target_dim_i = target_ndims - i - 1; + SizeT src_dim_i = src_ndims - i - 1; + + bool target_dim_exists = target_dim_i >= 0; + bool src_dim_exists = src_dim_i >= 0; + + SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1; + SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1; + + bool ok = src_dim == 1 || target_dim == src_dim; + if (!ok) return false; + } + + return true; +} + +/** + * @brief Performs `np.broadcast_shapes` + */ +template +void broadcast_shapes(ErrorContext* errctx, SizeT num_shapes, + const ShapeEntry* shapes, SizeT dst_ndims, + SizeT* dst_shape) { + // `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it + // for this function since it should already know in order to allocate `dst_shape` in the first place. + // `dst_shape` must be pre-allocated. + // `dst_shape` does not have to be initialized + + // TODO: Implementation is not obvious + + // This is essentially a `mconcat` where the neutral element is `[1, 1, 1, 1, ...]`, and the operation is commutative. + + // Set `dst_shape` to all `1`s. + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) { + dst_shape[dst_axis] = 0; + } + + for (SizeT i = 0; i < num_shapes; i++) { + ShapeEntry entry = shapes[i]; + SizeT entry_axis = entry.ndims - i; + SizeT dst_axis = dst_ndims - i; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) { + dst_shape[dst_axis] = entry_dim; + } else if (entry_dim == 1) { + // Do nothing + } else if (entry_dim == dst_dim) { + // Do nothing + } else { + errctx->set_exception(errctx->exceptions->value_error, + "shape mismatch: objects cannot be broadcast " + "to a single shape."); + return; + } + } +} +} // namespace util + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * Cautious note on https://github.com/numpy/numpy/issues/21744.. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template +void broadcast_to(ErrorContext* errctx, const NDArray* src_ndarray, + NDArray* dst_ndarray) { + /* + * Cautions: + * ``` + * xs = np.zeros((4,)) + * ys = np.zero((4, 1)) + * ys[:] = xs # ok + * + * xs = np.zeros((1, 4)) + * ys = np.zero((4,)) + * ys[:] = xs # allowed + * # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule. + * # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744 + * # This implementation will NOT support this assignment. + * ``` + */ + + if (!ndarray::broadcast::util::can_broadcast_shape_to( + dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) { + errctx->set_exception(errctx->exceptions->value_error, + "operands could not be broadcast together"); + return; + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // TODO: Implementation is not obvious + SizeT stride_product = 1; + for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) { + SizeT src_ndarray_dim_i = src_ndarray->ndims - i - 1; + SizeT dst_dim_i = dst_ndarray->ndims - i - 1; + + bool src_ndarray_dim_exists = src_ndarray_dim_i >= 0; + bool dst_dim_exists = dst_dim_i >= 0; + + bool c1 = src_ndarray_dim_exists && + src_ndarray->shape[src_ndarray_dim_i] == 1; + bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1; + if (!src_ndarray_dim_exists || (c1 && c2)) { + dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place + } else { + dst_ndarray->strides[dst_dim_i] = + stride_product * src_ndarray->itemsize; + stride_product *= src_ndarray->shape[src_ndarray_dim_i]; + } + } +} +} // namespace broadcast +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::broadcast; + +void __nac3_ndarray_broadcast_to(ErrorContext* errctx, + NDArray* src_ndarray, + NDArray* dst_ndarray) { + broadcast_to(errctx, src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_to64(ErrorContext* errctx, + NDArray* src_ndarray, + NDArray* dst_ndarray) { + broadcast_to(errctx, src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_shapes(ErrorContext* errctx, int32_t num_shapes, + const ShapeEntry* shapes, + int32_t dst_ndims, int32_t* dst_shape) { + ndarray::broadcast::util::broadcast_shapes(errctx, num_shapes, shapes, + dst_ndims, dst_shape); +} + +void __nac3_ndarray_broadcast_shapes64(ErrorContext* errctx, int64_t num_shapes, + const ShapeEntry* shapes, + int64_t dst_ndims, int64_t* dst_shape) { + ndarray::broadcast::util::broadcast_shapes(errctx, num_shapes, shapes, + dst_ndims, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index 1608b861..4200c729 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index 8ce5d196..e40abf8e 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -14,5 +15,6 @@ int main() { test::slice::run(); test::ndarray_basic::run(); test::ndarray_indexing::run(); + test::ndarray_broadcast::run(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/test/test_ndarray_broadcast.hpp b/nac3core/irrt/test/test_ndarray_broadcast.hpp new file mode 100644 index 00000000..a5cc0c79 --- /dev/null +++ b/nac3core/irrt/test/test_ndarray_broadcast.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include + +namespace test { +namespace ndarray_broadcast { +void test_can_broadcast_shape() { + BEGIN_TEST(); + + assert_values_match(true, + ndarray::broadcast::util::can_broadcast_shape_to( + 1, (int32_t[]){3}, 5, (int32_t[]){1, 1, 1, 1, 3})); + assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( + 1, (int32_t[]){3}, 2, (int32_t[]){3, 1})); + assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( + 1, (int32_t[]){3}, 1, (int32_t[]){3})); + assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( + 1, (int32_t[]){1}, 1, (int32_t[]){3})); + assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( + 1, (int32_t[]){1}, 1, (int32_t[]){1})); + assert_values_match( + true, ndarray::broadcast::util::can_broadcast_shape_to( + 3, (int32_t[]){256, 256, 3}, 3, (int32_t[]){256, 1, 3})); + assert_values_match(true, + ndarray::broadcast::util::can_broadcast_shape_to( + 3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){3})); + assert_values_match(false, + ndarray::broadcast::util::can_broadcast_shape_to( + 3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){2})); + assert_values_match(true, + ndarray::broadcast::util::can_broadcast_shape_to( + 3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){1})); + + // In cases when the shapes contain zero(es) + assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to( + 1, (int32_t[]){0}, 1, (int32_t[]){1})); + assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( + 1, (int32_t[]){0}, 1, (int32_t[]){2})); + assert_values_match(true, + ndarray::broadcast::util::can_broadcast_shape_to( + 4, (int32_t[]){0, 4, 0, 0}, 1, (int32_t[]){1})); + assert_values_match( + true, ndarray::broadcast::util::can_broadcast_shape_to( + 4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 1, 1, 1})); + assert_values_match( + true, ndarray::broadcast::util::can_broadcast_shape_to( + 4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 4, 1, 1})); + assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( + 2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 3})); + assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to( + 2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 0})); +} + +void test_ndarray_broadcast() { + /* + # array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64) + # >>> [[19.9 29.9 39.9 49.9]] + # + # array = np.broadcast_to(array, (2, 3, 4)) + # >>> [[[19.9 29.9 39.9 49.9] + # >>> [19.9 29.9 39.9 49.9] + # >>> [19.9 29.9 39.9 49.9]] + # >>> [[19.9 29.9 39.9 49.9] + # >>> [19.9 29.9 39.9 49.9] + # >>> [19.9 29.9 39.9 49.9]]] + # + # assery array.strides == (0, 0, 8) + + */ + BEGIN_TEST(); + + double in_data[4] = {19.9, 29.9, 39.9, 49.9}; + const int32_t in_ndims = 2; + int32_t in_shape[in_ndims] = {1, 4}; + int32_t in_strides[in_ndims] = {}; + NDArray ndarray = {.data = (uint8_t*)in_data, + .itemsize = sizeof(double), + .ndims = in_ndims, + .shape = in_shape, + .strides = in_strides}; + ndarray::basic::set_strides_by_shape(&ndarray); + + const int32_t dst_ndims = 3; + int32_t dst_shape[dst_ndims] = {2, 3, 4}; + int32_t dst_strides[dst_ndims] = {}; + NDArray dst_ndarray = { + .ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides}; + + ErrorContext errctx = create_testing_errctx(); + ndarray::broadcast::broadcast_to(&errctx, &ndarray, &dst_ndarray); + assert_errctx_no_exception(&errctx); + + assert_arrays_match(dst_ndims, ((int32_t[]){0, 0, 8}), dst_ndarray.strides); + + assert_values_match(19.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){0, 0, 0})))); + assert_values_match(29.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){0, 0, 1})))); + assert_values_match(39.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){0, 0, 2})))); + assert_values_match(49.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){0, 0, 3})))); + assert_values_match(19.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){0, 1, 0})))); + assert_values_match(29.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){0, 1, 1})))); + assert_values_match(39.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){0, 1, 2})))); + assert_values_match(49.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){0, 1, 3})))); + assert_values_match(49.9, + *((double*)ndarray::basic::get_pelement_by_indices( + &dst_ndarray, ((int32_t[]){1, 2, 3})))); +} + +void run() { + test_can_broadcast_shape(); + test_ndarray_broadcast(); +} +} // namespace ndarray_broadcast +} // namespace test \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs new file mode 100644 index 00000000..f6da23b0 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -0,0 +1,74 @@ +use crate::codegen::{ + irrt::{ + error_context::{check_error_context, setup_error_context}, + util::{function::CallFunction, get_sizet_dependent_function_name}, + }, + model::*, + structure::ndarray::NpArray, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Ptr<'ctx, StructModel>, + dst_ndarray: Ptr<'ctx, StructModel>, +) { + let tyctx = generator.type_context(ctx.ctx); + + let perrctx = setup_error_context(tyctx, ctx); + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_broadcast_to"), + ) + .arg("errctx", perrctx) + .arg("src_ndarray", src_ndarray) + .arg("dst_ndarray", dst_ndarray) + .returning_void(); + + check_error_context(generator, ctx, perrctx); +} + +/// Fields of [`ShapeEntry`] +pub struct ShapeEntryFields { + pub ndims: F::Field>, + pub shape: F::Field>>, +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct ShapeEntry; + +impl StructKind for ShapeEntry { + type Fields = ShapeEntryFields; + + fn visit_fields(&self, visitor: &mut F) -> Self::Fields { + Self::Fields { ndims: visitor.add("ndims"), shape: visitor.add("shape") } + } +} + +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + num_shape_entries: Int<'ctx, SizeT>, + shape_entries: Ptr<'ctx, StructModel>, + dst_ndims: Int<'ctx, SizeT>, + dst_shape: Ptr<'ctx, IntModel>, +) { + let tyctx = generator.type_context(ctx.ctx); + + let perrctx = setup_error_context(tyctx, ctx); + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_broadcast_shapes"), + ) + .arg("errctx", perrctx) + .arg("num_shapes", num_shape_entries) + .arg("shapes", shape_entries) + .arg("dst_ndims", dst_ndims) + .arg("dst_shape", dst_shape) + .returning_void(); + + check_error_context(generator, ctx, perrctx); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index f6c7c8a1..a49870c6 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,3 +1,4 @@ pub mod basic; +pub mod broadcast; pub mod indexing; pub mod reshape;