1
0
forked from M-Labs/nac3

[core] codegen/ndarray: Implement np_reshape

Based on 926e7e93: core/ndstrides: implement np_reshape()
This commit is contained in:
David Mak 2024-12-18 11:40:23 +08:00
parent aae41eef6a
commit 8d975b5ff3
13 changed files with 308 additions and 373 deletions

View File

@ -9,4 +9,5 @@
#include "irrt/ndarray/def.hpp" #include "irrt/ndarray/def.hpp"
#include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/iter.hpp"
#include "irrt/ndarray/indexing.hpp" #include "irrt/ndarray/indexing.hpp"
#include "irrt/ndarray/array.hpp" #include "irrt/ndarray/array.hpp"
#include "irrt/ndarray/reshape.hpp"

View File

@ -0,0 +1,97 @@
#pragma once
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray::reshape {
/**
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
*
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
* modified to contain the resolved dimension.
*
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
*
* @param size The `.size` of `<ndarray>`
* @param new_ndims Number of elements in `new_shape`
* @param new_shape Target shape to reshape to
*/
template<typename SizeT>
void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) {
// Is there a -1 in `new_shape`?
bool neg1_exists = false;
// Location of -1, only initialized if `neg1_exists` is true
SizeT neg1_axis_i;
// The computed ndarray size of `new_shape`
SizeT new_size = 1;
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
SizeT dim = new_shape[axis_i];
if (dim < 0) {
if (dim == -1) {
if (neg1_exists) {
// Multiple `-1` found. Throw an error.
raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM,
NO_PARAM, NO_PARAM);
} else {
neg1_exists = true;
neg1_axis_i = axis_i;
}
} else {
// TODO: What? In `np.reshape` any negative dimensions is
// treated like its `-1`.
//
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
//
// It is not documented by numpy.
// Throw an error for now...
raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i,
NO_PARAM);
}
} else {
new_size *= dim;
}
}
bool can_reshape;
if (neg1_exists) {
// Let `x` be the unknown dimension
// Solve `x * <new_size> = <size>`
if (new_size == 0 && size == 0) {
// `x` has infinitely many solutions
can_reshape = false;
} else if (new_size == 0 && size != 0) {
// `x` has no solutions
can_reshape = false;
} else if (size % new_size != 0) {
// `x` has no integer solutions
can_reshape = false;
} else {
can_reshape = true;
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
}
} else {
can_reshape = (new_size == size);
}
if (!can_reshape) {
raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM,
NO_PARAM);
}
}
} // namespace ndarray::reshape
} // namespace
extern "C" {
void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
}

View File

@ -20,11 +20,13 @@ pub use array::*;
pub use basic::*; pub use basic::*;
pub use indexing::*; pub use indexing::*;
pub use iter::*; pub use iter::*;
pub use reshape::*;
mod array; mod array;
mod basic; mod basic;
mod indexing; mod indexing;
mod iter; mod iter;
mod reshape;
/// Generates a call to `__nac3_ndarray_calc_size`. Returns a /// Generates a call to `__nac3_ndarray_calc_size`. Returns a
/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. /// [`usize`][CodeGenerator::get_size_type] representing the calculated total size.

View File

@ -0,0 +1,40 @@
use inkwell::values::IntValue;
use crate::codegen::{
expr::infer_and_call_function,
irrt::get_usize_dependent_function_name,
values::{ArrayLikeValue, ArraySliceValue},
CodeGenContext, CodeGenerator,
};
/// Generates a call to `__nac3_ndarray_reshape_resolve_and_check_new_shape`.
///
/// Resolves unknown dimensions in `new_shape` for `numpy.reshape(<ndarray>, new_shape)`, raising an
/// assertion if multiple dimensions are unknown (`-1`).
pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
new_ndims: IntValue<'ctx>,
new_shape: ArraySliceValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!(size.get_type(), llvm_usize);
assert_eq!(new_ndims.get_type(), llvm_usize);
assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into());
let name = get_usize_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_reshape_resolve_and_check_new_shape",
);
infer_and_call_function(
ctx,
&name,
None,
&[size.into(), new_ndims.into(), new_shape.base_ptr(ctx, generator).into()],
None,
None,
);
}

View File

