From f35fbde837539a7173851225402f3f9bf6b49cbc Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 16:35:20 +0800 Subject: [PATCH] core/ndstrides: implement np_transpose() (no axes argument) --- nac3core/irrt/irrt.cpp | 2 + nac3core/irrt/irrt/ndarray/broadcast.hpp | 188 ++++++++++++++++++ nac3core/irrt/irrt/ndarray/transpose.hpp | 155 +++++++++++++++ nac3core/src/codegen/irrt/mod.rs | 46 ++++- nac3core/src/codegen/numpy.rs | 106 ---------- .../src/codegen/object/ndarray/broadcast.rs | 135 +++++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 1 + nac3core/src/codegen/object/ndarray/view.rs | 34 +++- nac3core/src/toplevel/builtins.rs | 51 +++-- nac3core/src/toplevel/helper.rs | 2 + nac3standalone/demo/interpret_demo.py | 1 + 11 files changed, 596 insertions(+), 125 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/broadcast.hpp create mode 100644 nac3core/irrt/irrt/ndarray/transpose.hpp create mode 100644 nac3core/src/codegen/object/ndarray/broadcast.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index abd3e932..fdbffb39 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -4,9 +4,11 @@ #include #include #include +#include #include #include #include #include +#include #include #include \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 00000000..699bd8fa --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,188 @@ +#pragma once + +#include +#include +#include + +namespace +{ +template struct ShapeEntry +{ + SizeT ndims; + SizeT *shape; +}; +} // namespace + +namespace +{ +namespace ndarray +{ +namespace broadcast +{ +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + * + * See https://numpy.org/doc/stable/user/basics.broadcasting.html + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT *target_shape, SizeT src_ndims, const SizeT *src_shape) +{ + if (src_ndims > target_ndims) + { + return false; + } + + for (SizeT i = 0; i < src_ndims; i++) + { + SizeT target_dim = target_shape[target_ndims - i - 1]; + SizeT src_dim = src_shape[src_ndims - i - 1]; + if (!(src_dim == 1 || target_dim == src_dim)) + { + return false; + } + } + return true; +} + +/** + * @brief Performs `np.broadcast_shapes()` + * + * @param num_shapes Number of entries in `shapes` + * @param shapes The list of shape to do `np.broadcast_shapes` on. + * @param dst_ndims The length of `dst_shape`. + * `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it. + * for this function since they should already know in order to allocate `dst_shape` in the first place. + * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result + * of `np.broadcast_shapes` and write it here. + */ +template +void broadcast_shapes(SizeT num_shapes, const ShapeEntry *shapes, SizeT dst_ndims, SizeT *dst_shape) +{ + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) + { + dst_shape[dst_axis] = 1; + } + +#ifdef IRRT_DEBUG_ASSERT + SizeT max_ndims_found = 0; +#endif + + for (SizeT i = 0; i < num_shapes; i++) + { + ShapeEntry entry = shapes[i]; + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert(SizeT, entry.ndims <= dst_ndims); + +#ifdef IRRT_DEBUG_ASSERT + max_ndims_found = max(max_ndims_found, entry.ndims); +#endif + + for (SizeT j = 0; j < entry.ndims; j++) + { + SizeT entry_axis = entry.ndims - j - 1; + SizeT dst_axis = dst_ndims - j - 1; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) + { + dst_shape[dst_axis] = entry_dim; + } + else if (entry_dim == 1 || entry_dim == dst_dim) + { + // Do nothing + } + else + { + raise_exception(SizeT, EXN_VALUE_ERROR, + "shape mismatch: objects cannot be broadcast " + "to a single shape.", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } + } + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert_eq(SizeT, max_ndims_found, dst_ndims); +} + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template void broadcast_to(const NDArray *src_ndarray, NDArray *dst_ndarray) +{ + if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, + NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + for (SizeT i = 0; i < dst_ndarray->ndims; i++) + { + SizeT src_axis = src_ndarray->ndims - i - 1; + SizeT dst_axis = dst_ndarray->ndims - i - 1; + if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) + { + // Freeze the steps in-place + dst_ndarray->strides[dst_axis] = 0; + } + else + { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} // namespace broadcast +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::broadcast; + + void __nac3_ndarray_broadcast_to(NDArray *src_ndarray, NDArray *dst_ndarray) + { + broadcast_to(src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_broadcast_to64(NDArray *src_ndarray, NDArray *dst_ndarray) + { + broadcast_to(src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, const ShapeEntry *shapes, int32_t dst_ndims, + int32_t *dst_shape) + { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); + } + + void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, const ShapeEntry *shapes, int64_t dst_ndims, + int64_t *dst_shape) + { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/transpose.hpp b/nac3core/irrt/irrt/ndarray/transpose.hpp new file mode 100644 index 00000000..1ac73f4d --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/transpose.hpp @@ -0,0 +1,155 @@ +#pragma once + +#include +#include +#include + +/* + * Notes on `np.transpose(, )` + * + * TODO: `axes`, if specified, can actually contain negative indices, + * but it is not documented in numpy. + * + * Supporting it for now. + */ + +namespace +{ +namespace ndarray +{ +namespace transpose +{ +/** + * @brief Do assertions on `` in `np.transpose(, )`. + * + * Note that `np.transpose`'s `` argument is optional. If the argument + * is specified but the user, use this function to do assertions on it. + * + * @param ndims The number of dimensions of `` + * @param num_axes Number of elements in `` as specified by the user. + * This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown. + * @param axes The user specified ``. + */ +template void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT *axes) +{ + if (ndims != num_axes) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM); + } + + // TODO: Optimize this + bool *axe_specified = (bool *)__builtin_alloca(sizeof(bool) * ndims); + for (SizeT i = 0; i < ndims; i++) + axe_specified[i] = false; + + for (SizeT i = 0; i < ndims; i++) + { + SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); + if (axis == slice::OUT_OF_BOUNDS) + { + // TODO: numpy actually throws a `numpy.exceptions.AxisError` + raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, + NO_PARAM); + } + + if (axe_specified[axis]) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM); + } + + axe_specified[axis] = true; + } +} + +/** + * @brief Create a transpose view of `src_ndarray` and perform proper assertions. + * + * This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, )`. + * If `` is supposed to be `None`, caller can pass in a `nullptr` to ``. + * + * The transpose view created is returned by modifying `dst_ndarray`. + * + * The caller is responsible for setting up `dst_ndarray` before calling this function. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged + * - `dst_ndarray->shape` is updated according to how `np.transpose` works + * - `dst_ndarray->strides` is updated according to how `np.transpose` works + * + * @param src_ndarray The NDArray to build a transpose view on + * @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above, + * @param num_axes Number of elements in axes. Unused if `axes` is nullptr. + * @param axes Axes permutation. Set it to `nullptr` if `` is `None`. + */ +template +void transpose(const NDArray *src_ndarray, NDArray *dst_ndarray, SizeT num_axes, const SizeT *axes) +{ + debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims); + const auto ndims = src_ndarray->ndims; + + if (axes != nullptr) + assert_transpose_axes(ndims, num_axes, axes); + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes. + if (axes == nullptr) + { + // `np.transpose(, axes=None)` + + /* + * Minor note: `np.transpose(, axes=None)` is equivalent to + * `np.transpose(, axes=[N-1, N-2, ..., 0])` - basically it + * is reversing the order of strides and shape. + * + * This is a fast implementation to handle this special (but very common) case. + */ + + for (SizeT axis = 0; axis < ndims; axis++) + { + dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1]; + dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1]; + } + } + else + { + // `np.transpose(, )` + + // Permute strides and shape according to `axes`, while resolving negative indices in `axes` + for (SizeT axis = 0; axis < ndims; axis++) + { + // `i` cannot be OUT_OF_BOUNDS because of assertions + SizeT i = slice::resolve_index_in_length(ndims, axes[axis]); + + dst_ndarray->shape[axis] = src_ndarray->shape[i]; + dst_ndarray->strides[axis] = src_ndarray->strides[i]; + } + } +} +} // namespace transpose +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::transpose; + void __nac3_ndarray_transpose(const NDArray *src_ndarray, NDArray *dst_ndarray, int32_t num_axes, + const int32_t *axes) + { + transpose(src_ndarray, dst_ndarray, num_axes, axes); + } + + void __nac3_ndarray_transpose64(const NDArray *src_ndarray, NDArray *dst_ndarray, + int64_t num_axes, const int64_t *axes) + { + transpose(src_ndarray, dst_ndarray, num_axes, axes); + } +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 69ec6b6a..55b3a444 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -9,7 +9,7 @@ use super::{ model::*, object::{ list::List, - ndarray::{indexing::NDIndex, nditer::NDIter, NDArray}, + ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray}, }, CodeGenContext, CodeGenerator, }; @@ -1183,3 +1183,47 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera .arg(new_shape) .returning_void(); } + +pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); +} + +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + num_shape_entries: Instance<'ctx, Int>, + shape_entries: Instance<'ctx, Ptr>>, + dst_ndims: Instance<'ctx, Int>, + dst_shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + CallFunction::begin(generator, ctx, &name) + .arg(num_shape_entries) + .arg(shape_entries) + .arg(dst_ndims) + .arg(dst_shape) + .returning_void(); +} + +pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, + num_axes: Instance<'ctx, Int>, + axes: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); + CallFunction::begin(generator, ctx, &name) + .arg(src_ndarray) + .arg(dst_ndarray) + .arg(num_axes) + .arg(axes) + .returning_void(); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index a1e69a18..afc97d17 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1989,112 +1989,6 @@ pub fn gen_ndarray_fill<'ctx>( Ok(()) } -/// Generates LLVM IR for `ndarray.transpose`. -pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_transpose"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - - // Dimensions are reversed in the transposed array - let out = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &n1, - |_, ctx, n| Ok(n.load_ndims(ctx)), - |generator, ctx, n, idx| { - let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap(); - let new_idx = ctx - .builder - .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") - .unwrap(); - unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) } - }, - ) - .unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - - let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap(); - ctx.builder.build_store(rem_idx, idx).unwrap(); - - // Incrementally calculate the new index in the transposed array - // For each index, we first decompose it into the n-dims and use those to reconstruct the new index - // The formula used for indexing is: - // idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n1.load_ndims(ctx), false), - |generator, ctx, _, ndim| { - let ndim_rev = - ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap(); - let ndim_rev = ctx - .builder - .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") - .unwrap(); - let dim = unsafe { - n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None) - }; - - let rem_idx_val = - ctx.builder.build_load(rem_idx, "").unwrap().into_int_value(); - let new_idx_val = - ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - - let add_component = - ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap(); - let rem_idx_val = - ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap(); - - let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap(); - let new_idx_val = - ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap(); - - ctx.builder.build_store(rem_idx, rem_idx_val).unwrap(); - ctx.builder.build_store(new_idx, new_idx_val).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - unreachable!( - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` diff --git a/nac3core/src/codegen/object/ndarray/broadcast.rs b/nac3core/src/codegen/object/ndarray/broadcast.rs new file mode 100644 index 00000000..87c34a35 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/broadcast.rs @@ -0,0 +1,135 @@ +use itertools::Itertools; + +use crate::codegen::{ + irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to}, + model::*, + CodeGenContext, CodeGenerator, +}; + +use super::NDArrayObject; + +/// Fields of [`ShapeEntry`] +pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> { + pub ndims: F::Out>, + pub shape: F::Out>>, +} + +/// An IRRT structure used in broadcasting. +#[derive(Debug, Clone, Copy, Default)] +pub struct ShapeEntry; + +impl<'ctx> StructKind<'ctx> for ShapeEntry { + type Fields> = ShapeEntryFields<'ctx, F>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") } + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create a broadcast view on this ndarray with a target shape. + /// + /// The input shape will be checked to make sure that it contains no negative values. + /// + /// * `target_ndims` - The ndims type after broadcasting to the given shape. + /// The caller has to figure this out for this function. + /// * `target_shape` - An array pointer pointing to the target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims: u64, + target_shape: Instance<'ctx, Ptr>>, + ) -> Self { + let broadcast_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, target_ndims); + broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape); + + call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance); + broadcast_ndarray + } +} +/// A result produced by [`broadcast_all_ndarrays`] +#[derive(Debug, Clone)] +pub struct BroadcastAllResult<'ctx> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + /// The broadcasting shape. + pub shape: Instance<'ctx, Ptr>>, + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +/// Helper function to call `call_nac3_ndarray_broadcast_shapes` +fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape_entries: &[(Instance<'ctx, Ptr>>, u64)], // (shape, shape's length/ndims) + broadcast_ndims: u64, + broadcast_shape: Instance<'ctx, Ptr>>, +) { + // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. + let num_shape_entries = + Int(SizeT).const_int(generator, ctx.ctx, u64::try_from(in_shape_entries.len()).unwrap()); + let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value); + for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { + let pshape_entry = shape_entries.offset_const(ctx, i as u64); + + let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims); + pshape_entry.set(ctx, |f| f.ndims, in_ndims); + + pshape_entry.set(ctx, |f| f.shape, *in_shape); + } + + let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims); + call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims, + broadcast_shape, + ); +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`] + /// containing all the information of the result of the broadcast operation. + pub fn broadcast( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[Self], + ) -> BroadcastAllResult<'ctx> { + assert!(!ndarrays.is_empty()); + + // Infer the broadcast output ndims. + let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap(); + + let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int); + let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value); + + let shape_entries = ndarrays + .iter() + .map(|ndarray| (ndarray.instance.get(generator, ctx, |f| f.shape), ndarray.ndims)) + .collect_vec(); + broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape); + + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays: Vec<_> = ndarrays + .iter() + .map(|ndarray| { + ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape) + }) + .collect_vec(); + + BroadcastAllResult { + ndims: broadcast_ndims_int, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index e3c868ff..613123ad 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -1,4 +1,5 @@ pub mod array; +pub mod broadcast; pub mod factory; pub mod indexing; pub mod nditer; diff --git a/nac3core/src/codegen/object/ndarray/view.rs b/nac3core/src/codegen/object/ndarray/view.rs index f1d2afd4..c6a1be2f 100644 --- a/nac3core/src/codegen/object/ndarray/view.rs +++ b/nac3core/src/codegen/object/ndarray/view.rs @@ -1,6 +1,7 @@ use crate::codegen::{ - irrt::call_nac3_ndarray_reshape_resolve_and_check_new_shape, model::*, CodeGenContext, - CodeGenerator, + irrt::{call_nac3_ndarray_reshape_resolve_and_check_new_shape, call_nac3_ndarray_transpose}, + model::*, + CodeGenContext, CodeGenerator, }; use super::{indexing::RustNDIndex, NDArrayObject}; @@ -86,4 +87,33 @@ impl<'ctx> NDArrayObject<'ctx> { dst_ndarray } + + /// Create a transposed view on this ndarray like `np.transpose(, = None)`. + /// * `axes` - If specified, should be an array of the permutation (negative indices are **allowed**). + #[must_use] + pub fn transpose( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + axes: Option>>>, + ) -> Self { + // Define models + let transposed_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims); + + let num_axes = self.ndims_llvm(generator, ctx.ctx); + + // `axes = nullptr` if `axes` is unspecified. + let axes = axes.unwrap_or_else(|| Ptr(Int(SizeT)).nullptr(generator, ctx.ctx)); + + call_nac3_ndarray_transpose( + generator, + ctx, + self.instance, + transposed_ndarray.instance, + num_axes, + axes, + ); + + transposed_ndarray + } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 6c5ce52c..67cfc456 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -521,7 +521,7 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { self.build_ndarray_view_function(prim) } @@ -1470,7 +1470,10 @@ impl<'a> BuiltinBuilder<'a> { /// Build np/sp functions that take as input `NDArray` only fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1479,18 +1482,26 @@ impl<'a> BuiltinBuilder<'a> { ); match prim { - PrimDef::FunNpTranspose => create_fn_by_codegen( - self.unifier, - &into_var_map([in_ndarray_ty]), - prim.name(), - in_ndarray_ty.ty, - &[(in_ndarray_ty.ty, "x")], - Box::new(move |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) - }), - ), + PrimDef::FunNpTranspose => { + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + in_ndarray_ty.ty, + &[(in_ndarray_ty.ty, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let arg = AnyObject { ty: arg_ty, value: arg_val }; + let ndarray = NDArrayObject::from_object(generator, ctx, arg); + + let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument + Ok(Some(ndarray.instance.value.as_basic_value_enum())) + }), + ) + } // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and // the `param_ty` for `create_fn_by_codegen`. @@ -1498,7 +1509,7 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => { let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding create_fn_by_codegen( @@ -1529,7 +1540,15 @@ impl<'a> BuiltinBuilder<'a> { let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let ndims = extract_ndims(&ctx.unifier, ndims); - let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, shape); + let new_ndarray = match prim { + PrimDef::FunNpBroadcastTo => { + ndarray.broadcast_to(generator, ctx, ndims, shape) + } + PrimDef::FunNpReshape => { + ndarray.reshape_or_copy(generator, ctx, ndims, shape) + } + _ => unreachable!(), + }; Ok(Some(new_ndarray.instance.value.as_basic_value_enum())) }), ) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index bd9d70ad..2533489a 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -58,6 +58,7 @@ pub enum PrimDef { FunNpStrides, // NumPy ndarray view functions + FunNpBroadcastTo, FunNpTranspose, FunNpReshape, @@ -251,6 +252,7 @@ impl PrimDef { PrimDef::FunNpStrides => fun("np_strides", None), // NumPy NDArray view functions + PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None), PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunNpReshape => fun("np_reshape", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 56c6126d..8784ce53 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray view functions + module.np_broadcast_to = np.broadcast_to module.np_transpose = np.transpose module.np_reshape = np.reshape