forked from M-Labs/nac3
140 lines
5.1 KiB
Rust
140 lines
5.1 KiB
Rust
use itertools::Itertools;
|
|
|
|
use crate::codegen::{
|
|
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
|
|
model::*,
|
|
CodeGenContext, CodeGenerator,
|
|
};
|
|
|
|
use super::NDArrayObject;
|
|
|
|
/// Fields of [`ShapeEntry`]
|
|
pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> {
|
|
pub ndims: F::Output<Int<SizeT>>,
|
|
pub shape: F::Output<Ptr<Int<SizeT>>>,
|
|
}
|
|
|
|
/// An IRRT structure used in broadcasting.
|
|
#[derive(Debug, Clone, Copy, Default)]
|
|
pub struct ShapeEntry;
|
|
|
|
impl<'ctx> StructKind<'ctx> for ShapeEntry {
|
|
type Fields<F: FieldTraversal<'ctx>> = ShapeEntryFields<'ctx, F>;
|
|
|
|
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") }
|
|
}
|
|
}
|
|
|
|
impl<'ctx> NDArrayObject<'ctx> {
|
|
/// Create a broadcast view on this ndarray with a target shape.
|
|
///
|
|
/// The input shape will be checked to make sure that it contains no negative values.
|
|
///
|
|
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
|
|
/// The caller has to figure this out for this function.
|
|
/// * `target_shape` - An array pointer pointing to the target shape.
|
|
#[must_use]
|
|
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
|
|
&self,
|
|
generator: &mut G,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
target_ndims: u64,
|
|
target_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
) -> Self {
|
|
let broadcast_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, target_ndims);
|
|
broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape);
|
|
|
|
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance);
|
|
broadcast_ndarray
|
|
}
|
|
}
|
|
/// A result produced by [`broadcast_all_ndarrays`]
|
|
#[derive(Debug, Clone)]
|
|
pub struct BroadcastAllResult<'ctx> {
|
|
/// The statically known `ndims` of the broadcast result.
|
|
pub ndims: u64,
|
|
/// The broadcasting shape.
|
|
pub shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
/// Broadcasted views on the inputs.
|
|
///
|
|
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
|
|
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
|
|
/// is the same as the input.
|
|
pub ndarrays: Vec<NDArrayObject<'ctx>>,
|
|
}
|
|
|
|
/// Helper function to call `call_nac3_ndarray_broadcast_shapes`
|
|
fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
|
generator: &mut G,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
in_shape_entries: &[(Instance<'ctx, Ptr<Int<SizeT>>>, u64)], // (shape, shape's length/ndims)
|
|
broadcast_ndims: u64,
|
|
broadcast_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
) {
|
|
// Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`.
|
|
let num_shape_entries = Int(SizeT).const_int(
|
|
generator,
|
|
ctx.ctx,
|
|
u64::try_from(in_shape_entries.len()).unwrap(),
|
|
false,
|
|
);
|
|
let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value);
|
|
for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() {
|
|
let pshape_entry = shape_entries.offset_const(ctx, i64::try_from(i).unwrap());
|
|
|
|
let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims, false);
|
|
pshape_entry.set(ctx, |f| f.ndims, in_ndims);
|
|
|
|
pshape_entry.set(ctx, |f| f.shape, *in_shape);
|
|
}
|
|
|
|
let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims, false);
|
|
call_nac3_ndarray_broadcast_shapes(
|
|
generator,
|
|
ctx,
|
|
num_shape_entries,
|
|
shape_entries,
|
|
broadcast_ndims,
|
|
broadcast_shape,
|
|
);
|
|
}
|
|
|
|
impl<'ctx> NDArrayObject<'ctx> {
|
|
/// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`]
|
|
/// containing all the information of the result of the broadcast operation.
|
|
pub fn broadcast<G: CodeGenerator + ?Sized>(
|
|
generator: &mut G,
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
ndarrays: &[Self],
|
|
) -> BroadcastAllResult<'ctx> {
|
|
assert!(!ndarrays.is_empty());
|
|
|
|
// Infer the broadcast output ndims.
|
|
let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
|
|
|
|
let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int, false);
|
|
let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value);
|
|
|
|
let shape_entries = ndarrays
|
|
.iter()
|
|
.map(|ndarray| (ndarray.instance.get(generator, ctx, |f| f.shape), ndarray.ndims))
|
|
.collect_vec();
|
|
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape);
|
|
|
|
// Broadcast all the inputs to shape `dst_shape`.
|
|
let broadcast_ndarrays: Vec<_> = ndarrays
|
|
.iter()
|
|
.map(|ndarray| {
|
|
ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape)
|
|
})
|
|
.collect_vec();
|
|
|
|
BroadcastAllResult {
|
|
ndims: broadcast_ndims_int,
|
|
shape: broadcast_shape,
|
|
ndarrays: broadcast_ndarrays,
|
|
}
|
|
}
|
|
}
|