@ -21,9 +21,9 @@ use super::{
types::ndarray::{factory::ndarray_zero_value, NDArrayType}, types::ndarray::{factory::ndarray_zero_value, NDArrayType},
values::{ values::{
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator, UntypedArrayLikeMutator,
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -134,46 +134,6 @@ where
Ok(ndarray) Ok(ndarray)
} }
/// Creates an `NDArray` instance from a constant shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: &[IntValue<'ctx>],
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
for &shape_dim in shape {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let shape_dim_gez = ctx
.builder
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
// TODO: Disallow shape > u32_MAX
}
let llvm_dtype = ctx.get_llvm_type(generator, elem_ty);
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype, Some(shape.len() as u64))
.construct_dyn_shape(generator, ctx, shape, None);
unsafe { ndarray.create_data(generator, ctx) };
Ok(ndarray)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
/// its input. /// its input.
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
@ -1455,294 +1415,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
///
/// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
///
/// Note that unlike other generating functions, one of the dimensions in the shape can be negative.
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
shape: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_reshape";
let (x1_ty, x1) = x1;
let (_, shape) = shape;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
let n1 = llvm_ndarray_ty.map_value(n1, None);
let n_sz = n1.size(generator, ctx);
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_list.load_size(ctx, None), false),
|generator, ctx, _, idx| {
let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
ele,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_neg_value =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_neg_value = ctx
.builder
.build_int_add(
num_neg_value,
llvm_usize.const_int(1, false),
"",
)
.unwrap();
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
Ok(None)
},
|_, ctx| {
let acc_value =
ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_value =
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
ctx.builder.build_store(acc, acc_value).unwrap();
Ok(None)
},
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
// Generate the output shape by filling -1 with `rem`
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
let ndims = shape_tuple.get_type().count_fields();
// Check for -1 in dims
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_negs =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_negs = ctx
.builder
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(num_neg, num_negs).unwrap();
Ok(None)
},
|_, ctx| {
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
// Reconstruct shape filling negatives with rem
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
let dim = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
let shape_int = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
shape_int,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(n_sz)),
|_, ctx| {
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
},
)?
.unwrap()
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => codegen_unreachable!(ctx),
}
.unwrap();
// Only allow one dimension to be negative
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"can only specify one unknown dimension",
[None, None, None],
ctx.current_loc,
);
// The new shape must be compatible with the old shape
let out_sz = out.size(generator, ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
"cannot reshape array of size {0} into provided shape of size {1}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// Generates LLVM IR for `ndarray.dot`. /// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals /// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul` /// For matrix multiplication use `np_matmul`

View File

@ -1,9 +1,16 @@
use std::iter::{once, repeat_n}; use std::iter::{once, repeat_n};
use inkwell::values::IntValue;
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::{ use crate::codegen::{
values::ndarray::{NDArrayValue, RustNDIndex}, irrt,
stmt::gen_if_callback,
types::ndarray::NDArrayType,
values::{
ndarray::{NDArrayValue, RustNDIndex},
ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -33,4 +40,72 @@ impl<'ctx> NDArrayValue<'ctx> {
*self *self
} }
} }
/// Create a reshaped view on this ndarray like
/// [`np.reshape()`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html).
///
/// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be
/// modified as a result.
///
/// If reshape without copying is impossible, this function will allocate a new ndarray and copy
/// contents.
///
/// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`].
/// * `new_shape` - The target shape to do `np.reshape()`.
#[must_use]
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
new_ndims: u64,
new_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) -> Self {
assert_eq!(new_shape.element_type(ctx, generator), self.llvm_usize.into());
// TODO: The current criterion for whether to do a full copy or not is by checking
// `is_c_contiguous`, but this is not optimal - there are cases when the ndarray is
// not contiguous but could be reshaped without copying data. Look into how numpy does
// it.
let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, Some(new_ndims))
.construct_uninitialized(generator, ctx, None);
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator));
// Resolve negative indices
let size = self.size(generator, ctx);
let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims().unwrap(), false);
let dst_shape = dst_ndarray.shape();
irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape(
generator,
ctx,
size,
dst_ndims,
dst_shape.as_slice_value(ctx, generator),
);
gen_if_callback(
generator,
ctx,
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|generator, ctx| {
// Reshape is possible without copying
dst_ndarray.set_strides_contiguous(generator, ctx);
dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator));
Ok(())
},
|generator, ctx| {
// Reshape is impossible without copying
unsafe {
dst_ndarray.create_data(generator, ctx);
}
dst_ndarray.copy_data_from(generator, ctx, *self);
Ok(())
},
)
.unwrap();
dst_ndarray
}
} }

View File

