diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index abd3e932..4e39ec0d 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 00000000..699bd8fa --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,188 @@ +#pragma once + +#include +#include +#include + +namespace +{ +template struct ShapeEntry +{ + SizeT ndims; + SizeT *shape; +}; +} // namespace + +namespace +{ +namespace ndarray +{ +namespace broadcast +{ +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + * + * See https://numpy.org/doc/stable/user/basics.broadcasting.html + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT *target_shape, SizeT src_ndims, const SizeT *src_shape) +{ + if (src_ndims > target_ndims) + { + return false; + } + + for (SizeT i = 0; i < src_ndims; i++) + { + SizeT target_dim = target_shape[target_ndims - i - 1]; + SizeT src_dim = src_shape[src_ndims - i - 1]; + if (!(src_dim == 1 || target_dim == src_dim)) + { + return false; + } + } + return true; +} + +/** + * @brief Performs `np.broadcast_shapes()` + * + * @param num_shapes Number of entries in `shapes` + * @param shapes The list of shape to do `np.broadcast_shapes` on. + * @param dst_ndims The length of `dst_shape`. + * `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it. + * for this function since they should already know in order to allocate `dst_shape` in the first place. + * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result + * of `np.broadcast_shapes` and write it here. + */ +template +void broadcast_shapes(SizeT num_shapes, const ShapeEntry *shapes, SizeT dst_ndims, SizeT *dst_shape) +{ + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) + { + dst_shape[dst_axis] = 1; + } + +#ifdef IRRT_DEBUG_ASSERT + SizeT max_ndims_found = 0; +#endif + + for (SizeT i = 0; i < num_shapes; i++) + { + ShapeEntry entry = shapes[i]; + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert(SizeT, entry.ndims <= dst_ndims); + +#ifdef IRRT_DEBUG_ASSERT + max_ndims_found = max(max_ndims_found, entry.ndims); +#endif + + for (SizeT j = 0; j < entry.ndims; j++) + { + SizeT entry_axis = entry.ndims - j - 1; + SizeT dst_axis = dst_ndims - j - 1; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) + { + dst_shape[dst_axis] = entry_dim; + } + else if (entry_dim == 1 || entry_dim == dst_dim) + { + // Do nothing + } + else + { + raise_exception(SizeT, EXN_VALUE_ERROR, + "shape mismatch: objects cannot be broadcast " + "to a single shape.", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } + } + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert_eq(SizeT, max_ndims_found, dst_ndims); +} + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template void broadcast_to(const NDArray *src_ndarray, NDArray *dst_ndarray) +{ + if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, + NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + for (SizeT i = 0; i < dst_ndarray->ndims; i++) + { + SizeT src_axis = src_ndarray->ndims - i - 1; + SizeT dst_axis = dst_ndarray->ndims - i - 1; + if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) + { + // Freeze the steps in-place + dst_ndarray->strides[dst_axis] = 0; + } + else + { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} // namespace broadcast +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::broadcast; + + void __nac3_ndarray_broadcast_to(NDArray *src_ndarray, NDArray *dst_ndarray) + { + broadcast_to(src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_broadcast_to64(NDArray *src_ndarray, NDArray *dst_ndarray) + { + broadcast_to(src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, const ShapeEntry *shapes, int32_t dst_ndims, + int32_t *dst_shape) + { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); + } + + void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, const ShapeEntry *shapes, int64_t dst_ndims, + int64_t *dst_shape) + { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); + } +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index d3664fe3..5f410c21 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -9,7 +9,7 @@ use super::{ model::*, object::{ list::List, - ndarray::{indexing::NDIndex, nditer::NDIter, NDArray}, + ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray}, }, CodeGenContext, CodeGenerator, }; @@ -1187,3 +1187,30 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera .arg(new_shape) .returning_void(); } + +pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); +} + +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + num_shape_entries: Instance<'ctx, Int>, + shape_entries: Instance<'ctx, Ptr>>, + dst_ndims: Instance<'ctx, Int>, + dst_shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + CallFunction::begin(generator, ctx, &name) + .arg(num_shape_entries) + .arg(shape_entries) + .arg(dst_ndims) + .arg(dst_shape) + .returning_void(); +} diff --git a/nac3core/src/codegen/object/ndarray/broadcast.rs b/nac3core/src/codegen/object/ndarray/broadcast.rs new file mode 100644 index 00000000..87c34a35 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/broadcast.rs @@ -0,0 +1,135 @@ +use itertools::Itertools; + +use crate::codegen::{ + irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to}, + model::*, + CodeGenContext, CodeGenerator, +}; + +use super::NDArrayObject; + +/// Fields of [`ShapeEntry`] +pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> { + pub ndims: F::Out>, + pub shape: F::Out>>, +} + +/// An IRRT structure used in broadcasting. +#[derive(Debug, Clone, Copy, Default)] +pub struct ShapeEntry; + +impl<'ctx> StructKind<'ctx> for ShapeEntry { + type Fields> = ShapeEntryFields<'ctx, F>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") } + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create a broadcast view on this ndarray with a target shape. + /// + /// The input shape will be checked to make sure that it contains no negative values. + /// + /// * `target_ndims` - The ndims type after broadcasting to the given shape. + /// The caller has to figure this out for this function. + /// * `target_shape` - An array pointer pointing to the target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims: u64, + target_shape: Instance<'ctx, Ptr>>, + ) -> Self { + let broadcast_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, target_ndims); + broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape); + + call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance); + broadcast_ndarray + } +} +/// A result produced by [`broadcast_all_ndarrays`] +#[derive(Debug, Clone)] +pub struct BroadcastAllResult<'ctx> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + /// The broadcasting shape. + pub shape: Instance<'ctx, Ptr>>, + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +/// Helper function to call `call_nac3_ndarray_broadcast_shapes` +fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape_entries: &[(Instance<'ctx, Ptr>>, u64)], // (shape, shape's length/ndims) + broadcast_ndims: u64, + broadcast_shape: Instance<'ctx, Ptr>>, +) { + // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. + let num_shape_entries = + Int(SizeT).const_int(generator, ctx.ctx, u64::try_from(in_shape_entries.len()).unwrap()); + let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value); + for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { + let pshape_entry = shape_entries.offset_const(ctx, i as u64); + + let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims); + pshape_entry.set(ctx, |f| f.ndims, in_ndims); + + pshape_entry.set(ctx, |f| f.shape, *in_shape); + } + + let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims); + call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims, + broadcast_shape, + ); +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`] + /// containing all the information of the result of the broadcast operation. + pub fn broadcast( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[Self], + ) -> BroadcastAllResult<'ctx> { + assert!(!ndarrays.is_empty()); + + // Infer the broadcast output ndims. + let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap(); + + let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int); + let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value); + + let shape_entries = ndarrays + .iter() + .map(|ndarray| (ndarray.instance.get(generator, ctx, |f| f.shape), ndarray.ndims)) + .collect_vec(); + broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape); + + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays: Vec<_> = ndarrays + .iter() + .map(|ndarray| { + ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape) + }) + .collect_vec(); + + BroadcastAllResult { + ndims: broadcast_ndims_int, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 8318458c..a7985b4b 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -1,4 +1,5 @@ pub mod array; +pub mod broadcast; pub mod factory; pub mod indexing; pub mod nditer; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index dcfc81ef..855fd7d1 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -520,7 +520,7 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { self.build_ndarray_view_function(prim) } @@ -1446,7 +1446,10 @@ impl<'a> BuiltinBuilder<'a> { /// Build np/sp functions that take as input `NDArray` only fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1474,7 +1477,7 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => { let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding create_fn_by_codegen( @@ -1505,7 +1508,15 @@ impl<'a> BuiltinBuilder<'a> { let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let ndims = extract_ndims(&ctx.unifier, ndims); - let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, shape); + let new_ndarray = match prim { + PrimDef::FunNpBroadcastTo => { + ndarray.broadcast_to(generator, ctx, ndims, shape) + } + PrimDef::FunNpReshape => { + ndarray.reshape_or_copy(generator, ctx, ndims, shape) + } + _ => unreachable!(), + }; Ok(Some(new_ndarray.instance.value.as_basic_value_enum())) }), ) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index f7d939a8..95ffdf7c 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -57,6 +57,7 @@ pub enum PrimDef { FunNpStrides, // NumPy ndarray view functions + FunNpBroadcastTo, FunNpTranspose, FunNpReshape, @@ -249,6 +250,7 @@ impl PrimDef { PrimDef::FunNpStrides => fun("np_strides", None), // NumPy NDArray view functions + PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None), PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunNpReshape => fun("np_reshape", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index bb097bbd..f9ad5630 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray view functions + module.np_broadcast_to = np.broadcast_to module.np_transpose = np.transpose module.np_reshape = np.reshape