diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 7e4f2a15..1cdfde02 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -10,4 +10,5 @@ #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" #include "irrt/ndarray/array.hpp" -#include "irrt/ndarray/reshape.hpp" \ No newline at end of file +#include "irrt/ndarray/reshape.hpp" +#include "irrt/ndarray/broadcast.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 00000000..e419081c --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/slice.hpp" + +namespace { +template +struct ShapeEntry { + SizeT ndims; + SizeT* shape; +}; +} // namespace + +namespace { +namespace ndarray { +namespace broadcast { +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + * + * See https://numpy.org/doc/stable/user/basics.broadcasting.html + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) { + if (src_ndims > target_ndims) { + return false; + } + + for (SizeT i = 0; i < src_ndims; i++) { + SizeT target_dim = target_shape[target_ndims - i - 1]; + SizeT src_dim = src_shape[src_ndims - i - 1]; + if (!(src_dim == 1 || target_dim == src_dim)) { + return false; + } + } + return true; +} + +/** + * @brief Performs `np.broadcast_shapes()` + * + * @param num_shapes Number of entries in `shapes` + * @param shapes The list of shape to do `np.broadcast_shapes` on. + * @param dst_ndims The length of `dst_shape`. + * `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it. + * for this function since they should already know in order to allocate `dst_shape` in the first place. + * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result + * of `np.broadcast_shapes` and write it here. + */ +template +void broadcast_shapes(SizeT num_shapes, const ShapeEntry* shapes, SizeT dst_ndims, SizeT* dst_shape) { + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) { + dst_shape[dst_axis] = 1; + } + +#ifdef IRRT_DEBUG_ASSERT + SizeT max_ndims_found = 0; +#endif + + for (SizeT i = 0; i < num_shapes; i++) { + ShapeEntry entry = shapes[i]; + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert(SizeT, entry.ndims <= dst_ndims); + +#ifdef IRRT_DEBUG_ASSERT + max_ndims_found = max(max_ndims_found, entry.ndims); +#endif + + for (SizeT j = 0; j < entry.ndims; j++) { + SizeT entry_axis = entry.ndims - j - 1; + SizeT dst_axis = dst_ndims - j - 1; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) { + dst_shape[dst_axis] = entry_dim; + } else if (entry_dim == 1 || entry_dim == dst_dim) { + // Do nothing + } else { + raise_exception(SizeT, EXN_VALUE_ERROR, + "shape mismatch: objects cannot be broadcast " + "to a single shape.", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } + } + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert_eq(SizeT, max_ndims_found, dst_ndims); +} + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template +void broadcast_to(const NDArray* src_ndarray, NDArray* dst_ndarray) { + if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) { + raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, + NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + for (SizeT i = 0; i < dst_ndarray->ndims; i++) { + SizeT src_axis = src_ndarray->ndims - i - 1; + SizeT dst_axis = dst_ndarray->ndims - i - 1; + if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) { + // Freeze the steps in-place + dst_ndarray->strides[dst_axis] = 0; + } else { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} // namespace broadcast +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::broadcast; + +void __nac3_ndarray_broadcast_to(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_to64(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, + const ShapeEntry* shapes, + int32_t dst_ndims, + int32_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} + +void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, + const ShapeEntry* shapes, + int64_t dst_ndims, + int64_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 67dde78d..fb27a5df 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -10,7 +10,7 @@ use super::{ model::*, object::{ list::List, - ndarray::{indexing::NDIndex, nditer::NDIter, NDArray}, + ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray}, }, stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, @@ -1176,3 +1176,30 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera ); FnCall::builder(generator, ctx, &name).arg(size).arg(new_ndims).arg(new_shape).returning_void(); } + +pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); +} + +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + num_shape_entries: Instance<'ctx, Int>, + shape_entries: Instance<'ctx, Ptr>>, + dst_ndims: Instance<'ctx, Int>, + dst_shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + FnCall::builder(generator, ctx, &name) + .arg(num_shape_entries) + .arg(shape_entries) + .arg(dst_ndims) + .arg(dst_shape) + .returning_void(); +} diff --git a/nac3core/src/codegen/object/ndarray/broadcast.rs b/nac3core/src/codegen/object/ndarray/broadcast.rs new file mode 100644 index 00000000..f41de803 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/broadcast.rs @@ -0,0 +1,139 @@ +use itertools::Itertools; + +use crate::codegen::{ + irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to}, + model::*, + CodeGenContext, CodeGenerator, +}; + +use super::NDArrayObject; + +/// Fields of [`ShapeEntry`] +pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> { + pub ndims: F::Output>, + pub shape: F::Output>>, +} + +/// An IRRT structure used in broadcasting. +#[derive(Debug, Clone, Copy, Default)] +pub struct ShapeEntry; + +impl<'ctx> StructKind<'ctx> for ShapeEntry { + type Fields> = ShapeEntryFields<'ctx, F>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") } + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create a broadcast view on this ndarray with a target shape. + /// + /// The input shape will be checked to make sure that it contains no negative values. + /// + /// * `target_ndims` - The ndims type after broadcasting to the given shape. + /// The caller has to figure this out for this function. + /// * `target_shape` - An array pointer pointing to the target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims: u64, + target_shape: Instance<'ctx, Ptr>>, + ) -> Self { + let broadcast_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, target_ndims); + broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape); + + call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance); + broadcast_ndarray + } +} +/// A result produced by [`broadcast_all_ndarrays`] +#[derive(Debug, Clone)] +pub struct BroadcastAllResult<'ctx> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + /// The broadcasting shape. + pub shape: Instance<'ctx, Ptr>>, + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +/// Helper function to call `call_nac3_ndarray_broadcast_shapes` +fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape_entries: &[(Instance<'ctx, Ptr>>, u64)], // (shape, shape's length/ndims) + broadcast_ndims: u64, + broadcast_shape: Instance<'ctx, Ptr>>, +) { + // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. + let num_shape_entries = Int(SizeT).const_int( + generator, + ctx.ctx, + u64::try_from(in_shape_entries.len()).unwrap(), + false, + ); + let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value); + for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { + let pshape_entry = shape_entries.offset_const(ctx, i64::try_from(i).unwrap()); + + let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims, false); + pshape_entry.set(ctx, |f| f.ndims, in_ndims); + + pshape_entry.set(ctx, |f| f.shape, *in_shape); + } + + let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims, false); + call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims, + broadcast_shape, + ); +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`] + /// containing all the information of the result of the broadcast operation. + pub fn broadcast( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[Self], + ) -> BroadcastAllResult<'ctx> { + assert!(!ndarrays.is_empty()); + + // Infer the broadcast output ndims. + let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap(); + + let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int, false); + let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value); + + let shape_entries = ndarrays + .iter() + .map(|ndarray| (ndarray.instance.get(generator, ctx, |f| f.shape), ndarray.ndims)) + .collect_vec(); + broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape); + + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays: Vec<_> = ndarrays + .iter() + .map(|ndarray| { + ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape) + }) + .collect_vec(); + + BroadcastAllResult { + ndims: broadcast_ndims_int, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 9692f227..f9787078 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -1,4 +1,5 @@ pub mod array; +pub mod broadcast; pub mod factory; pub mod indexing; pub mod nditer; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index b2a6da15..485b6347 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -521,7 +521,7 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { self.build_ndarray_view_function(prim) } @@ -1469,7 +1469,10 @@ impl<'a> BuiltinBuilder<'a> { /// Build np/sp functions that take as input `NDArray` only fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1497,7 +1500,10 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => { + // These two functions have the same function signature. + // Mixed together for convenience. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding create_fn_by_codegen( @@ -1528,7 +1534,15 @@ impl<'a> BuiltinBuilder<'a> { let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let ndims = extract_ndims(&ctx.unifier, ndims); - let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, shape); + let new_ndarray = match prim { + PrimDef::FunNpBroadcastTo => { + ndarray.broadcast_to(generator, ctx, ndims, shape) + } + PrimDef::FunNpReshape => { + ndarray.reshape_or_copy(generator, ctx, ndims, shape) + } + _ => unreachable!(), + }; Ok(Some(new_ndarray.instance.value.as_basic_value_enum())) }), ) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 0a69a009..f35b2e55 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -59,6 +59,7 @@ pub enum PrimDef { FunNpStrides, // NumPy ndarray view functions + FunNpBroadcastTo, FunNpTranspose, FunNpReshape, @@ -252,6 +253,7 @@ impl PrimDef { PrimDef::FunNpStrides => fun("np_strides", None), // NumPy NDArray view functions + PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None), PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunNpReshape => fun("np_reshape", None), diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 6bdf48df..95aa8aba 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1543,7 +1543,7 @@ impl<'a> Inferencer<'a> { })); } // 2-argument ndarray n-dimensional factory functions - if id == &"np_reshape".into() && args.len() == 2 { + if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 { let arg0 = self.fold_expr(args.remove(0))?; let shape_expr = args.remove(0); diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 56c6126d..8784ce53 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray view functions + module.np_broadcast_to = np.broadcast_to module.np_transpose = np.transpose module.np_reshape = np.reshape