@ -5,8 +5,10 @@ use inkwell::{values::BasicValue, IntPredicate};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use super::{ use super::{
helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDef, PrimDefDetails}, helper::{
numpy::make_ndarray_ty, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails,
},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
*, *,
}; };
use crate::{ use crate::{
@ -15,7 +17,7 @@ use crate::{
numpy::*, numpy::*,
stmt::exn_constructor, stmt::exn_constructor,
types::ndarray::NDArrayType, types::ndarray::NDArrayType,
values::{ProxyValue, RangeValue}, values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue},
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
@ -193,7 +195,6 @@ struct BuiltinBuilder<'a> {
ndarray_float: Type, ndarray_float: Type,
ndarray_float_2d: Type, ndarray_float_2d: Type,
ndarray_num_ty: Type,
float_or_ndarray_ty: TypeVar, float_or_ndarray_ty: TypeVar,
float_or_ndarray_var_map: VarMap, float_or_ndarray_var_map: VarMap,
@ -307,7 +308,6 @@ impl<'a> BuiltinBuilder<'a> {
ndarray_float, ndarray_float,
ndarray_float_2d, ndarray_float_2d,
ndarray_num_ty,
float_or_ndarray_ty, float_or_ndarray_ty,
float_or_ndarray_var_map, float_or_ndarray_var_map,
@ -1356,20 +1356,41 @@ impl<'a> BuiltinBuilder<'a> {
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => create_fn_by_codegen( PrimDef::FunNpReshape => {
self.unifier, let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
&VarMap::new(),
prim.name(), create_fn_by_codegen(
self.ndarray_num_ty, self.unifier,
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], &VarMap::new(),
Box::new(move |ctx, _, fun, args, generator| { prim.name(),
let x1_ty = fun.0.args[0].ty; ret_ty,
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; &[
let x2_ty = fun.0.args[1].ty; (in_ndarray_ty.ty, "x"),
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) ],
}), Box::new(move |ctx, _, fun, args, generator| {
), let ndarray_ty = fun.0.args[0].ty;
let ndarray_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let shape_ty = fun.0.args[1].ty;
let shape_val =
args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty)
.map_value(ndarray_val.into_pointer_value(), None);
let shape = parse_numpy_int_sequence(generator, ctx, (shape_ty, shape_val));
// The ndims after reshaping is gotten from the return type of the call.
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndims = extract_ndims(&ctx.unifier, ndims);
let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape);
Ok(Some(new_ndarray.as_base_value().as_basic_value_enum()))
}),
)
}
_ => unreachable!(), _ => unreachable!(),
} }

View File

@ -8,5 +8,5 @@ expression: res_vec
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(251)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
] ]

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar236]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar236\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(254)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar235, typevar236]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar235\", \"typevar236\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(255)]\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n",
] ]

View File

@ -68,6 +68,13 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
for c in range(len(n[r])): for c in range(len(n[r])):
output_float64(n[r][c]) output_float64(n[r][c])
def output_ndarray_float_4(n: ndarray[float, Literal[4]]):
for x in range(len(n)):
for y in range(len(n[x])):
for z in range(len(n[x][y])):
for w in range(len(n[x][y][z])):
output_float64(n[x][y][z][w])
def consume_ndarray_1(n: ndarray[float, Literal[1]]): def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass pass
@ -197,6 +204,38 @@ def test_ndarray_nd_idx():
output_float64(x[1, 0]) output_float64(x[1, 0])
output_float64(x[1, 1]) output_float64(x[1, 1])
def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x = np_reshape(w, (1, 2, 1, -1))
y = np_reshape(x, [2, -1])
z = np_reshape(y, 10)
output_int32(np_shape(w)[0])
output_ndarray_float_1(w)
output_int32(np_shape(x)[0])
output_int32(np_shape(x)[1])
output_int32(np_shape(x)[2])
output_int32(np_shape(x)[3])
output_ndarray_float_4(x)
output_int32(np_shape(y)[0])
output_int32(np_shape(y)[1])
output_ndarray_float_2(y)
output_int32(np_shape(z)[0])
output_ndarray_float_1(z)
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
output_int32(np_shape(x1)[0])
output_ndarray_int32_1(x1)
output_int32(np_shape(x2)[0])
output_int32(np_shape(x2)[1])
output_ndarray_int32_2(x2)
def test_ndarray_add(): def test_ndarray_add():
x = np_identity(2) x = np_identity(2)
y = x + np_ones([2, 2]) y = x + np_ones([2, 2])
@ -1448,19 +1487,6 @@ def test_ndarray_transpose():
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_ndarray_float_2(y) output_ndarray_float_2(y)
def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x = np_reshape(w, (1, 2, 1, -1))
y = np_reshape(x, [2, -1])
z = np_reshape(y, 10)
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
output_ndarray_float_1(w)
output_ndarray_float_2(y)
output_ndarray_float_1(z)
def test_ndarray_dot(): def test_ndarray_dot():
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
@ -1592,6 +1618,8 @@ def run() -> int32:
test_ndarray_slices() test_ndarray_slices()
test_ndarray_nd_idx() test_ndarray_nd_idx()
test_ndarray_reshape()
test_ndarray_add() test_ndarray_add()
test_ndarray_add_broadcast() test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar() test_ndarray_add_broadcast_lhs_scalar()
@ -1756,7 +1784,6 @@ def run() -> int32:
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_transpose() test_ndarray_transpose()
test_ndarray_reshape()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_cholesky() test_ndarray_cholesky()