forked from M-Labs/nac3
core/ndstrides: add numpy broadcasting utils
This commit is contained in:
parent
eb295cf7e4
commit
e4f6adb1ec
|
@ -0,0 +1,113 @@
|
|||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt::ndarray::broadcast::{
|
||||
call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to, ShapeEntry,
|
||||
},
|
||||
model::*,
|
||||
numpy_new::util::{create_ndims, extract_ndims},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
use super::object::NDArrayObject;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BroadcastAllResult<'ctx> {
|
||||
/// The statically known `ndims` of the broadcast result.
|
||||
pub ndims: u64,
|
||||
/// The broadcasting shape.
|
||||
pub shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
/// Broadcasted views on the inputs.
|
||||
///
|
||||
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
|
||||
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
|
||||
/// is the same as the input.
|
||||
pub ndarrays: Vec<NDArrayObject<'ctx>>,
|
||||
}
|
||||
|
||||
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
|
||||
pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarrays: Vec<NDArrayObject<'ctx>>,
|
||||
) -> BroadcastAllResult<'ctx> {
|
||||
assert!(!ndarrays.is_empty());
|
||||
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let shape_model = StructModel(ShapeEntry);
|
||||
|
||||
// We can deduce the final ndims statically and immediately.
|
||||
// It should be `max([ ndarray.ndims for ndarray in ndarrays ])`.
|
||||
let broadcast_ndims =
|
||||
ndarrays.iter().map(|ndarray| extract_ndims(&ctx.unifier, ndarray.ndims)).max().unwrap();
|
||||
let broadcast_ndims_ty = create_ndims(&mut ctx.unifier, broadcast_ndims);
|
||||
|
||||
// NOTE: Now prepare before calling `call_nac3_ndarray_broadcast_shapes`
|
||||
|
||||
// Prepare input shape entries
|
||||
let num_shape_entries =
|
||||
sizet_model.constant(tyctx, ctx.ctx, u64::try_from(ndarrays.len()).unwrap());
|
||||
let shape_entries =
|
||||
shape_model.array_alloca(tyctx, ctx, num_shape_entries.value, "shape_entries");
|
||||
for (i, ndarray) in ndarrays.iter().enumerate() {
|
||||
let i = sizet_model.constant(tyctx, ctx.ctx, i as u64).value;
|
||||
|
||||
let this_shape = ndarray.instance.gep(ctx, |f| f.shape).load(tyctx, ctx, "this_shape");
|
||||
let this_ndims = ndarray.instance.gep(ctx, |f| f.ndims).load(tyctx, ctx, "this_ndims");
|
||||
|
||||
let shape_entry = shape_entries.offset(tyctx, ctx, i, "shape_entry");
|
||||
shape_entry.gep(ctx, |f| f.shape).store(ctx, this_shape);
|
||||
shape_entry.gep(ctx, |f| f.ndims).store(ctx, this_ndims);
|
||||
}
|
||||
|
||||
// Prepare destination
|
||||
let dst_ndims = sizet_model.constant(tyctx, ctx.ctx, broadcast_ndims);
|
||||
let dst_shape = sizet_model.array_alloca(tyctx, ctx, dst_ndims.value, "dst_shape");
|
||||
|
||||
call_nac3_ndarray_broadcast_shapes(
|
||||
generator,
|
||||
ctx,
|
||||
num_shape_entries,
|
||||
shape_entries,
|
||||
dst_ndims,
|
||||
dst_shape,
|
||||
);
|
||||
|
||||
// Now that we know about the broadcasting shape, broadcast all the inputs.
|
||||
|
||||
// Broadcast all the inputs to shape `dst_shape`
|
||||
let broadcasted_ndarrays = ndarrays
|
||||
.into_iter()
|
||||
.map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims_ty, dst_shape))
|
||||
.collect_vec();
|
||||
|
||||
BroadcastAllResult { ndims: broadcast_ndims, shape: dst_shape, ndarrays: broadcasted_ndarrays }
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Broadcast an ndarray to a target shape.
|
||||
#[must_use]
|
||||
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target_ndims_ty: Type,
|
||||
target_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) -> Self {
|
||||
// Please see comment in IRRT on how the caller should prepare `dst_ndarray`
|
||||
let dst_ndarray = NDArrayObject::alloca(
|
||||
generator,
|
||||
ctx,
|
||||
target_ndims_ty,
|
||||
self.dtype,
|
||||
"broadcast_ndarray_to_dst",
|
||||
);
|
||||
dst_ndarray.copy_shape(generator, ctx, target_shape);
|
||||
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, dst_ndarray.instance);
|
||||
dst_ndarray
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
pub mod broadcast;
|
||||
pub mod factory;
|
||||
pub mod object;
|
||||
pub mod util;
|
||||
|
|
Loading…
Reference in New Issue