From 32e1d55de9d74102e25e1112f0bd600ed6c67165 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 15:23:41 +0800 Subject: [PATCH] [core] codegen/ndarray: Reimplement broadcasting Based on 9359ed96: core/ndstrides: implement broadcasting & np_broadcast_to() --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/broadcast.hpp | 165 ++++++++++++ .../src/codegen/irrt/ndarray/broadcast.rs | 69 +++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + .../src/codegen/types/ndarray/broadcast.rs | 176 +++++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 16 ++ nac3core/src/codegen/values/array.rs | 1 + .../src/codegen/values/ndarray/broadcast.rs | 245 ++++++++++++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 2 + nac3core/src/toplevel/builtins.rs | 25 +- nac3core/src/toplevel/helper.rs | 2 + ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 2 +- nac3standalone/demo/interpret_demo.py | 1 + nac3standalone/demo/src/ndarray.py | 24 ++ 19 files changed, 734 insertions(+), 13 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/broadcast.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/broadcast.rs create mode 100644 nac3core/src/codegen/types/ndarray/broadcast.rs create mode 100644 nac3core/src/codegen/values/ndarray/broadcast.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 2615082c..1fbc037f 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,4 +9,5 @@ #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" #include "irrt/ndarray/array.hpp" -#include "irrt/ndarray/reshape.hpp" \ No newline at end of file +#include "irrt/ndarray/reshape.hpp" +#include "irrt/ndarray/broadcast.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 00000000..e419081c --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/slice.hpp" + +namespace { +template +struct ShapeEntry { + SizeT ndims; + SizeT* shape; +}; +} // namespace + +namespace { +namespace ndarray { +namespace broadcast { +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + * + * See https://numpy.org/doc/stable/user/basics.broadcasting.html + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) { + if (src_ndims > target_ndims) { + return false; + } + + for (SizeT i = 0; i < src_ndims; i++) { + SizeT target_dim = target_shape[target_ndims - i - 1]; + SizeT src_dim = src_shape[src_ndims - i - 1]; + if (!(src_dim == 1 || target_dim == src_dim)) { + return false; + } + } + return true; +} + +/** + * @brief Performs `np.broadcast_shapes()` + * + * @param num_shapes Number of entries in `shapes` + * @param shapes The list of shape to do `np.broadcast_shapes` on. + * @param dst_ndims The length of `dst_shape`. + * `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it. + * for this function since they should already know in order to allocate `dst_shape` in the first place. + * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result + * of `np.broadcast_shapes` and write it here. + */ +template +void broadcast_shapes(SizeT num_shapes, const ShapeEntry* shapes, SizeT dst_ndims, SizeT* dst_shape) { + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) { + dst_shape[dst_axis] = 1; + } + +#ifdef IRRT_DEBUG_ASSERT + SizeT max_ndims_found = 0; +#endif + + for (SizeT i = 0; i < num_shapes; i++) { + ShapeEntry entry = shapes[i]; + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert(SizeT, entry.ndims <= dst_ndims); + +#ifdef IRRT_DEBUG_ASSERT + max_ndims_found = max(max_ndims_found, entry.ndims); +#endif + + for (SizeT j = 0; j < entry.ndims; j++) { + SizeT entry_axis = entry.ndims - j - 1; + SizeT dst_axis = dst_ndims - j - 1; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) { + dst_shape[dst_axis] = entry_dim; + } else if (entry_dim == 1 || entry_dim == dst_dim) { + // Do nothing + } else { + raise_exception(SizeT, EXN_VALUE_ERROR, + "shape mismatch: objects cannot be broadcast " + "to a single shape.", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } + } + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert_eq(SizeT, max_ndims_found, dst_ndims); +} + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template +void broadcast_to(const NDArray* src_ndarray, NDArray* dst_ndarray) { + if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) { + raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, + NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + for (SizeT i = 0; i < dst_ndarray->ndims; i++) { + SizeT src_axis = src_ndarray->ndims - i - 1; + SizeT dst_axis = dst_ndarray->ndims - i - 1; + if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) { + // Freeze the steps in-place + dst_ndarray->strides[dst_axis] = 0; + } else { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} // namespace broadcast +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::broadcast; + +void __nac3_ndarray_broadcast_to(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_to64(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, + const ShapeEntry* shapes, + int32_t dst_ndims, + int32_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} + +void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, + const ShapeEntry* shapes, + int64_t dst_ndims, + int64_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs new file mode 100644 index 00000000..120aaaa6 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -0,0 +1,69 @@ +use inkwell::values::IntValue; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + types::{ndarray::ShapeEntryType, ProxyType}, + values::{ + ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, +) { + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + infer_and_call_function( + ctx, + &name, + None, + &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + None, + None, + ); +} + +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + num_shape_entries: IntValue<'ctx>, + shape_entries: ArraySliceValue<'ctx>, + dst_ndims: IntValue<'ctx>, + dst_shape: &Shape, +) where + G: CodeGenerator + ?Sized, + Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!(num_shape_entries.get_type(), llvm_usize); + assert!(ShapeEntryType::is_type( + generator, + ctx.ctx, + shape_entries.base_ptr(ctx, generator).get_type() + ) + .is_ok()); + assert_eq!(dst_ndims.get_type(), llvm_usize); + assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + infer_and_call_function( + ctx, + &name, + None, + &[ + num_shape_entries.into(), + shape_entries.base_ptr(ctx, generator).into(), + dst_ndims.into(), + dst_shape.base_ptr(ctx, generator).into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index f67566b9..c640042b 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -18,12 +18,14 @@ use crate::codegen::{ }; pub use array::*; pub use basic::*; +pub use broadcast::*; pub use indexing::*; pub use iter::*; pub use reshape::*; mod array; mod basic; +mod broadcast; mod indexing; mod iter; mod reshape; diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs new file mode 100644 index 00000000..7e0579d2 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -0,0 +1,176 @@ +use inkwell::{ + context::{AsContextRef, Context}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use crate::codegen::{ + types::{ + structure::{check_struct_type_matches_fields, StructField, StructFields}, + ProxyType, + }, + values::{ndarray::ShapeEntryValue, ProxyValue}, + CodeGenContext, CodeGenerator, +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ShapeEntryType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ShapeEntryStructFields<'ctx> { + #[value_type(usize)] + pub ndims: StructField<'ctx, IntValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub shape: StructField<'ctx, PointerValue<'ctx>>, +} + +impl<'ctx> ShapeEntryType<'ctx> { + /// Checks whether `llvm_ty` represents a [`ShapeEntryType`], returning [Err] if it does not. + pub fn is_representable( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let ctx = llvm_ty.get_context(); + + let llvm_ndarray_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!( + "Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}" + )); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) + } + + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields( + ctx: impl AsContextRef<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> ShapeEntryStructFields<'ctx> { + ShapeEntryStructFields::new(ctx, llvm_usize) + } + + /// See [`ShapeEntryStructFields::fields`]. + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> ShapeEntryStructFields<'ctx> { + Self::fields(ctx, self.llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new(generator: &G, ctx: &'ctx Context) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_ty = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_ty, llvm_usize } + } + + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `NDArray`. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ShapeEntryValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { + type Base = PointerType<'ctx>; + type Value = ShapeEntryValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: ShapeEntryType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index dd41df67..51d851f4 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -20,11 +20,13 @@ use crate::{ toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, typecheck::typedef::Type, }; +pub use broadcast::*; pub use contiguous::*; pub use indexing::*; pub use nditer::*; mod array; +mod broadcast; mod contiguous; pub mod factory; mod indexing; @@ -118,6 +120,20 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize } } + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more + /// `ndarray` operands. + #[must_use] + pub fn new_broadcast( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + inputs: &[NDArrayType<'ctx>], + ) -> Self { + assert!(!inputs.is_empty()); + + Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max()) + } + /// Creates an instance of [`NDArrayType`] with `ndims` of 0. #[must_use] pub fn new_unsized( diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 55e91b21..6808f696 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -208,6 +208,7 @@ pub trait TypedArrayLikeMutator<'ctx, G: CodeGenerator + ?Sized, T, Index = IntV } /// An adapter for constraining untyped array values as typed values. +#[derive(Clone)] pub struct TypedArrayLikeAdapter< 'ctx, G: CodeGenerator + ?Sized, diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs new file mode 100644 index 00000000..12861dc2 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -0,0 +1,245 @@ +use inkwell::{ + types::IntType, + values::{IntValue, PointerValue}, +}; +use itertools::Itertools; + +use crate::codegen::values::TypedArrayLikeMutator; +use crate::codegen::{ + irrt, + types::{ + ndarray::{NDArrayType, ShapeEntryType}, + structure::StructField, + ProxyType, + }, + values::{ + ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, + }, + CodeGenContext, CodeGenerator, +}; + +#[derive(Copy, Clone)] +pub struct ShapeEntryValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> ShapeEntryValue<'ctx> { + /// Checks whether `value` is an instance of `ShapeEntry`, returning [Err] if `value` is + /// not an instance. + pub fn is_representable( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + >::Type::is_representable(value.get_type(), llvm_usize) + } + + /// Creates an [`ShapeEntryValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields(self.value.get_type().get_context()).ndims + } + + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + self.ndims_field().set(ctx, self.value, value, self.name); + } + + fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(self.value.get_type().get_context()).shape + } + + pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + self.shape_field().set(ctx, self.value, value, self.name); + } +} + +impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = ShapeEntryType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ShapeEntryValue<'ctx>) -> Self { + value.as_base_value() + } +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Create a broadcast view on this ndarray with a target shape. + /// + /// The input shape will be checked to make sure that it contains no negative values. + /// + /// * `target_ndims` - The ndims type after broadcasting to the given shape. + /// The caller has to figure this out for this function. + /// * `target_shape` - An array pointer pointing to the target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims: u64, + target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) -> Self { + assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims)); + assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); + + let broadcast_ndarray = + NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims)) + .construct_uninitialized(generator, ctx, None); + broadcast_ndarray.copy_shape_from_array( + generator, + ctx, + target_shape.base_ptr(ctx, generator), + ); + + irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray); + broadcast_ndarray + } +} + +/// A result produced by [`broadcast_all_ndarrays`] +#[derive(Clone)] +pub struct BroadcastAllResult<'ctx, G: CodeGenerator + ?Sized> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + + /// The broadcasting shape. + pub shape: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, + + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +/// Helper function to call `call_nac3_ndarray_broadcast_shapes` +fn broadcast_shapes<'ctx, G, Shape>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape_entries: &[(ArraySliceValue<'ctx>, u64)], // (shape, shape's length/ndims) + broadcast_ndims: u64, + broadcast_shape: &Shape, +) where + G: CodeGenerator + ?Sized, + Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); + + assert!(in_shape_entries + .iter() + .all(|entry| entry.0.element_type(ctx, generator) == llvm_usize.into())); + assert_eq!(broadcast_shape.element_type(ctx, generator), llvm_usize.into()); + + // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. + let num_shape_entries = + llvm_usize.const_int(u64::try_from(in_shape_entries.len()).unwrap(), false); + let shape_entries = llvm_shape_ty.array_alloca(ctx, num_shape_entries, None); + for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { + let pshape_entry = unsafe { + shape_entries.ptr_offset_unchecked( + ctx, + generator, + &llvm_usize.const_int(i as u64, false), + None, + ) + }; + let shape_entry = llvm_shape_ty.map_value(pshape_entry, None); + + let in_ndims = llvm_usize.const_int(*in_ndims, false); + shape_entry.store_ndims(ctx, in_ndims); + + shape_entry.store_shape(ctx, in_shape.base_ptr(ctx, generator)); + } + + let broadcast_ndims = llvm_usize.const_int(broadcast_ndims, false); + irrt::ndarray::call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims, + broadcast_shape, + ); +} + +impl<'ctx> NDArrayType<'ctx> { + /// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`] + /// containing all the information of the result of the broadcast operation. + pub fn broadcast( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[NDArrayValue<'ctx>], + ) -> BroadcastAllResult<'ctx, G> { + assert!(!ndarrays.is_empty()); + assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some())); + + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Infer the broadcast output ndims. + let broadcast_ndims_int = + ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap(); + assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int)); + + let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false); + let broadcast_shape = ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, broadcast_ndims, "").unwrap(), + broadcast_ndims, + None, + ); + let broadcast_shape = TypedArrayLikeAdapter::from( + broadcast_shape, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + let shape_entries = ndarrays + .iter() + .map(|ndarray| { + ( + ndarray.shape().as_slice_value(ctx, generator), + ndarray.get_type().ndims().unwrap(), + ) + }) + .collect_vec(); + broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape); + + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays = ndarrays + .iter() + .map(|ndarray| { + ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, &broadcast_shape) + }) + .collect_vec(); + + BroadcastAllResult { + ndims: broadcast_ndims_int, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index dc3b5c8d..449b1a61 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -20,10 +20,12 @@ use crate::codegen::{ types::{ndarray::NDArrayType, structure::StructField, TupleType}, CodeGenContext, CodeGenerator, }; +pub use broadcast::*; pub use contiguous::*; pub use indexing::*; pub use nditer::*; +mod broadcast; mod contiguous; mod indexing; mod nditer; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index db7acaf3..8bd2ef54 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -373,7 +373,7 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { self.build_ndarray_view_function(prim) } @@ -1328,7 +1328,10 @@ impl<'a> BuiltinBuilder<'a> { /// Build np/sp functions that take as input `NDArray` only fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape], + ); let in_ndarray_ty = self.unifier.get_fresh_var_with_range( &[self.primitives.ndarray], @@ -1356,7 +1359,10 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => { + PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => { + // These two functions have the same function signature. + // Mixed together for convenience. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding create_fn_by_codegen( @@ -1386,7 +1392,18 @@ impl<'a> BuiltinBuilder<'a> { let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let ndims = extract_ndims(&ctx.unifier, ndims); - let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape); + // let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape); + let new_ndarray = match prim { + PrimDef::FunNpBroadcastTo => { + ndarray.broadcast_to(generator, ctx, ndims, &shape) + } + + PrimDef::FunNpReshape => { + ndarray.reshape_or_copy(generator, ctx, ndims, &shape) + } + + _ => unreachable!(), + }; Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) }), ) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 9313b13b..de90a41b 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -60,6 +60,7 @@ pub enum PrimDef { FunNpStrides, // NumPy ndarray view functions + FunNpBroadcastTo, FunNpTranspose, FunNpReshape, @@ -253,6 +254,7 @@ impl PrimDef { PrimDef::FunNpStrides => fun("np_strides", None), // NumPy NDArray view functions + PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None), PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunNpReshape => fun("np_reshape", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 263407f5..41b39bb8 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(253)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 284fdc12..90408d91 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar237]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar237\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 7228fc3d..f0418889 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(250)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index c4216511..72e54e02 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar236, typevar237]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar236\", \"typevar237\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index ddb0bfaa..a8a534cd 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(256)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 8f1c54fc..87692114 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1594,7 +1594,7 @@ impl<'a> Inferencer<'a> { })); } // 2-argument ndarray n-dimensional factory functions - if id == &"np_reshape".into() && args.len() == 2 { + if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 { let arg0 = self.fold_expr(args.remove(0))?; let shape_expr = args.remove(0); diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 56c6126d..8784ce53 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -180,6 +180,7 @@ def patch(module): module.np_array = np.array # NumPy NDArray view functions + module.np_broadcast_to = np.broadcast_to module.np_transpose = np.transpose module.np_reshape = np.reshape diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index a82cfbad..374bcf73 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -68,6 +68,12 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): for c in range(len(n[r])): output_float64(n[r][c]) +def output_ndarray_float_3(n: ndarray[float, Literal[3]]): + for d in range(len(n)): + for r in range(len(n[d])): + for c in range(len(n[d][r])): + output_float64(n[d][r][c]) + def output_ndarray_float_4(n: ndarray[float, Literal[4]]): for x in range(len(n)): for y in range(len(n[x])): @@ -236,6 +242,23 @@ def test_ndarray_reshape(): output_int32(np_shape(x2)[1]) output_ndarray_int32_2(x2) +def test_ndarray_broadcast_to(): + xs = np_array([1.0, 2.0, 3.0]) + ys = np_broadcast_to(xs, (1, 3)) + zs = np_broadcast_to(ys, (2, 4, 3)) + + output_int32(np_shape(xs)[0]) + output_ndarray_float_1(xs) + + output_int32(np_shape(ys)[0]) + output_int32(np_shape(ys)[1]) + output_ndarray_float_2(ys) + + output_int32(np_shape(zs)[0]) + output_int32(np_shape(zs)[1]) + output_int32(np_shape(zs)[2]) + output_ndarray_float_3(zs) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1619,6 +1642,7 @@ def run() -> int32: test_ndarray_nd_idx() test_ndarray_reshape() + test_ndarray_broadcast_to() test_ndarray_add() test_ndarray_add_broadcast()