diff --git a/nac3core/src/codegen/numpy_new/broadcast.rs b/nac3core/src/codegen/numpy_new/broadcast.rs new file mode 100644 index 00000000..9e5b6532 --- /dev/null +++ b/nac3core/src/codegen/numpy_new/broadcast.rs @@ -0,0 +1,113 @@ +use itertools::Itertools; + +use crate::{ + codegen::{ + irrt::ndarray::broadcast::{ + call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to, ShapeEntry, + }, + model::*, + numpy_new::util::{create_ndims, extract_ndims}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +use super::object::NDArrayObject; + +#[derive(Debug, Clone)] +pub struct BroadcastAllResult<'ctx> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + /// The broadcasting shape. + pub shape: Ptr<'ctx, IntModel>, + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently. +pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: Vec>, +) -> BroadcastAllResult<'ctx> { + assert!(!ndarrays.is_empty()); + + let tyctx = generator.type_context(ctx.ctx); + let sizet_model = IntModel(SizeT); + let shape_model = StructModel(ShapeEntry); + + // We can deduce the final ndims statically and immediately. + // It should be `max([ ndarray.ndims for ndarray in ndarrays ])`. + let broadcast_ndims = + ndarrays.iter().map(|ndarray| extract_ndims(&ctx.unifier, ndarray.ndims)).max().unwrap(); + let broadcast_ndims_ty = create_ndims(&mut ctx.unifier, broadcast_ndims); + + // NOTE: Now prepare before calling `call_nac3_ndarray_broadcast_shapes` + + // Prepare input shape entries + let num_shape_entries = + sizet_model.constant(tyctx, ctx.ctx, u64::try_from(ndarrays.len()).unwrap()); + let shape_entries = + shape_model.array_alloca(tyctx, ctx, num_shape_entries.value, "shape_entries"); + for (i, ndarray) in ndarrays.iter().enumerate() { + let i = sizet_model.constant(tyctx, ctx.ctx, i as u64).value; + + let this_shape = ndarray.instance.gep(ctx, |f| f.shape).load(tyctx, ctx, "this_shape"); + let this_ndims = ndarray.instance.gep(ctx, |f| f.ndims).load(tyctx, ctx, "this_ndims"); + + let shape_entry = shape_entries.offset(tyctx, ctx, i, "shape_entry"); + shape_entry.gep(ctx, |f| f.shape).store(ctx, this_shape); + shape_entry.gep(ctx, |f| f.ndims).store(ctx, this_ndims); + } + + // Prepare destination + let dst_ndims = sizet_model.constant(tyctx, ctx.ctx, broadcast_ndims); + let dst_shape = sizet_model.array_alloca(tyctx, ctx, dst_ndims.value, "dst_shape"); + + call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + dst_ndims, + dst_shape, + ); + + // Now that we know about the broadcasting shape, broadcast all the inputs. + + // Broadcast all the inputs to shape `dst_shape` + let broadcasted_ndarrays = ndarrays + .into_iter() + .map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims_ty, dst_shape)) + .collect_vec(); + + BroadcastAllResult { ndims: broadcast_ndims, shape: dst_shape, ndarrays: broadcasted_ndarrays } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Broadcast an ndarray to a target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims_ty: Type, + target_shape: Ptr<'ctx, IntModel>, + ) -> Self { + // Please see comment in IRRT on how the caller should prepare `dst_ndarray` + let dst_ndarray = NDArrayObject::alloca( + generator, + ctx, + target_ndims_ty, + self.dtype, + "broadcast_ndarray_to_dst", + ); + dst_ndarray.copy_shape(generator, ctx, target_shape); + call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, dst_ndarray.instance); + dst_ndarray + } +} diff --git a/nac3core/src/codegen/numpy_new/mod.rs b/nac3core/src/codegen/numpy_new/mod.rs index 6421f856..f6f8a317 100644 --- a/nac3core/src/codegen/numpy_new/mod.rs +++ b/nac3core/src/codegen/numpy_new/mod.rs @@ -1,3 +1,4 @@ +pub mod broadcast; pub mod factory; pub mod object; pub mod util;