forked from M-Labs/nac3
core/ndstrides: add numpy broadcasting utils
This commit is contained in:
parent
eb295cf7e4
commit
e4f6adb1ec
|
@ -0,0 +1,113 @@
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
irrt::ndarray::broadcast::{
|
||||||
|
call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to, ShapeEntry,
|
||||||
|
},
|
||||||
|
model::*,
|
||||||
|
numpy_new::util::{create_ndims, extract_ndims},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::object::NDArrayObject;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BroadcastAllResult<'ctx> {
|
||||||
|
/// The statically known `ndims` of the broadcast result.
|
||||||
|
pub ndims: u64,
|
||||||
|
/// The broadcasting shape.
|
||||||
|
pub shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
|
/// Broadcasted views on the inputs.
|
||||||
|
///
|
||||||
|
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
|
||||||
|
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
|
||||||
|
/// is the same as the input.
|
||||||
|
pub ndarrays: Vec<NDArrayObject<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
|
||||||
|
pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarrays: Vec<NDArrayObject<'ctx>>,
|
||||||
|
) -> BroadcastAllResult<'ctx> {
|
||||||
|
assert!(!ndarrays.is_empty());
|
||||||
|
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let sizet_model = IntModel(SizeT);
|
||||||
|
let shape_model = StructModel(ShapeEntry);
|
||||||
|
|
||||||
|
// We can deduce the final ndims statically and immediately.
|
||||||
|
// It should be `max([ ndarray.ndims for ndarray in ndarrays ])`.
|
||||||
|
let broadcast_ndims =
|
||||||
|
ndarrays.iter().map(|ndarray| extract_ndims(&ctx.unifier, ndarray.ndims)).max().unwrap();
|
||||||
|
let broadcast_ndims_ty = create_ndims(&mut ctx.unifier, broadcast_ndims);
|
||||||
|
|
||||||
|
// NOTE: Now prepare before calling `call_nac3_ndarray_broadcast_shapes`
|
||||||
|
|
||||||
|
// Prepare input shape entries
|
||||||
|
let num_shape_entries =
|
||||||
|
sizet_model.constant(tyctx, ctx.ctx, u64::try_from(ndarrays.len()).unwrap());
|
||||||
|
let shape_entries =
|
||||||
|
shape_model.array_alloca(tyctx, ctx, num_shape_entries.value, "shape_entries");
|
||||||
|
for (i, ndarray) in ndarrays.iter().enumerate() {
|
||||||
|
let i = sizet_model.constant(tyctx, ctx.ctx, i as u64).value;
|
||||||
|
|
||||||
|
let this_shape = ndarray.instance.gep(ctx, |f| f.shape).load(tyctx, ctx, "this_shape");
|
||||||
|
let this_ndims = ndarray.instance.gep(ctx, |f| f.ndims).load(tyctx, ctx, "this_ndims");
|
||||||
|
|
||||||
|
let shape_entry = shape_entries.offset(tyctx, ctx, i, "shape_entry");
|
||||||
|
shape_entry.gep(ctx, |f| f.shape).store(ctx, this_shape);
|
||||||
|
shape_entry.gep(ctx, |f| f.ndims).store(ctx, this_ndims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare destination
|
||||||
|
let dst_ndims = sizet_model.constant(tyctx, ctx.ctx, broadcast_ndims);
|
||||||
|
let dst_shape = sizet_model.array_alloca(tyctx, ctx, dst_ndims.value, "dst_shape");
|
||||||
|
|
||||||
|
call_nac3_ndarray_broadcast_shapes(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
num_shape_entries,
|
||||||
|
shape_entries,
|
||||||
|
dst_ndims,
|
||||||
|
dst_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Now that we know about the broadcasting shape, broadcast all the inputs.
|
||||||
|
|
||||||
|
// Broadcast all the inputs to shape `dst_shape`
|
||||||
|
let broadcasted_ndarrays = ndarrays
|
||||||
|
.into_iter()
|
||||||
|
.map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims_ty, dst_shape))
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
BroadcastAllResult { ndims: broadcast_ndims, shape: dst_shape, ndarrays: broadcasted_ndarrays }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Broadcast an ndarray to a target shape.
|
||||||
|
#[must_use]
|
||||||
|
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
target_ndims_ty: Type,
|
||||||
|
target_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
|
) -> Self {
|
||||||
|
// Please see comment in IRRT on how the caller should prepare `dst_ndarray`
|
||||||
|
let dst_ndarray = NDArrayObject::alloca(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
target_ndims_ty,
|
||||||
|
self.dtype,
|
||||||
|
"broadcast_ndarray_to_dst",
|
||||||
|
);
|
||||||
|
dst_ndarray.copy_shape(generator, ctx, target_shape);
|
||||||
|
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, dst_ndarray.instance);
|
||||||
|
dst_ndarray
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod broadcast;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod object;
|
pub mod object;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
|
|
Loading…
Reference in New Issue