From 8d975b5ff3ecf04c71ba7971aac8c1f7c91dab81 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 11:40:23 +0800 Subject: [PATCH] [core] codegen/ndarray: Implement np_reshape Based on 926e7e93: core/ndstrides: implement np_reshape() --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/reshape.hpp | 97 +++++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + nac3core/src/codegen/irrt/ndarray/reshape.rs | 40 +++ nac3core/src/codegen/numpy.rs | 334 +----------------- nac3core/src/codegen/values/ndarray/view.rs | 77 +++- nac3core/src/toplevel/builtins.rs | 59 +++- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3standalone/demo/src/ndarray.py | 55 ++- 13 files changed, 308 insertions(+), 373 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/reshape.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/reshape.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 57f60d52f..bb7fb3d4d 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -9,4 +9,5 @@ #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" -#include "irrt/ndarray/array.hpp" \ No newline at end of file +#include "irrt/ndarray/array.hpp" +#include "irrt/ndarray/reshape.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/reshape.hpp b/nac3core/irrt/irrt/ndarray/reshape.hpp new file mode 100644 index 000000000..b2ad2a5b5 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/reshape.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray::reshape { +/** + * @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(, new_shape)` + * + * If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be + * modified to contain the resolved dimension. + * + * To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual + * `` object itself, but only the `.size` of the ``. + * + * @param size The `.size` of `` + * @param new_ndims Number of elements in `new_shape` + * @param new_shape Target shape to reshape to + */ +template +void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) { + // Is there a -1 in `new_shape`? + bool neg1_exists = false; + // Location of -1, only initialized if `neg1_exists` is true + SizeT neg1_axis_i; + // The computed ndarray size of `new_shape` + SizeT new_size = 1; + + for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) { + SizeT dim = new_shape[axis_i]; + if (dim < 0) { + if (dim == -1) { + if (neg1_exists) { + // Multiple `-1` found. Throw an error. + raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM, + NO_PARAM, NO_PARAM); + } else { + neg1_exists = true; + neg1_axis_i = axis_i; + } + } else { + // TODO: What? In `np.reshape` any negative dimensions is + // treated like its `-1`. + // + // Try running `np.zeros((3, 4)).reshape((-999, 2))` + // + // It is not documented by numpy. + // Throw an error for now... + + raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i, + NO_PARAM); + } + } else { + new_size *= dim; + } + } + + bool can_reshape; + if (neg1_exists) { + // Let `x` be the unknown dimension + // Solve `x * = ` + if (new_size == 0 && size == 0) { + // `x` has infinitely many solutions + can_reshape = false; + } else if (new_size == 0 && size != 0) { + // `x` has no solutions + can_reshape = false; + } else if (size % new_size != 0) { + // `x` has no integer solutions + can_reshape = false; + } else { + can_reshape = true; + new_shape[neg1_axis_i] = size / new_size; // Resolve dimension + } + } else { + can_reshape = (new_size == size); + } + + if (!can_reshape) { + raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM, + NO_PARAM); + } +} +} // namespace ndarray::reshape +} // namespace + +extern "C" { +void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); +} + +void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); +} +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 307ec6bbd..f67566b9c 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -20,11 +20,13 @@ pub use array::*; pub use basic::*; pub use indexing::*; pub use iter::*; +pub use reshape::*; mod array; mod basic; mod indexing; mod iter; +mod reshape; /// Generates a call to `__nac3_ndarray_calc_size`. Returns a /// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. diff --git a/nac3core/src/codegen/irrt/ndarray/reshape.rs b/nac3core/src/codegen/irrt/ndarray/reshape.rs new file mode 100644 index 000000000..32de2fa15 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/reshape.rs @@ -0,0 +1,40 @@ +use inkwell::values::IntValue; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ArrayLikeValue, ArraySliceValue}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_reshape_resolve_and_check_new_shape`. +/// +/// Resolves unknown dimensions in `new_shape` for `numpy.reshape(, new_shape)`, raising an +/// assertion if multiple dimensions are unknown (`-1`). +pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + new_ndims: IntValue<'ctx>, + new_shape: ArraySliceValue<'ctx>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!(size.get_type(), llvm_usize); + assert_eq!(new_ndims.get_type(), llvm_usize); + assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into()); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_reshape_resolve_and_check_new_shape", + ); + infer_and_call_function( + ctx, + &name, + None, + &[size.into(), new_ndims.into(), new_shape.base_ptr(ctx, generator).into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 2f899f95f..ce4e20598 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -21,9 +21,9 @@ use super::{ types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; @@ -134,46 +134,6 @@ where Ok(ndarray) } -/// Creates an `NDArray` instance from a constant shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. -pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: &[IntValue<'ctx>], -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - for &shape_dim in shape { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let shape_dim_gez = ctx - .builder - .build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow shape > u32_MAX - } - - let llvm_dtype = ctx.get_llvm_type(generator, elem_ty); - - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype, Some(shape.len() as u64)) - .construct_dyn_shape(generator, ctx, shape, None); - unsafe { ndarray.create_data(generator, ctx) }; - - Ok(ndarray) -} - /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// its input. fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( @@ -1455,294 +1415,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( } } -/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`. -/// -/// * `x1` - `NDArray` to reshape. -/// * `shape` - The `shape` parameter used to construct the new `NDArray`. -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])` -/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` -/// -/// Note that unlike other generating functions, one of the dimensions in the shape can be negative. -pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - shape: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_reshape"; - let (x1_ty, x1) = x1; - let (_, shape) = shape; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = n1.size(generator, ctx); - - let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap(); - ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap(); - - let out = match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])` - - let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); - // Check for -1 in dimensions - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_list.load_size(ctx, None), false), - |generator, ctx, _, idx| { - let ele = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - ele, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_neg_value = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_neg_value = ctx - .builder - .build_int_add( - num_neg_value, - llvm_usize.const_int(1, false), - "", - ) - .unwrap(); - ctx.builder.build_store(num_neg, num_neg_value).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_value = - ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_value = - ctx.builder.build_int_mul(acc_value, ele, "").unwrap(); - ctx.builder.build_store(acc, acc_value).unwrap(); - Ok(None) - }, - )?; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - // Generate the output shape by filling -1 with `rem` - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, _| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - let dim = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` - - let ndims = shape_tuple.get_type().count_fields(); - // Check for -1 in dims - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_negs = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_negs = ctx - .builder - .build_int_add(num_negs, llvm_usize.const_int(1, false), "") - .unwrap(); - ctx.builder.build_store(num_neg, num_negs).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); - ctx.builder.build_store(acc, acc_val).unwrap(); - Ok(None) - }, - )?; - } - - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - let mut shape = Vec::with_capacity(ndims as usize); - - // Reconstruct shape filling negatives with rem - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - let dim = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value(); - shape.push(dim); - } - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` - let shape_int = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - shape_int, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(n_sz)), - |_, ctx| { - Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap())) - }, - )? - .unwrap() - .into_int_value(); - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } - .unwrap(); - - // Only allow one dimension to be negative - let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "can only specify one unknown dimension", - [None, None, None], - ctx.current_loc, - ); - - // The new shape must be compatible with the old shape - let out_sz = out.size(generator, ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), - "0:ValueError", - "cannot reshape array of size {0} into provided shape of size {1}", - [Some(n_sz), Some(out_sz), None], - ctx.current_loc, - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - codegen_unreachable!( - ctx, - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 70a9d6590..8334d379c 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -1,9 +1,16 @@ use std::iter::{once, repeat_n}; +use inkwell::values::IntValue; use itertools::Itertools; use crate::codegen::{ - values::ndarray::{NDArrayValue, RustNDIndex}, + irrt, + stmt::gen_if_callback, + types::ndarray::NDArrayType, + values::{ + ndarray::{NDArrayValue, RustNDIndex}, + ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, + }, CodeGenContext, CodeGenerator, }; @@ -33,4 +40,72 @@ impl<'ctx> NDArrayValue<'ctx> { *self } } + + /// Create a reshaped view on this ndarray like + /// [`np.reshape()`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html). + /// + /// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be + /// modified as a result. + /// + /// If reshape without copying is impossible, this function will allocate a new ndarray and copy + /// contents. + /// + /// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`]. + /// * `new_shape` - The target shape to do `np.reshape()`. + #[must_use] + pub fn reshape_or_copy( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + new_ndims: u64, + new_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) -> Self { + assert_eq!(new_shape.element_type(ctx, generator), self.llvm_usize.into()); + + // TODO: The current criterion for whether to do a full copy or not is by checking + // `is_c_contiguous`, but this is not optimal - there are cases when the ndarray is + // not contiguous but could be reshaped without copying data. Look into how numpy does + // it. + + let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, Some(new_ndims)) + .construct_uninitialized(generator, ctx, None); + dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); + + // Resolve negative indices + let size = self.size(generator, ctx); + let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims().unwrap(), false); + let dst_shape = dst_ndarray.shape(); + irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( + generator, + ctx, + size, + dst_ndims, + dst_shape.as_slice_value(ctx, generator), + ); + + gen_if_callback( + generator, + ctx, + |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |generator, ctx| { + // Reshape is possible without copying + dst_ndarray.set_strides_contiguous(generator, ctx); + dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator)); + + Ok(()) + }, + |generator, ctx| { + // Reshape is impossible without copying + unsafe { + dst_ndarray.create_data(generator, ctx); + } + dst_ndarray.copy_data_from(generator, ctx, *self); + + Ok(()) + }, + ) + .unwrap(); + + dst_ndarray + } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 276c00c7d..db7acaf34 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -5,8 +5,10 @@ use inkwell::{values::BasicValue, IntPredicate}; use strum::IntoEnumIterator; use super::{ - helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDef, PrimDefDetails}, - numpy::make_ndarray_ty, + helper::{ + debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails, + }, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, *, }; use crate::{ @@ -15,7 +17,7 @@ use crate::{ numpy::*, stmt::exn_constructor, types::ndarray::NDArrayType, - values::{ProxyValue, RangeValue}, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue}, }, symbol_resolver::SymbolValue, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, @@ -193,7 +195,6 @@ struct BuiltinBuilder<'a> { ndarray_float: Type, ndarray_float_2d: Type, - ndarray_num_ty: Type, float_or_ndarray_ty: TypeVar, float_or_ndarray_var_map: VarMap, @@ -307,7 +308,6 @@ impl<'a> BuiltinBuilder<'a> { ndarray_float, ndarray_float_2d, - ndarray_num_ty, float_or_ndarray_ty, float_or_ndarray_var_map, @@ -1356,20 +1356,41 @@ impl<'a> BuiltinBuilder<'a> { // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), + PrimDef::FunNpReshape => { + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding + + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[ + (in_ndarray_ty.ty, "x"), + (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding + ], + Box::new(move |ctx, _, fun, args, generator| { + let ndarray_ty = fun.0.args[0].ty; + let ndarray_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let shape_ty = fun.0.args[1].ty; + let shape_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_value(ndarray_val.into_pointer_value(), None); + + let shape = parse_numpy_int_sequence(generator, ctx, (shape_ty, shape_val)); + + // The ndims after reshaping is gotten from the return type of the call. + let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, &shape); + Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) + }), + ) + } _ => unreachable!(), } diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index b03b3616c..f53498907 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(251)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(252)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index b4df49c9c..054086834 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar236]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar236\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 65a6a8ac8..029815aa2 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(254)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index cfedf1f6b..fc8f1aaf7 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar235, typevar236]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar235\", \"typevar236\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 9a77a21c4..34e78a2bb 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(255)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", ] diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index d42f3b93a..a82cfbad6 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -68,6 +68,13 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): for c in range(len(n[r])): output_float64(n[r][c]) +def output_ndarray_float_4(n: ndarray[float, Literal[4]]): + for x in range(len(n)): + for y in range(len(n[x])): + for z in range(len(n[x][y])): + for w in range(len(n[x][y][z])): + output_float64(n[x][y][z][w]) + def consume_ndarray_1(n: ndarray[float, Literal[1]]): pass @@ -197,6 +204,38 @@ def test_ndarray_nd_idx(): output_float64(x[1, 0]) output_float64(x[1, 1]) +def test_ndarray_reshape(): + w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + x = np_reshape(w, (1, 2, 1, -1)) + y = np_reshape(x, [2, -1]) + z = np_reshape(y, 10) + + output_int32(np_shape(w)[0]) + output_ndarray_float_1(w) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_int32(np_shape(x)[2]) + output_int32(np_shape(x)[3]) + output_ndarray_float_4(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_ndarray_float_1(z) + + x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) + x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) + + output_int32(np_shape(x1)[0]) + output_ndarray_int32_1(x1) + + output_int32(np_shape(x2)[0]) + output_int32(np_shape(x2)[1]) + output_ndarray_int32_2(x2) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1448,19 +1487,6 @@ def test_ndarray_transpose(): output_ndarray_float_2(x) output_ndarray_float_2(y) -def test_ndarray_reshape(): - w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) - x = np_reshape(w, (1, 2, 1, -1)) - y = np_reshape(x, [2, -1]) - z = np_reshape(y, 10) - - x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) - x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) - - output_ndarray_float_1(w) - output_ndarray_float_2(y) - output_ndarray_float_1(z) - def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1592,6 +1618,8 @@ def run() -> int32: test_ndarray_slices() test_ndarray_nd_idx() + test_ndarray_reshape() + test_ndarray_add() test_ndarray_add_broadcast() test_ndarray_add_broadcast_lhs_scalar() @@ -1756,7 +1784,6 @@ def run() -> int32: test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_transpose() - test_ndarray_reshape() test_ndarray_dot() test_ndarray_cholesky()