[core] codegen/ndarray: Reimplement broadcasting
Based on 9359ed96
: core/ndstrides: implement broadcasting &
np_broadcast_to()
This commit is contained in:
parent
936749ae5f
commit
32e1d55de9
@ -9,4 +9,5 @@
|
|||||||
#include "irrt/ndarray/iter.hpp"
|
#include "irrt/ndarray/iter.hpp"
|
||||||
#include "irrt/ndarray/indexing.hpp"
|
#include "irrt/ndarray/indexing.hpp"
|
||||||
#include "irrt/ndarray/array.hpp"
|
#include "irrt/ndarray/array.hpp"
|
||||||
#include "irrt/ndarray/reshape.hpp"
|
#include "irrt/ndarray/reshape.hpp"
|
||||||
|
#include "irrt/ndarray/broadcast.hpp"
|
165
nac3core/irrt/irrt/ndarray/broadcast.hpp
Normal file
165
nac3core/irrt/irrt/ndarray/broadcast.hpp
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
#include "irrt/ndarray/def.hpp"
|
||||||
|
#include "irrt/slice.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template<typename SizeT>
|
||||||
|
struct ShapeEntry {
|
||||||
|
SizeT ndims;
|
||||||
|
SizeT* shape;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace broadcast {
|
||||||
|
/**
|
||||||
|
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
|
||||||
|
*
|
||||||
|
* See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) {
|
||||||
|
if (src_ndims > target_ndims) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < src_ndims; i++) {
|
||||||
|
SizeT target_dim = target_shape[target_ndims - i - 1];
|
||||||
|
SizeT src_dim = src_shape[src_ndims - i - 1];
|
||||||
|
if (!(src_dim == 1 || target_dim == src_dim)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Performs `np.broadcast_shapes(<shapes>)`
|
||||||
|
*
|
||||||
|
* @param num_shapes Number of entries in `shapes`
|
||||||
|
* @param shapes The list of shape to do `np.broadcast_shapes` on.
|
||||||
|
* @param dst_ndims The length of `dst_shape`.
|
||||||
|
* `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it.
|
||||||
|
* for this function since they should already know in order to allocate `dst_shape` in the first place.
|
||||||
|
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
|
||||||
|
* of `np.broadcast_shapes` and write it here.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT>* shapes, SizeT dst_ndims, SizeT* dst_shape) {
|
||||||
|
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
|
||||||
|
dst_shape[dst_axis] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef IRRT_DEBUG_ASSERT
|
||||||
|
SizeT max_ndims_found = 0;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < num_shapes; i++) {
|
||||||
|
ShapeEntry<SizeT> entry = shapes[i];
|
||||||
|
|
||||||
|
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
||||||
|
debug_assert(SizeT, entry.ndims <= dst_ndims);
|
||||||
|
|
||||||
|
#ifdef IRRT_DEBUG_ASSERT
|
||||||
|
max_ndims_found = max(max_ndims_found, entry.ndims);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (SizeT j = 0; j < entry.ndims; j++) {
|
||||||
|
SizeT entry_axis = entry.ndims - j - 1;
|
||||||
|
SizeT dst_axis = dst_ndims - j - 1;
|
||||||
|
|
||||||
|
SizeT entry_dim = entry.shape[entry_axis];
|
||||||
|
SizeT dst_dim = dst_shape[dst_axis];
|
||||||
|
|
||||||
|
if (dst_dim == 1) {
|
||||||
|
dst_shape[dst_axis] = entry_dim;
|
||||||
|
} else if (entry_dim == 1 || entry_dim == dst_dim) {
|
||||||
|
// Do nothing
|
||||||
|
} else {
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||||
|
"shape mismatch: objects cannot be broadcast "
|
||||||
|
"to a single shape.",
|
||||||
|
NO_PARAM, NO_PARAM, NO_PARAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
||||||
|
debug_assert_eq(SizeT, max_ndims_found, dst_ndims);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
|
||||||
|
*
|
||||||
|
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
|
||||||
|
* and return the result by modifying `dst_ndarray`.
|
||||||
|
*
|
||||||
|
* # Notes on `dst_ndarray`
|
||||||
|
* The caller is responsible for allocating space for the resulting ndarray.
|
||||||
|
* Here is what this function expects from `dst_ndarray` when called:
|
||||||
|
* - `dst_ndarray->data` does not have to be initialized.
|
||||||
|
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||||
|
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
|
||||||
|
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
|
||||||
|
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||||
|
* When this function call ends:
|
||||||
|
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
||||||
|
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
||||||
|
* - `dst_ndarray->ndims` is unchanged.
|
||||||
|
* - `dst_ndarray->shape` is unchanged.
|
||||||
|
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void broadcast_to(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
|
if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
|
||||||
|
src_ndarray->shape)) {
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM,
|
||||||
|
NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_ndarray->data = src_ndarray->data;
|
||||||
|
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < dst_ndarray->ndims; i++) {
|
||||||
|
SizeT src_axis = src_ndarray->ndims - i - 1;
|
||||||
|
SizeT dst_axis = dst_ndarray->ndims - i - 1;
|
||||||
|
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) {
|
||||||
|
// Freeze the steps in-place
|
||||||
|
dst_ndarray->strides[dst_axis] = 0;
|
||||||
|
} else {
|
||||||
|
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace broadcast
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
using namespace ndarray::broadcast;
|
||||||
|
|
||||||
|
void __nac3_ndarray_broadcast_to(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
||||||
|
broadcast_to(src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_broadcast_to64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
||||||
|
broadcast_to(src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes,
|
||||||
|
const ShapeEntry<int32_t>* shapes,
|
||||||
|
int32_t dst_ndims,
|
||||||
|
int32_t* dst_shape) {
|
||||||
|
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes,
|
||||||
|
const ShapeEntry<int64_t>* shapes,
|
||||||
|
int64_t dst_ndims,
|
||||||
|
int64_t* dst_shape) {
|
||||||
|
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
||||||
|
}
|
||||||
|
}
|
69
nac3core/src/codegen/irrt/ndarray/broadcast.rs
Normal file
69
nac3core/src/codegen/irrt/ndarray/broadcast.rs
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
use inkwell::values::IntValue;
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
expr::infer_and_call_function,
|
||||||
|
irrt::get_usize_dependent_function_name,
|
||||||
|
types::{ndarray::ShapeEntryType, ProxyType},
|
||||||
|
values::{
|
||||||
|
ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor,
|
||||||
|
TypedArrayLikeMutator,
|
||||||
|
},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: NDArrayValue<'ctx>,
|
||||||
|
dst_ndarray: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to");
|
||||||
|
infer_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
None,
|
||||||
|
&[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
num_shape_entries: IntValue<'ctx>,
|
||||||
|
shape_entries: ArraySliceValue<'ctx>,
|
||||||
|
dst_ndims: IntValue<'ctx>,
|
||||||
|
dst_shape: &Shape,
|
||||||
|
) where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
||||||
|
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
||||||
|
{
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
assert_eq!(num_shape_entries.get_type(), llvm_usize);
|
||||||
|
assert!(ShapeEntryType::is_type(
|
||||||
|
generator,
|
||||||
|
ctx.ctx,
|
||||||
|
shape_entries.base_ptr(ctx, generator).get_type()
|
||||||
|
)
|
||||||
|
.is_ok());
|
||||||
|
assert_eq!(dst_ndims.get_type(), llvm_usize);
|
||||||
|
assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into());
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes");
|
||||||
|
infer_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
num_shape_entries.into(),
|
||||||
|
shape_entries.base_ptr(ctx, generator).into(),
|
||||||
|
dst_ndims.into(),
|
||||||
|
dst_shape.base_ptr(ctx, generator).into(),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
@ -18,12 +18,14 @@ use crate::codegen::{
|
|||||||
};
|
};
|
||||||
pub use array::*;
|
pub use array::*;
|
||||||
pub use basic::*;
|
pub use basic::*;
|
||||||
|
pub use broadcast::*;
|
||||||
pub use indexing::*;
|
pub use indexing::*;
|
||||||
pub use iter::*;
|
pub use iter::*;
|
||||||
pub use reshape::*;
|
pub use reshape::*;
|
||||||
|
|
||||||
mod array;
|
mod array;
|
||||||
mod basic;
|
mod basic;
|
||||||
|
mod broadcast;
|
||||||
mod indexing;
|
mod indexing;
|
||||||
mod iter;
|
mod iter;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
|
176
nac3core/src/codegen/types/ndarray/broadcast.rs
Normal file
176
nac3core/src/codegen/types/ndarray/broadcast.rs
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
use inkwell::{
|
||||||
|
context::{AsContextRef, Context},
|
||||||
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
types::{
|
||||||
|
structure::{check_struct_type_matches_fields, StructField, StructFields},
|
||||||
|
ProxyType,
|
||||||
|
},
|
||||||
|
values::{ndarray::ShapeEntryValue, ProxyValue},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct ShapeEntryType<'ctx> {
|
||||||
|
ty: PointerType<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct ShapeEntryStructFields<'ctx> {
|
||||||
|
#[value_type(usize)]
|
||||||
|
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||||
|
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ShapeEntryType<'ctx> {
|
||||||
|
/// Checks whether `llvm_ty` represents a [`ShapeEntryType`], returning [Err] if it does not.
|
||||||
|
pub fn is_representable(
|
||||||
|
llvm_ty: PointerType<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let ctx = llvm_ty.get_context();
|
||||||
|
|
||||||
|
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
||||||
|
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
check_struct_type_matches_fields(
|
||||||
|
Self::fields(ctx, llvm_usize),
|
||||||
|
llvm_ndarray_ty,
|
||||||
|
"NDArray",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
|
||||||
|
#[must_use]
|
||||||
|
fn fields(
|
||||||
|
ctx: impl AsContextRef<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> ShapeEntryStructFields<'ctx> {
|
||||||
|
ShapeEntryStructFields::new(ctx, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// See [`ShapeEntryStructFields::fields`].
|
||||||
|
// TODO: Move this into e.g. StructProxyType
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> ShapeEntryStructFields<'ctx> {
|
||||||
|
Self::fields(ctx, self.llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`.
|
||||||
|
#[must_use]
|
||||||
|
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||||
|
let field_tys =
|
||||||
|
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||||
|
|
||||||
|
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an instance of [`ShapeEntryType`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn new<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
|
let llvm_ty = Self::llvm_type(ctx, llvm_usize);
|
||||||
|
|
||||||
|
Self { ty: llvm_ty, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||||
|
|
||||||
|
Self { ty: ptr_ty, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn alloca(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||||
|
self.raw_alloca(ctx, name),
|
||||||
|
self.llvm_usize,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||||
|
self.raw_alloca_var(generator, ctx, name),
|
||||||
|
self.llvm_usize,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts an existing value into a [`ShapeEntryValue`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn map_value(
|
||||||
|
&self,
|
||||||
|
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> {
|
||||||
|
type Base = PointerType<'ctx>;
|
||||||
|
type Value = ShapeEntryValue<'ctx>;
|
||||||
|
|
||||||
|
fn is_type<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: impl BasicType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||||
|
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
|
||||||
|
} else {
|
||||||
|
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: Self::Base,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||||
|
self.as_base_type().get_element_type().into_struct_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_type(&self) -> Self::Base {
|
||||||
|
self.ty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<ShapeEntryType<'ctx>> for PointerType<'ctx> {
|
||||||
|
fn from(value: ShapeEntryType<'ctx>) -> Self {
|
||||||
|
value.as_base_type()
|
||||||
|
}
|
||||||
|
}
|
@ -20,11 +20,13 @@ use crate::{
|
|||||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||||
typecheck::typedef::Type,
|
typecheck::typedef::Type,
|
||||||
};
|
};
|
||||||
|
pub use broadcast::*;
|
||||||
pub use contiguous::*;
|
pub use contiguous::*;
|
||||||
pub use indexing::*;
|
pub use indexing::*;
|
||||||
pub use nditer::*;
|
pub use nditer::*;
|
||||||
|
|
||||||
mod array;
|
mod array;
|
||||||
|
mod broadcast;
|
||||||
mod contiguous;
|
mod contiguous;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
mod indexing;
|
mod indexing;
|
||||||
@ -118,6 +120,20 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize }
|
NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more
|
||||||
|
/// `ndarray` operands.
|
||||||
|
#[must_use]
|
||||||
|
pub fn new_broadcast<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
inputs: &[NDArrayType<'ctx>],
|
||||||
|
) -> Self {
|
||||||
|
assert!(!inputs.is_empty());
|
||||||
|
|
||||||
|
Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max())
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
|
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new_unsized<G: CodeGenerator + ?Sized>(
|
pub fn new_unsized<G: CodeGenerator + ?Sized>(
|
||||||
|
@ -208,6 +208,7 @@ pub trait TypedArrayLikeMutator<'ctx, G: CodeGenerator + ?Sized, T, Index = IntV
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// An adapter for constraining untyped array values as typed values.
|
/// An adapter for constraining untyped array values as typed values.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct TypedArrayLikeAdapter<
|
pub struct TypedArrayLikeAdapter<
|
||||||
'ctx,
|
'ctx,
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
|
245
nac3core/src/codegen/values/ndarray/broadcast.rs
Normal file
245
nac3core/src/codegen/values/ndarray/broadcast.rs
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
use inkwell::{
|
||||||
|
types::IntType,
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use crate::codegen::values::TypedArrayLikeMutator;
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt,
|
||||||
|
types::{
|
||||||
|
ndarray::{NDArrayType, ShapeEntryType},
|
||||||
|
structure::StructField,
|
||||||
|
ProxyType,
|
||||||
|
},
|
||||||
|
values::{
|
||||||
|
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue,
|
||||||
|
TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||||
|
},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct ShapeEntryValue<'ctx> {
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ShapeEntryValue<'ctx> {
|
||||||
|
/// Checks whether `value` is an instance of `ShapeEntry`, returning [Err] if `value` is
|
||||||
|
/// not an instance.
|
||||||
|
pub fn is_representable(
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
<Self as ProxyValue<'ctx>>::Type::is_representable(value.get_type(), llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`ShapeEntryValue`] from a [`PointerValue`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_pointer_value(
|
||||||
|
ptr: PointerValue<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||||
|
|
||||||
|
Self { value: ptr, llvm_usize, name }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
|
||||||
|
self.get_type().get_fields(self.value.get_type().get_context()).ndims
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
|
||||||
|
self.ndims_field().set(ctx, self.value, value, self.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||||
|
self.get_type().get_fields(self.value.get_type().get_context()).shape
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||||
|
self.shape_field().set(ctx, self.value, value, self.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> {
|
||||||
|
type Base = PointerValue<'ctx>;
|
||||||
|
type Type = ShapeEntryType<'ctx>;
|
||||||
|
|
||||||
|
fn get_type(&self) -> Self::Type {
|
||||||
|
Self::Type::from_type(self.value.get_type(), self.llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_value(&self) -> Self::Base {
|
||||||
|
self.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<ShapeEntryValue<'ctx>> for PointerValue<'ctx> {
|
||||||
|
fn from(value: ShapeEntryValue<'ctx>) -> Self {
|
||||||
|
value.as_base_value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
/// Create a broadcast view on this ndarray with a target shape.
|
||||||
|
///
|
||||||
|
/// The input shape will be checked to make sure that it contains no negative values.
|
||||||
|
///
|
||||||
|
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
|
||||||
|
/// The caller has to figure this out for this function.
|
||||||
|
/// * `target_shape` - An array pointer pointing to the target shape.
|
||||||
|
#[must_use]
|
||||||
|
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
target_ndims: u64,
|
||||||
|
target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||||
|
) -> Self {
|
||||||
|
assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims));
|
||||||
|
assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into());
|
||||||
|
|
||||||
|
let broadcast_ndarray =
|
||||||
|
NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims))
|
||||||
|
.construct_uninitialized(generator, ctx, None);
|
||||||
|
broadcast_ndarray.copy_shape_from_array(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
target_shape.base_ptr(ctx, generator),
|
||||||
|
);
|
||||||
|
|
||||||
|
irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray);
|
||||||
|
broadcast_ndarray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A result produced by [`broadcast_all_ndarrays`]
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct BroadcastAllResult<'ctx, G: CodeGenerator + ?Sized> {
|
||||||
|
/// The statically known `ndims` of the broadcast result.
|
||||||
|
pub ndims: u64,
|
||||||
|
|
||||||
|
/// The broadcasting shape.
|
||||||
|
pub shape: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>,
|
||||||
|
|
||||||
|
/// Broadcasted views on the inputs.
|
||||||
|
///
|
||||||
|
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
|
||||||
|
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
|
||||||
|
/// is the same as the input.
|
||||||
|
pub ndarrays: Vec<NDArrayValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to call `call_nac3_ndarray_broadcast_shapes`
|
||||||
|
fn broadcast_shapes<'ctx, G, Shape>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
in_shape_entries: &[(ArraySliceValue<'ctx>, u64)], // (shape, shape's length/ndims)
|
||||||
|
broadcast_ndims: u64,
|
||||||
|
broadcast_shape: &Shape,
|
||||||
|
) where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
||||||
|
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
||||||
|
{
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx);
|
||||||
|
|
||||||
|
assert!(in_shape_entries
|
||||||
|
.iter()
|
||||||
|
.all(|entry| entry.0.element_type(ctx, generator) == llvm_usize.into()));
|
||||||
|
assert_eq!(broadcast_shape.element_type(ctx, generator), llvm_usize.into());
|
||||||
|
|
||||||
|
// Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`.
|
||||||
|
let num_shape_entries =
|
||||||
|
llvm_usize.const_int(u64::try_from(in_shape_entries.len()).unwrap(), false);
|
||||||
|
let shape_entries = llvm_shape_ty.array_alloca(ctx, num_shape_entries, None);
|
||||||
|
for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() {
|
||||||
|
let pshape_entry = unsafe {
|
||||||
|
shape_entries.ptr_offset_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(i as u64, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let shape_entry = llvm_shape_ty.map_value(pshape_entry, None);
|
||||||
|
|
||||||
|
let in_ndims = llvm_usize.const_int(*in_ndims, false);
|
||||||
|
shape_entry.store_ndims(ctx, in_ndims);
|
||||||
|
|
||||||
|
shape_entry.store_shape(ctx, in_shape.base_ptr(ctx, generator));
|
||||||
|
}
|
||||||
|
|
||||||
|
let broadcast_ndims = llvm_usize.const_int(broadcast_ndims, false);
|
||||||
|
irrt::ndarray::call_nac3_ndarray_broadcast_shapes(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
num_shape_entries,
|
||||||
|
shape_entries,
|
||||||
|
broadcast_ndims,
|
||||||
|
broadcast_shape,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayType<'ctx> {
|
||||||
|
/// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`]
|
||||||
|
/// containing all the information of the result of the broadcast operation.
|
||||||
|
pub fn broadcast<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarrays: &[NDArrayValue<'ctx>],
|
||||||
|
) -> BroadcastAllResult<'ctx, G> {
|
||||||
|
assert!(!ndarrays.is_empty());
|
||||||
|
assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some()));
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
// Infer the broadcast output ndims.
|
||||||
|
let broadcast_ndims_int =
|
||||||
|
ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap();
|
||||||
|
assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int));
|
||||||
|
|
||||||
|
let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false);
|
||||||
|
let broadcast_shape = ArraySliceValue::from_ptr_val(
|
||||||
|
ctx.builder.build_array_alloca(llvm_usize, broadcast_ndims, "").unwrap(),
|
||||||
|
broadcast_ndims,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let broadcast_shape = TypedArrayLikeAdapter::from(
|
||||||
|
broadcast_shape,
|
||||||
|
|_, _, val| val.into_int_value(),
|
||||||
|
|_, _, val| val.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let shape_entries = ndarrays
|
||||||
|
.iter()
|
||||||
|
.map(|ndarray| {
|
||||||
|
(
|
||||||
|
ndarray.shape().as_slice_value(ctx, generator),
|
||||||
|
ndarray.get_type().ndims().unwrap(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape);
|
||||||
|
|
||||||
|
// Broadcast all the inputs to shape `dst_shape`.
|
||||||
|
let broadcast_ndarrays = ndarrays
|
||||||
|
.iter()
|
||||||
|
.map(|ndarray| {
|
||||||
|
ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, &broadcast_shape)
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
BroadcastAllResult {
|
||||||
|
ndims: broadcast_ndims_int,
|
||||||
|
shape: broadcast_shape,
|
||||||
|
ndarrays: broadcast_ndarrays,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -20,10 +20,12 @@ use crate::codegen::{
|
|||||||
types::{ndarray::NDArrayType, structure::StructField, TupleType},
|
types::{ndarray::NDArrayType, structure::StructField, TupleType},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
pub use broadcast::*;
|
||||||
pub use contiguous::*;
|
pub use contiguous::*;
|
||||||
pub use indexing::*;
|
pub use indexing::*;
|
||||||
pub use nditer::*;
|
pub use nditer::*;
|
||||||
|
|
||||||
|
mod broadcast;
|
||||||
mod contiguous;
|
mod contiguous;
|
||||||
mod indexing;
|
mod indexing;
|
||||||
mod nditer;
|
mod nditer;
|
||||||
|
@ -373,7 +373,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
self.build_ndarray_property_getter_function(prim)
|
self.build_ndarray_property_getter_function(prim)
|
||||||
}
|
}
|
||||||
|
|
||||||
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||||
self.build_ndarray_view_function(prim)
|
self.build_ndarray_view_function(prim)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1328,7 +1328,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
|
|
||||||
/// Build np/sp functions that take as input `NDArray` only
|
/// Build np/sp functions that take as input `NDArray` only
|
||||||
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
debug_assert_prim_is_allowed(
|
||||||
|
prim,
|
||||||
|
&[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape],
|
||||||
|
);
|
||||||
|
|
||||||
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
|
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||||
&[self.primitives.ndarray],
|
&[self.primitives.ndarray],
|
||||||
@ -1356,7 +1359,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||||
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||||
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||||
PrimDef::FunNpReshape => {
|
PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => {
|
||||||
|
// These two functions have the same function signature.
|
||||||
|
// Mixed together for convenience.
|
||||||
|
|
||||||
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
|
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
|
||||||
|
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
@ -1386,7 +1392,18 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
|
|
||||||
let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape);
|
// let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape);
|
||||||
|
let new_ndarray = match prim {
|
||||||
|
PrimDef::FunNpBroadcastTo => {
|
||||||
|
ndarray.broadcast_to(generator, ctx, ndims, &shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
PrimDef::FunNpReshape => {
|
||||||
|
ndarray.reshape_or_copy(generator, ctx, ndims, &shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
Ok(Some(new_ndarray.as_base_value().as_basic_value_enum()))
|
Ok(Some(new_ndarray.as_base_value().as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -60,6 +60,7 @@ pub enum PrimDef {
|
|||||||
FunNpStrides,
|
FunNpStrides,
|
||||||
|
|
||||||
// NumPy ndarray view functions
|
// NumPy ndarray view functions
|
||||||
|
FunNpBroadcastTo,
|
||||||
FunNpTranspose,
|
FunNpTranspose,
|
||||||
FunNpReshape,
|
FunNpReshape,
|
||||||
|
|
||||||
@ -253,6 +254,7 @@ impl PrimDef {
|
|||||||
PrimDef::FunNpStrides => fun("np_strides", None),
|
PrimDef::FunNpStrides => fun("np_strides", None),
|
||||||
|
|
||||||
// NumPy NDArray view functions
|
// NumPy NDArray view functions
|
||||||
|
PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None),
|
||||||
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||||
PrimDef::FunNpReshape => fun("np_reshape", None),
|
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||||
|
|
||||||
|
@ -8,5 +8,5 @@ expression: res_vec
|
|||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||||
]
|
]
|
||||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar237]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar237\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(250)]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar236, typevar237]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar236\", \"typevar237\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n",
|
||||||
]
|
]
|
||||||
|
@ -1594,7 +1594,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
// 2-argument ndarray n-dimensional factory functions
|
// 2-argument ndarray n-dimensional factory functions
|
||||||
if id == &"np_reshape".into() && args.len() == 2 {
|
if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 {
|
||||||
let arg0 = self.fold_expr(args.remove(0))?;
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
|
||||||
let shape_expr = args.remove(0);
|
let shape_expr = args.remove(0);
|
||||||
|
@ -180,6 +180,7 @@ def patch(module):
|
|||||||
module.np_array = np.array
|
module.np_array = np.array
|
||||||
|
|
||||||
# NumPy NDArray view functions
|
# NumPy NDArray view functions
|
||||||
|
module.np_broadcast_to = np.broadcast_to
|
||||||
module.np_transpose = np.transpose
|
module.np_transpose = np.transpose
|
||||||
module.np_reshape = np.reshape
|
module.np_reshape = np.reshape
|
||||||
|
|
||||||
|
@ -68,6 +68,12 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
|
|||||||
for c in range(len(n[r])):
|
for c in range(len(n[r])):
|
||||||
output_float64(n[r][c])
|
output_float64(n[r][c])
|
||||||
|
|
||||||
|
def output_ndarray_float_3(n: ndarray[float, Literal[3]]):
|
||||||
|
for d in range(len(n)):
|
||||||
|
for r in range(len(n[d])):
|
||||||
|
for c in range(len(n[d][r])):
|
||||||
|
output_float64(n[d][r][c])
|
||||||
|
|
||||||
def output_ndarray_float_4(n: ndarray[float, Literal[4]]):
|
def output_ndarray_float_4(n: ndarray[float, Literal[4]]):
|
||||||
for x in range(len(n)):
|
for x in range(len(n)):
|
||||||
for y in range(len(n[x])):
|
for y in range(len(n[x])):
|
||||||
@ -236,6 +242,23 @@ def test_ndarray_reshape():
|
|||||||
output_int32(np_shape(x2)[1])
|
output_int32(np_shape(x2)[1])
|
||||||
output_ndarray_int32_2(x2)
|
output_ndarray_int32_2(x2)
|
||||||
|
|
||||||
|
def test_ndarray_broadcast_to():
|
||||||
|
xs = np_array([1.0, 2.0, 3.0])
|
||||||
|
ys = np_broadcast_to(xs, (1, 3))
|
||||||
|
zs = np_broadcast_to(ys, (2, 4, 3))
|
||||||
|
|
||||||
|
output_int32(np_shape(xs)[0])
|
||||||
|
output_ndarray_float_1(xs)
|
||||||
|
|
||||||
|
output_int32(np_shape(ys)[0])
|
||||||
|
output_int32(np_shape(ys)[1])
|
||||||
|
output_ndarray_float_2(ys)
|
||||||
|
|
||||||
|
output_int32(np_shape(zs)[0])
|
||||||
|
output_int32(np_shape(zs)[1])
|
||||||
|
output_int32(np_shape(zs)[2])
|
||||||
|
output_ndarray_float_3(zs)
|
||||||
|
|
||||||
def test_ndarray_add():
|
def test_ndarray_add():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = x + np_ones([2, 2])
|
y = x + np_ones([2, 2])
|
||||||
@ -1619,6 +1642,7 @@ def run() -> int32:
|
|||||||
test_ndarray_nd_idx()
|
test_ndarray_nd_idx()
|
||||||
|
|
||||||
test_ndarray_reshape()
|
test_ndarray_reshape()
|
||||||
|
test_ndarray_broadcast_to()
|
||||||
|
|
||||||
test_ndarray_add()
|
test_ndarray_add()
|
||||||
test_ndarray_add_broadcast()
|
test_ndarray_add_broadcast()
|
||||||
|
Loading…
Reference in New Issue
Block a user