forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement broadcasting & np_broadcast_to()

This commit is contained in:
lyken 2024-08-22 09:59:58 +08:00
parent 34709cf076
commit 25be81cc83
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
9 changed files with 375 additions and 6 deletions

View File

@ -4,6 +4,7 @@
#include <irrt/math_util.hpp>
#include <irrt/ndarray/array.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/indexing.hpp>
#include <irrt/ndarray/iter.hpp>

View File

@ -0,0 +1,188 @@
#pragma once
#include <irrt/int_types.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
namespace
{
template <typename SizeT> struct ShapeEntry
{
SizeT ndims;
SizeT *shape;
};
} // namespace
namespace
{
namespace ndarray
{
namespace broadcast
{
/**
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
*
* See https://numpy.org/doc/stable/user/basics.broadcasting.html
*/
template <typename SizeT>
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT *target_shape, SizeT src_ndims, const SizeT *src_shape)
{
if (src_ndims > target_ndims)
{
return false;
}
for (SizeT i = 0; i < src_ndims; i++)
{
SizeT target_dim = target_shape[target_ndims - i - 1];
SizeT src_dim = src_shape[src_ndims - i - 1];
if (!(src_dim == 1 || target_dim == src_dim))
{
return false;
}
}
return true;
}
/**
* @brief Performs `np.broadcast_shapes(<shapes>)`
*
* @param num_shapes Number of entries in `shapes`
* @param shapes The list of shape to do `np.broadcast_shapes` on.
* @param dst_ndims The length of `dst_shape`.
* `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it.
* for this function since they should already know in order to allocate `dst_shape` in the first place.
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
* of `np.broadcast_shapes` and write it here.
*/
template <typename SizeT>
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT> *shapes, SizeT dst_ndims, SizeT *dst_shape)
{
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++)
{
dst_shape[dst_axis] = 1;
}
#ifdef IRRT_DEBUG_ASSERT
SizeT max_ndims_found = 0;
#endif
for (SizeT i = 0; i < num_shapes; i++)
{
ShapeEntry<SizeT> entry = shapes[i];
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
debug_assert(SizeT, entry.ndims <= dst_ndims);
#ifdef IRRT_DEBUG_ASSERT
max_ndims_found = max(max_ndims_found, entry.ndims);
#endif
for (SizeT j = 0; j < entry.ndims; j++)
{
SizeT entry_axis = entry.ndims - j - 1;
SizeT dst_axis = dst_ndims - j - 1;
SizeT entry_dim = entry.shape[entry_axis];
SizeT dst_dim = dst_shape[dst_axis];
if (dst_dim == 1)
{
dst_shape[dst_axis] = entry_dim;
}
else if (entry_dim == 1 || entry_dim == dst_dim)
{
// Do nothing
}
else
{
raise_exception(SizeT, EXN_VALUE_ERROR,
"shape mismatch: objects cannot be broadcast "
"to a single shape.",
NO_PARAM, NO_PARAM, NO_PARAM);
}
}
}
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
debug_assert_eq(SizeT, max_ndims_found, dst_ndims);
}
/**
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
*
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
* and return the result by modifying `dst_ndarray`.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is unchanged.
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
*/
template <typename SizeT> void broadcast_to(const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray)
{
if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
src_ndarray->shape))
{
raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM,
NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
for (SizeT i = 0; i < dst_ndarray->ndims; i++)
{
SizeT src_axis = src_ndarray->ndims - i - 1;
SizeT dst_axis = dst_ndarray->ndims - i - 1;
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1))
{
// Freeze the steps in-place
dst_ndarray->strides[dst_axis] = 0;
}
else
{
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
}
}
} // namespace broadcast
} // namespace ndarray
} // namespace
extern "C"
{
using namespace ndarray::broadcast;
void __nac3_ndarray_broadcast_to(NDArray<int32_t> *src_ndarray, NDArray<int32_t> *dst_ndarray)
{
broadcast_to(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_broadcast_to64(NDArray<int64_t> *src_ndarray, NDArray<int64_t> *dst_ndarray)
{
broadcast_to(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, const ShapeEntry<int32_t> *shapes, int32_t dst_ndims,
int32_t *dst_shape)
{
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
}
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, const ShapeEntry<int64_t> *shapes, int64_t dst_ndims,
int64_t *dst_shape)
{
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
}
}

View File

@ -9,7 +9,7 @@ use super::{
model::*,
object::{
list::List,
ndarray::{indexing::NDIndex, nditer::NDIter, NDArray},
ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray},
},
CodeGenContext, CodeGenerator,
};
@ -1187,3 +1187,30 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera
.arg(new_shape)
.returning_void();
}
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to");
CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
}
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_shape_entries: Instance<'ctx, Int<SizeT>>,
shape_entries: Instance<'ctx, Ptr<Struct<ShapeEntry>>>,
dst_ndims: Instance<'ctx, Int<SizeT>>,
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes");
CallFunction::begin(generator, ctx, &name)
.arg(num_shape_entries)
.arg(shape_entries)
.arg(dst_ndims)
.arg(dst_shape)
.returning_void();
}

View File

