[core] codegen/ndarray: Reimplement broadcasting
Based on 9359ed96
: core/ndstrides: implement broadcasting &
np_broadcast_to()
This commit is contained in:
parent
8d975b5ff3
commit
43e440d2fd
@ -10,4 +10,5 @@
|
||||
#include "irrt/ndarray/iter.hpp"
|
||||
#include "irrt/ndarray/indexing.hpp"
|
||||
#include "irrt/ndarray/array.hpp"
|
||||
#include "irrt/ndarray/reshape.hpp"
|
||||
#include "irrt/ndarray/reshape.hpp"
|
||||
#include "irrt/ndarray/broadcast.hpp"
|
165
nac3core/irrt/irrt/ndarray/broadcast.hpp
Normal file
165
nac3core/irrt/irrt/ndarray/broadcast.hpp
Normal file
@ -0,0 +1,165 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
|
||||
namespace {
|
||||
template<typename SizeT>
|
||||
struct ShapeEntry {
|
||||
SizeT ndims;
|
||||
SizeT* shape;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
namespace ndarray::broadcast {
|
||||
/**
|
||||
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
|
||||
*
|
||||
* See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||
*/
|
||||
template<typename SizeT>
|
||||
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) {
|
||||
if (src_ndims > target_ndims) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (SizeT i = 0; i < src_ndims; i++) {
|
||||
SizeT target_dim = target_shape[target_ndims - i - 1];
|
||||
SizeT src_dim = src_shape[src_ndims - i - 1];
|
||||
if (!(src_dim == 1 || target_dim == src_dim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs `np.broadcast_shapes(<shapes>)`
|
||||
*
|
||||
* @param num_shapes Number of entries in `shapes`
|
||||
* @param shapes The list of shape to do `np.broadcast_shapes` on.
|
||||
* @param dst_ndims The length of `dst_shape`.
|
||||
* `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it.
|
||||
* for this function since they should already know in order to allocate `dst_shape` in the first place.
|
||||
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
|
||||
* of `np.broadcast_shapes` and write it here.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT>* shapes, SizeT dst_ndims, SizeT* dst_shape) {
|
||||
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
|
||||
dst_shape[dst_axis] = 1;
|
||||
}
|
||||
|
||||
#ifdef IRRT_DEBUG_ASSERT
|
||||
SizeT max_ndims_found = 0;
|
||||
#endif
|
||||
|
||||
for (SizeT i = 0; i < num_shapes; i++) {
|
||||
ShapeEntry<SizeT> entry = shapes[i];
|
||||
|
||||
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
||||
debug_assert(SizeT, entry.ndims <= dst_ndims);
|
||||
|
||||
#ifdef IRRT_DEBUG_ASSERT
|
||||
max_ndims_found = max(max_ndims_found, entry.ndims);
|
||||
#endif
|
||||
|
||||
for (SizeT j = 0; j < entry.ndims; j++) {
|
||||
SizeT entry_axis = entry.ndims - j - 1;
|
||||
SizeT dst_axis = dst_ndims - j - 1;
|
||||
|
||||
SizeT entry_dim = entry.shape[entry_axis];
|
||||
SizeT dst_dim = dst_shape[dst_axis];
|
||||
|
||||
if (dst_dim == 1) {
|
||||
dst_shape[dst_axis] = entry_dim;
|
||||
} else if (entry_dim == 1 || entry_dim == dst_dim) {
|
||||
// Do nothing
|
||||
} else {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||
"shape mismatch: objects cannot be broadcast "
|
||||
"to a single shape.",
|
||||
NO_PARAM, NO_PARAM, NO_PARAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef IRRT_DEBUG_ASSERT
|
||||
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
||||
debug_assert_eq(SizeT, max_ndims_found, dst_ndims);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
|
||||
*
|
||||
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
|
||||
* and return the result by modifying `dst_ndarray`.
|
||||
*
|
||||
* # Notes on `dst_ndarray`
|
||||
* The caller is responsible for allocating space for the resulting ndarray.
|
||||
* Here is what this function expects from `dst_ndarray` when called:
|
||||
* - `dst_ndarray->data` does not have to be initialized.
|
||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
|
||||
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
|
||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||
* When this function call ends:
|
||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
||||
* - `dst_ndarray->ndims` is unchanged.
|
||||
* - `dst_ndarray->shape` is unchanged.
|
||||
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void broadcast_to(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
|
||||
src_ndarray->shape)) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM,
|
||||
NO_PARAM);
|
||||
}
|
||||
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
||||
for (SizeT i = 0; i < dst_ndarray->ndims; i++) {
|
||||
SizeT src_axis = src_ndarray->ndims - i - 1;
|
||||
SizeT dst_axis = dst_ndarray->ndims - i - 1;
|
||||
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) {
|
||||
// Freeze the steps in-place
|
||||
dst_ndarray->strides[dst_axis] = 0;
|
||||
} else {
|
||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace ndarray::broadcast
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::broadcast;
|
||||
|
||||
void __nac3_ndarray_broadcast_to(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
||||
broadcast_to(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_to64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
||||
broadcast_to(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes,
|
||||
const ShapeEntry<int32_t>* shapes,
|
||||
int32_t dst_ndims,
|
||||
int32_t* dst_shape) {
|
||||
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes,
|
||||
const ShapeEntry<int64_t>* shapes,
|
||||
int64_t dst_ndims,
|
||||
int64_t* dst_shape) {
|
||||
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
||||
}
|
||||
}
|
82
nac3core/src/codegen/irrt/ndarray/broadcast.rs
Normal file
82
nac3core/src/codegen/irrt/ndarray/broadcast.rs
Normal file
@ -0,0 +1,82 @@
|
||||
use inkwell::values::IntValue;
|
||||
|
||||
use crate::codegen::{
|
||||
expr::infer_and_call_function,
|
||||
irrt::get_usize_dependent_function_name,
|
||||
types::{ndarray::ShapeEntryType, ProxyType},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor,
|
||||
TypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_broadcast_to`.
|
||||
///
|
||||
/// Attempts to broadcast `src_ndarray` to the new shape defined by `dst_ndarray`.
|
||||
///
|
||||
/// `dst_ndarray` must meet the following preconditions:
|
||||
///
|
||||
/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`.
|
||||
/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape.
|
||||
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
|
||||
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: NDArrayValue<'ctx>,
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_broadcast_shapes`.
|
||||
///
|
||||
/// Attempts to calculate the resultant shape from broadcasting all shapes in `shape_entries`,
|
||||
/// writing the result to `dst_shape`.
|
||||
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
num_shape_entries: IntValue<'ctx>,
|
||||
shape_entries: ArraySliceValue<'ctx>,
|
||||
dst_ndims: IntValue<'ctx>,
|
||||
dst_shape: &Shape,
|
||||
) where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
||||
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
assert_eq!(num_shape_entries.get_type(), llvm_usize);
|
||||
assert!(ShapeEntryType::is_type(
|
||||
generator,
|
||||
ctx.ctx,
|
||||
shape_entries.base_ptr(ctx, generator).get_type()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(dst_ndims.get_type(), llvm_usize);
|
||||
assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into());
|
||||
|
||||
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[
|
||||
num_shape_entries.into(),
|
||||
shape_entries.base_ptr(ctx, generator).into(),
|
||||
dst_ndims.into(),
|
||||
dst_shape.base_ptr(ctx, generator).into(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
@ -18,12 +18,14 @@ use crate::codegen::{
|
||||
};
|
||||
pub use array::*;
|
||||
pub use basic::*;
|
||||
pub use broadcast::*;
|
||||
pub use indexing::*;
|
||||
pub use iter::*;
|
||||
pub use reshape::*;
|
||||
|
||||
mod array;
|
||||
mod basic;
|
||||
mod broadcast;
|
||||
mod indexing;
|
||||
mod iter;
|
||||
mod reshape;
|
||||
|
176
nac3core/src/codegen/types/ndarray/broadcast.rs
Normal file
176
nac3core/src/codegen/types/ndarray/broadcast.rs
Normal file
@ -0,0 +1,176 @@
|
||||
use inkwell::{
|
||||
context::{AsContextRef, Context},
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
use crate::codegen::{
|
||||
types::{
|
||||
structure::{check_struct_type_matches_fields, StructField, StructFields},
|
||||
ProxyType,
|
||||
},
|
||||
values::{ndarray::ShapeEntryValue, ProxyValue},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct ShapeEntryType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct ShapeEntryStructFields<'ctx> {
|
||||
#[value_type(usize)]
|
||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> ShapeEntryType<'ctx> {
|
||||
/// Checks whether `llvm_ty` represents a [`ShapeEntryType`], returning [Err] if it does not.
|
||||
pub fn is_representable(
|
||||
llvm_ty: PointerType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
let ctx = llvm_ty.get_context();
|
||||
|
||||
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||
return Err(format!(
|
||||
"Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}"
|
||||
));
|
||||
};
|
||||
|
||||
check_struct_type_matches_fields(
|
||||
Self::fields(ctx, llvm_usize),
|
||||
llvm_ndarray_ty,
|
||||
"NDArray",
|
||||
&[],
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
|
||||
#[must_use]
|
||||
fn fields(
|
||||
ctx: impl AsContextRef<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> ShapeEntryStructFields<'ctx> {
|
||||
ShapeEntryStructFields::new(ctx, llvm_usize)
|
||||
}
|
||||
|
||||
/// See [`ShapeEntryStructFields::fields`].
|
||||
// TODO: Move this into e.g. StructProxyType
|
||||
#[must_use]
|
||||
pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> ShapeEntryStructFields<'ctx> {
|
||||
Self::fields(ctx, self.llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||
let field_tys =
|
||||
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ShapeEntryType`].
|
||||
#[must_use]
|
||||
pub fn new<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
|
||||
let llvm_usize = generator.get_size_type(ctx);
|
||||
let llvm_ty = Self::llvm_type(ctx, llvm_usize);
|
||||
|
||||
Self { ty: llvm_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`.
|
||||
#[must_use]
|
||||
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
Self { ty: ptr_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type.
|
||||
#[must_use]
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type.
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ShapeEntryValue`].
|
||||
#[must_use]
|
||||
pub fn map_value(
|
||||
&self,
|
||||
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> {
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = ShapeEntryValue<'ctx>;
|
||||
|
||||
fn is_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
llvm_ty: Self::Base,
|
||||
) -> Result<(), String> {
|
||||
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_base_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<ShapeEntryType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: ShapeEntryType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
@ -20,11 +20,13 @@ use crate::{
|
||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
pub use broadcast::*;
|
||||
pub use contiguous::*;
|
||||
pub use indexing::*;
|
||||
pub use nditer::*;
|
||||
|
||||
mod array;
|
||||
mod broadcast;
|
||||
mod contiguous;
|
||||
pub mod factory;
|
||||
mod indexing;
|
||||
@ -118,6 +120,20 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more
|
||||
/// `ndarray` operands.
|
||||
#[must_use]
|
||||
pub fn new_broadcast<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
inputs: &[NDArrayType<'ctx>],
|
||||
) -> Self {
|
||||
assert!(!inputs.is_empty());
|
||||
|
||||
Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
|
||||
#[must_use]
|
||||
pub fn new_unsized<G: CodeGenerator + ?Sized>(
|
||||
|
248
nac3core/src/codegen/values/ndarray/broadcast.rs
Normal file
248
nac3core/src/codegen/values/ndarray/broadcast.rs
Normal file
@ -0,0 +1,248 @@
|
||||
use inkwell::{
|
||||
types::IntType,
|
||||
values::{IntValue, PointerValue},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::{
|
||||
irrt,
|
||||
types::{
|
||||
ndarray::{NDArrayType, ShapeEntryType},
|
||||
structure::StructField,
|
||||
ProxyType,
|
||||
},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue,
|
||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ShapeEntryValue<'ctx> {
|
||||
value: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
}
|
||||
|
||||
impl<'ctx> ShapeEntryValue<'ctx> {
|
||||
/// Checks whether `value` is an instance of `ShapeEntry`, returning [Err] if `value` is
|
||||
/// not an instance.
|
||||
pub fn is_representable(
|
||||
value: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
<Self as ProxyValue<'ctx>>::Type::is_representable(value.get_type(), llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an [`ShapeEntryValue`] from a [`PointerValue`].
|
||||
#[must_use]
|
||||
pub fn from_pointer_value(
|
||||
ptr: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||
|
||||
Self { value: ptr, llvm_usize, name }
|
||||
}
|
||||
|
||||
fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
|
||||
self.get_type().get_fields(self.value.get_type().get_context()).ndims
|
||||
}
|
||||
|
||||
/// Stores the number of dimensions into this value.
|
||||
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
|
||||
self.ndims_field().set(ctx, self.value, value, self.name);
|
||||
}
|
||||
|
||||
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||
self.get_type().get_fields(self.value.get_type().get_context()).shape
|
||||
}
|
||||
|
||||
/// Stores the shape into this value.
|
||||
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||
self.shape_field().set(ctx, self.value, value, self.name);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> {
|
||||
type Base = PointerValue<'ctx>;
|
||||
type Type = ShapeEntryType<'ctx>;
|
||||
|
||||
fn get_type(&self) -> Self::Type {
|
||||
Self::Type::from_type(self.value.get_type(), self.llvm_usize)
|
||||
}
|
||||
|
||||
fn as_base_value(&self) -> Self::Base {
|
||||
self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<ShapeEntryValue<'ctx>> for PointerValue<'ctx> {
|
||||
fn from(value: ShapeEntryValue<'ctx>) -> Self {
|
||||
value.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Create a broadcast view on this ndarray with a target shape.
|
||||
///
|
||||
/// The input shape will be checked to make sure that it contains no negative values.
|
||||
///
|
||||
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
|
||||
/// The caller has to figure this out for this function.
|
||||
/// * `target_shape` - An array pointer pointing to the target shape.
|
||||
#[must_use]
|
||||
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target_ndims: u64,
|
||||
target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) -> Self {
|
||||
assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims));
|
||||
assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into());
|
||||
|
||||
let broadcast_ndarray =
|
||||
NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims))
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
broadcast_ndarray.copy_shape_from_array(
|
||||
generator,
|
||||
ctx,
|
||||
target_shape.base_ptr(ctx, generator),
|
||||
);
|
||||
|
||||
irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray);
|
||||
broadcast_ndarray
|
||||
}
|
||||
}
|
||||
|
||||
/// A result produced by [`broadcast_all_ndarrays`]
|
||||
#[derive(Clone)]
|
||||
pub struct BroadcastAllResult<'ctx, G: CodeGenerator + ?Sized> {
|
||||
/// The statically known `ndims` of the broadcast result.
|
||||
pub ndims: u64,
|
||||
|
||||
/// The broadcasting shape.
|
||||
pub shape: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>,
|
||||
|
||||
/// Broadcasted views on the inputs.
|
||||
///
|
||||
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
|
||||
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
|
||||
/// is the same as the input.
|
||||
pub ndarrays: Vec<NDArrayValue<'ctx>>,
|
||||
}
|
||||
|
||||
/// Helper function to call [`irrt::ndarray::call_nac3_ndarray_broadcast_shapes`].
|
||||
fn broadcast_shapes<'ctx, G, Shape>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
in_shape_entries: &[(ArraySliceValue<'ctx>, u64)], // (shape, shape's length/ndims)
|
||||
broadcast_ndims: u64,
|
||||
broadcast_shape: &Shape,
|
||||
) where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
||||
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx);
|
||||
|
||||
assert!(in_shape_entries
|
||||
.iter()
|
||||
.all(|entry| entry.0.element_type(ctx, generator) == llvm_usize.into()));
|
||||
assert_eq!(broadcast_shape.element_type(ctx, generator), llvm_usize.into());
|
||||
|
||||
// Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`.
|
||||
let num_shape_entries =
|
||||
llvm_usize.const_int(u64::try_from(in_shape_entries.len()).unwrap(), false);
|
||||
let shape_entries = llvm_shape_ty.array_alloca(ctx, num_shape_entries, None);
|
||||
for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() {
|
||||
let pshape_entry = unsafe {
|
||||
shape_entries.ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
let shape_entry = llvm_shape_ty.map_value(pshape_entry, None);
|
||||
|
||||
let in_ndims = llvm_usize.const_int(*in_ndims, false);
|
||||
shape_entry.store_ndims(ctx, in_ndims);
|
||||
|
||||
shape_entry.store_shape(ctx, in_shape.base_ptr(ctx, generator));
|
||||
}
|
||||
|
||||
let broadcast_ndims = llvm_usize.const_int(broadcast_ndims, false);
|
||||
irrt::ndarray::call_nac3_ndarray_broadcast_shapes(
|
||||
generator,
|
||||
ctx,
|
||||
num_shape_entries,
|
||||
shape_entries,
|
||||
broadcast_ndims,
|
||||
broadcast_shape,
|
||||
);
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayType<'ctx> {
|
||||
/// Broadcast all ndarrays according to
|
||||
/// [`np.broadcast()`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html)
|
||||
/// and return a [`BroadcastAllResult`] containing all the information of the result of the
|
||||
/// broadcast operation.
|
||||
pub fn broadcast<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarrays: &[NDArrayValue<'ctx>],
|
||||
) -> BroadcastAllResult<'ctx, G> {
|
||||
assert!(!ndarrays.is_empty());
|
||||
assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some()));
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
// Infer the broadcast output ndims.
|
||||
let broadcast_ndims_int =
|
||||
ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap();
|
||||
assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int));
|
||||
|
||||
let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false);
|
||||
let broadcast_shape = ArraySliceValue::from_ptr_val(
|
||||
ctx.builder.build_array_alloca(llvm_usize, broadcast_ndims, "").unwrap(),
|
||||
broadcast_ndims,
|
||||
None,
|
||||
);
|
||||
let broadcast_shape = TypedArrayLikeAdapter::from(
|
||||
broadcast_shape,
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
|
||||
let shape_entries = ndarrays
|
||||
.iter()
|
||||
.map(|ndarray| {
|
||||
(
|
||||
ndarray.shape().as_slice_value(ctx, generator),
|
||||
ndarray.get_type().ndims().unwrap(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape);
|
||||
|
||||
// Broadcast all the inputs to shape `dst_shape`.
|
||||
let broadcast_ndarrays = ndarrays
|
||||
.iter()
|
||||
.map(|ndarray| {
|
||||
ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, &broadcast_shape)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
BroadcastAllResult {
|
||||
ndims: broadcast_ndims_int,
|
||||
shape: broadcast_shape,
|
||||
ndarrays: broadcast_ndarrays,
|
||||
}
|
||||
}
|
||||
}
|
@ -20,10 +20,12 @@ use crate::codegen::{
|
||||
types::{ndarray::NDArrayType, structure::StructField, TupleType},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
pub use broadcast::*;
|
||||
pub use contiguous::*;
|
||||
pub use indexing::*;
|
||||
pub use nditer::*;
|
||||
|
||||
mod broadcast;
|
||||
mod contiguous;
|
||||
mod indexing;
|
||||
mod nditer;
|
||||
|
@ -373,7 +373,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
self.build_ndarray_property_getter_function(prim)
|
||||
}
|
||||
|
||||
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||
PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||
self.build_ndarray_view_function(prim)
|
||||
}
|
||||
|
||||
@ -1328,7 +1328,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
|
||||
/// Build np/sp functions that take as input `NDArray` only
|
||||
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
||||
debug_assert_prim_is_allowed(
|
||||
prim,
|
||||
&[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape],
|
||||
);
|
||||
|
||||
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[self.primitives.ndarray],
|
||||
@ -1356,7 +1359,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||
PrimDef::FunNpReshape => {
|
||||
PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => {
|
||||
// These two functions have the same function signature.
|
||||
// Mixed together for convenience.
|
||||
|
||||
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
|
||||
|
||||
create_fn_by_codegen(
|
||||
@ -1386,7 +1392,17 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape);
|
||||
let new_ndarray = match prim {
|
||||
PrimDef::FunNpBroadcastTo => {
|
||||
ndarray.broadcast_to(generator, ctx, ndims, &shape)
|
||||
}
|
||||
|
||||
PrimDef::FunNpReshape => {
|
||||
ndarray.reshape_or_copy(generator, ctx, ndims, &shape)
|
||||
}
|
||||
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(new_ndarray.as_base_value().as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
|
@ -60,6 +60,7 @@ pub enum PrimDef {
|
||||
FunNpStrides,
|
||||
|
||||
// NumPy ndarray view functions
|
||||
FunNpBroadcastTo,
|
||||
FunNpTranspose,
|
||||
FunNpReshape,
|
||||
|
||||
@ -253,6 +254,7 @@ impl PrimDef {
|
||||
PrimDef::FunNpStrides => fun("np_strides", None),
|
||||
|
||||
// NumPy NDArray view functions
|
||||
PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None),
|
||||
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||
|
||||
|
@ -8,5 +8,5 @@ expression: res_vec
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||
]
|
||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar236]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar236\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar235, typevar236]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar235\", \"typevar236\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(255)]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n",
|
||||
]
|
||||
|
@ -1594,7 +1594,7 @@ impl<'a> Inferencer<'a> {
|
||||
}));
|
||||
}
|
||||
// 2-argument ndarray n-dimensional factory functions
|
||||
if id == &"np_reshape".into() && args.len() == 2 {
|
||||
if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 {
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
|
||||
let shape_expr = args.remove(0);
|
||||
|
@ -180,6 +180,7 @@ def patch(module):
|
||||
module.np_array = np.array
|
||||
|
||||
# NumPy NDArray view functions
|
||||
module.np_broadcast_to = np.broadcast_to
|
||||
module.np_transpose = np.transpose
|
||||
module.np_reshape = np.reshape
|
||||
|
||||
|
@ -68,6 +68,12 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
|
||||
for c in range(len(n[r])):
|
||||
output_float64(n[r][c])
|
||||
|
||||
def output_ndarray_float_3(n: ndarray[float, Literal[3]]):
|
||||
for d in range(len(n)):
|
||||
for r in range(len(n[d])):
|
||||
for c in range(len(n[d][r])):
|
||||
output_float64(n[d][r][c])
|
||||
|
||||
def output_ndarray_float_4(n: ndarray[float, Literal[4]]):
|
||||
for x in range(len(n)):
|
||||
for y in range(len(n[x])):
|
||||
@ -236,6 +242,23 @@ def test_ndarray_reshape():
|
||||
output_int32(np_shape(x2)[1])
|
||||
output_ndarray_int32_2(x2)
|
||||
|
||||
def test_ndarray_broadcast_to():
|
||||
xs = np_array([1.0, 2.0, 3.0])
|
||||
ys = np_broadcast_to(xs, (1, 3))
|
||||
zs = np_broadcast_to(ys, (2, 4, 3))
|
||||
|
||||
output_int32(np_shape(xs)[0])
|
||||
output_ndarray_float_1(xs)
|
||||
|
||||
output_int32(np_shape(ys)[0])
|
||||
output_int32(np_shape(ys)[1])
|
||||
output_ndarray_float_2(ys)
|
||||
|
||||
output_int32(np_shape(zs)[0])
|
||||
output_int32(np_shape(zs)[1])
|
||||
output_int32(np_shape(zs)[2])
|
||||
output_ndarray_float_3(zs)
|
||||
|
||||
def test_ndarray_add():
|
||||
x = np_identity(2)
|
||||
y = x + np_ones([2, 2])
|
||||
@ -1619,6 +1642,7 @@ def run() -> int32:
|
||||
test_ndarray_nd_idx()
|
||||
|
||||
test_ndarray_reshape()
|
||||
test_ndarray_broadcast_to()
|
||||
|
||||
test_ndarray_add()
|
||||
test_ndarray_add_broadcast()
|
||||
|
Loading…
Reference in New Issue
Block a user