@ -0,0 +1,135 @@
use itertools::Itertools;
use crate::codegen::{
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
model::*,
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
/// Fields of [`ShapeEntry`]
pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<Int<SizeT>>,
pub shape: F::Out<Ptr<Int<SizeT>>>,
}
/// An IRRT structure used in broadcasting.
#[derive(Debug, Clone, Copy, Default)]
pub struct ShapeEntry;
impl<'ctx> StructKind<'ctx> for ShapeEntry {
type Fields<F: FieldTraversal<'ctx>> = ShapeEntryFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") }
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create a broadcast view on this ndarray with a target shape.
///
/// The input shape will be checked to make sure that it contains no negative values.
///
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
/// The caller has to figure this out for this function.
/// * `target_shape` - An array pointer pointing to the target shape.
#[must_use]
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target_ndims: u64,
target_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> Self {
let broadcast_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, target_ndims);
broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape);
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance);
broadcast_ndarray
}
}
/// A result produced by [`broadcast_all_ndarrays`]
#[derive(Debug, Clone)]
pub struct BroadcastAllResult<'ctx> {
/// The statically known `ndims` of the broadcast result.
pub ndims: u64,
/// The broadcasting shape.
pub shape: Instance<'ctx, Ptr<Int<SizeT>>>,
/// Broadcasted views on the inputs.
///
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
/// is the same as the input.
pub ndarrays: Vec<NDArrayObject<'ctx>>,
}
/// Helper function to call `call_nac3_ndarray_broadcast_shapes`
fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
in_shape_entries: &[(Instance<'ctx, Ptr<Int<SizeT>>>, u64)], // (shape, shape's length/ndims)
broadcast_ndims: u64,
broadcast_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
// Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`.
let num_shape_entries =
Int(SizeT).const_int(generator, ctx.ctx, u64::try_from(in_shape_entries.len()).unwrap());
let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value);
for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() {
let pshape_entry = shape_entries.offset_const(ctx, i as u64);
let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims);
pshape_entry.set(ctx, |f| f.ndims, in_ndims);
pshape_entry.set(ctx, |f| f.shape, *in_shape);
}
let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims);
call_nac3_ndarray_broadcast_shapes(
generator,
ctx,
num_shape_entries,
shape_entries,
broadcast_ndims,
broadcast_shape,
);
}
impl<'ctx> NDArrayObject<'ctx> {
/// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`]
/// containing all the information of the result of the broadcast operation.
pub fn broadcast<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarrays: &[Self],
) -> BroadcastAllResult<'ctx> {
assert!(!ndarrays.is_empty());
// Infer the broadcast output ndims.
let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int);
let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value);
let shape_entries = ndarrays
.iter()
.map(|ndarray| (ndarray.instance.get(generator, ctx, |f| f.shape), ndarray.ndims))
.collect_vec();
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape);
// Broadcast all the inputs to shape `dst_shape`.
let broadcast_ndarrays: Vec<_> = ndarrays
.iter()
.map(|ndarray| {
ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape)
})
.collect_vec();
BroadcastAllResult {
ndims: broadcast_ndims_int,
shape: broadcast_shape,
ndarrays: broadcast_ndarrays,
}
}
}

View File

@ -1,4 +1,5 @@
pub mod array;
pub mod broadcast;
pub mod factory;
pub mod indexing;
pub mod nditer;

View File

@ -521,7 +521,7 @@ impl<'a> BuiltinBuilder<'a> {
self.build_ndarray_property_getter_function(prim)
}
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_ndarray_view_function(prim)
}
@ -1469,7 +1469,10 @@ impl<'a> BuiltinBuilder<'a> {
/// Build np/sp functions that take as input `NDArray` only
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape],
);
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.primitives.ndarray],
@ -1497,7 +1500,10 @@ impl<'a> BuiltinBuilder<'a> {
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => {
PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => {
// These two functions have the same function signature.
// Mixed together for convenience.
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
create_fn_by_codegen(
@ -1528,7 +1534,15 @@ impl<'a> BuiltinBuilder<'a> {
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndims = extract_ndims(&ctx.unifier, ndims);
let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, shape);
let new_ndarray = match prim {
PrimDef::FunNpBroadcastTo => {
ndarray.broadcast_to(generator, ctx, ndims, shape)
}
PrimDef::FunNpReshape => {
ndarray.reshape_or_copy(generator, ctx, ndims, shape)
}
_ => unreachable!(),
};
Ok(Some(new_ndarray.instance.value.as_basic_value_enum()))
}),
)

View File

@ -58,6 +58,7 @@ pub enum PrimDef {
FunNpStrides,
// NumPy ndarray view functions
FunNpBroadcastTo,
FunNpTranspose,
FunNpReshape,
@ -251,6 +252,7 @@ impl PrimDef {
PrimDef::FunNpStrides => fun("np_strides", None),
// NumPy NDArray view functions
PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None),
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunNpReshape => fun("np_reshape", None),

View File

@ -1541,7 +1541,7 @@ impl<'a> Inferencer<'a> {
}));
}
// 2-argument ndarray n-dimensional factory functions
if id == &"np_reshape".into() && args.len() == 2 {
if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 {
let arg0 = self.fold_expr(args.remove(0))?;
let shape_expr = args.remove(0);

View File

@ -180,6 +180,7 @@ def patch(module):
module.np_array = np.array
# NumPy NDArray view functions
module.np_broadcast_to = np.broadcast_to
module.np_transpose = np.transpose
module.np_reshape = np.reshape