From 12fddc3533dcc4c88bac0dfd91a46cf1930723b6 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:48:00 +0800 Subject: [PATCH] [core] codegen/ndarray: Make ndims non-optional Now that everything is ported to use strided impl, dynamic-ndim ndarray instances do not exist anymore. --- nac3artiq/src/codegen.rs | 4 +- nac3artiq/src/symbol_resolver.rs | 2 +- nac3core/src/codegen/builtin_fns.rs | 28 +++---- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/codegen/numpy.rs | 23 +++--- nac3core/src/codegen/stmt.rs | 6 +- nac3core/src/codegen/test.rs | 2 +- nac3core/src/codegen/types/ndarray/array.rs | 14 ++-- nac3core/src/codegen/types/ndarray/map.rs | 2 +- nac3core/src/codegen/types/ndarray/mod.rs | 54 ++++--------- nac3core/src/codegen/types/ndarray/nditer.rs | 7 +- .../src/codegen/values/ndarray/broadcast.rs | 17 ++--- .../src/codegen/values/ndarray/contiguous.rs | 6 +- .../src/codegen/values/ndarray/indexing.rs | 8 +- nac3core/src/codegen/values/ndarray/matmul.rs | 33 +++----- nac3core/src/codegen/values/ndarray/mod.rs | 76 +++---------------- nac3core/src/codegen/values/ndarray/view.rs | 11 +-- 17 files changed, 94 insertions(+), 201 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 653f41a3..156ba23e 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -464,7 +464,7 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, ndims) .map_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -597,7 +597,7 @@ fn format_rpc_ret<'ctx>( let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let dtype_llvm = ctx.get_llvm_type(generator, dtype); let ndims = extract_ndims(&ctx.unifier, ndims); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, ndims) .construct_uninitialized(generator, ctx, None); // NOTE: Current content of `ndarray`: diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 6507dc20..8e9cd10c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1107,7 +1107,7 @@ impl InnerResolver { self.global_value_ids.write().insert(id, obj.into()); } - let ndims = llvm_ndarray.ndims().unwrap(); + let ndims = llvm_ndarray.ndims(); // Obtain the shape of the ndarray let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 7c8ad7a8..9d368070 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1652,7 +1652,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1694,7 +1694,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { q.create_data(generator, ctx) }; @@ -1746,8 +1746,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)); - let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1); + let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); unsafe { u.create_data(generator, ctx) }; @@ -1796,7 +1796,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1838,7 +1838,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_dyn_shape(generator, ctx, &[d0, d1], None); unsafe { out.create_data(generator, ctx) }; @@ -1880,7 +1880,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { l.create_data(generator, ctx) }; @@ -1924,7 +1924,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None); + let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, ndims, llvm_usize, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1940,7 +1940,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( .construct_unsized(generator, ctx, &x2, None); // x2.shape == [] let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1979,7 +1979,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( } // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. - let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)) + let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1) .construct_const_shape(generator, ctx, &[1], None); unsafe { det.create_data(generator, ctx) }; @@ -2008,13 +2008,13 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - assert_eq!(x1.get_type().ndims(), Some(2)); + assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); t.copy_shape_from_ndarray(generator, ctx, x1); @@ -2053,13 +2053,13 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - assert_eq!(x1.get_type().ndims(), Some(2)); + assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); h.copy_shape_from_ndarray(generator, ctx, x1); diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 2ce3c9ab..4e83d530 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -520,7 +520,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new(generator, ctx, element_type, Some(ndims)).as_base_type().into() + NDArrayType::new(generator, ctx, element_type, ndims).as_base_type().into() } _ => unreachable!( diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index e5a893c9..6c16be9f 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -42,7 +42,7 @@ pub fn gen_ndarray_empty<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); Ok(ndarray.as_base_value()) } @@ -67,7 +67,7 @@ pub fn gen_ndarray_zeros<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -92,7 +92,7 @@ pub fn gen_ndarray_ones<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -120,8 +120,13 @@ pub fn gen_ndarray_full<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) - .construct_numpy_full(generator, context, &shape, fill_value_arg, None); + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims).construct_numpy_full( + generator, + context, + &shape, + fill_value_arg, + None, + ); Ok(ndarray.as_base_value()) } @@ -218,7 +223,7 @@ pub fn gen_ndarray_eye<'ctx>( .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); Ok(ndarray.as_base_value()) } @@ -246,7 +251,7 @@ pub fn gen_ndarray_identity<'ctx>( .builder .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); Ok(ndarray.as_base_value()) } @@ -315,8 +320,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. - assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1)); - assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1)); + assert_eq!(a.get_type().ndims(), 1); + assert_eq!(b.get_type().ndims(), 1); let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); // Check shapes. diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index edebb4f0..2a3bd066 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -447,10 +447,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) .to_ndarray(generator, ctx); - let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()] - .iter() - .filter_map(|ndims| *ndims) - .max(); + let broadcast_ndims = + [target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap(); let broadcast_result = NDArrayType::new( generator, ctx.ctx, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 97bd3f09..2701e138 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -464,6 +464,6 @@ fn test_classes_ndarray_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), None); + let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), 2); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 87cd002a..0f30f0eb 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -41,7 +41,7 @@ impl<'ctx> NDArrayType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); - assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims_int)); + assert!(self.ndims >= ndims_int); assert_eq!(dtype, self.dtype); let list_value = list.as_i8_list(generator, ctx); @@ -61,7 +61,7 @@ impl<'ctx> NDArrayType<'ctx> { generator, ctx, list_value, ndims, &shape, ); - let ndarray = Self::new(generator, ctx.ctx, dtype, Some(ndims_int)) + let ndarray = Self::new(generator, ctx.ctx, dtype, ndims_int) .construct_uninitialized(generator, ctx, name); ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { ndarray.create_data(generator, ctx) }; @@ -93,12 +93,12 @@ impl<'ctx> NDArrayType<'ctx> { if ndims == 1 { // `list` is not nested assert_eq!(ndims, 1); - assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims)); + assert!(self.ndims >= ndims); assert_eq!(dtype, self.dtype); let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let ndarray = Self::new(generator, ctx.ctx, dtype, Some(1)) + let ndarray = Self::new(generator, ctx.ctx, dtype, 1) .construct_uninitialized(generator, ctx, name); // Set data @@ -170,7 +170,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)).map_value(ndarray, None) + NDArrayType::new(generator, ctx.ctx, dtype, ndims).map_value(ndarray, None) } /// Implementation of `np_array(, copy=copy)`. @@ -183,9 +183,7 @@ impl<'ctx> NDArrayType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { assert_eq!(ndarray.get_type().dtype, self.dtype); - assert!(ndarray.get_type().ndims.is_none_or(|ndarray_ndims| self - .ndims - .is_none_or(|self_ndims| self_ndims >= ndarray_ndims))); + assert!(self.ndims >= ndarray.get_type().ndims); assert_eq!(copy.get_type(), ctx.ctx.bool_type()); let ndarray_val = gen_if_else_expr_callback( diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs index 0d63b22f..bf82b4da 100644 --- a/nac3core/src/codegen/types/ndarray/map.rs +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -47,7 +47,7 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayOut::NewNDArray { dtype } => { // Create a new ndarray based on the broadcast shape. let result_ndarray = - NDArrayType::new(generator, ctx.ctx, dtype, Some(broadcast_result.ndims)) + NDArrayType::new(generator, ctx.ctx, dtype, broadcast_result.ndims) .construct_uninitialized(generator, ctx, None); result_ndarray.copy_shape_from_array( generator, diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 43712be1..353ace33 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -38,7 +38,7 @@ mod nditer; pub struct NDArrayType<'ctx> { ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, } @@ -113,7 +113,7 @@ impl<'ctx> NDArrayType<'ctx> { generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, ) -> Self { let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); @@ -132,7 +132,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> Self { assert!(!inputs.is_empty()); - Self::new(generator, ctx, dtype, inputs.iter().filter_map(NDArrayType::ndims).max()) + Self::new(generator, ctx, dtype, inputs.iter().map(NDArrayType::ndims).max().unwrap()) } /// Creates an instance of [`NDArrayType`] with `ndims` of 0. @@ -145,7 +145,7 @@ impl<'ctx> NDArrayType<'ctx> { let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - NDArrayType { ty: llvm_ndarray, dtype, ndims: Some(0), llvm_usize } + NDArrayType { ty: llvm_ndarray, dtype, ndims: 0, llvm_usize } } /// Creates an [`NDArrayType`] from a [unifier type][Type]. @@ -164,7 +164,7 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayType { ty: Self::llvm_type(ctx.ctx, llvm_usize), dtype: llvm_dtype, - ndims: Some(ndims), + ndims, llvm_usize, } } @@ -174,7 +174,7 @@ impl<'ctx> NDArrayType<'ctx> { pub fn from_type( ptr_ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, ) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); @@ -196,7 +196,7 @@ impl<'ctx> NDArrayType<'ctx> { /// Returns the number of dimensions of this `ndarray` type. #[must_use] - pub fn ndims(&self) -> Option { + pub fn ndims(&self) -> u64 { self.ndims } @@ -286,35 +286,7 @@ impl<'ctx> NDArrayType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_some(), "NDArrayType::construct can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - - let Some(ndims) = self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false)) else { - unreachable!() - }; - - self.construct_impl(generator, ctx, ndims, name) - } - - /// Allocate an [`NDArrayValue`] on the stack given its `ndims` and `dtype`. - /// - /// `shape` and `strides` will be automatically allocated onto the stack. - /// - /// The returned ndarray's content will be: - /// - `data`: uninitialized. - /// - `itemsize`: set to the size of `dtype`. - /// - `ndims`: set to the value of `ndims`. - /// - `shape`: allocated with an array of length `ndims` with uninitialized values. - /// - `strides`: allocated with an array of length `ndims` with uninitialized values. - #[deprecated = "Prefer construct_uninitialized or construct_*_shape."] - #[must_use] - pub fn construct_dyn_ndims( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndims: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> >::Value { - assert!(self.ndims.is_none(), "NDArrayType::construct_dyn_ndims can only be called on an instance with compile-time unknown ndims (self.ndims = None)"); + let ndims = self.llvm_usize.const_int(self.ndims, false); self.construct_impl(generator, ctx, ndims, name) } @@ -330,9 +302,9 @@ impl<'ctx> NDArrayType<'ctx> { shape: &[u64], name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims)); + assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64)) + let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -365,9 +337,9 @@ impl<'ctx> NDArrayType<'ctx> { shape: &[IntValue<'ctx>], name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims)); + assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64)) + let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -407,7 +379,7 @@ impl<'ctx> NDArrayType<'ctx> { let value = value.as_basic_value_enum(); assert_eq!(value.get_type(), self.dtype); - assert!(self.ndims.is_none_or(|ndims| ndims == 0)); + assert_eq!(self.ndims, 0); // We have to put the value on the stack to get a data pointer. let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap(); diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 9b71693a..c77e4571 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -163,13 +163,8 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { - assert!( - ndarray.get_type().ndims().is_some(), - "NDIter requires ndims of NDArray to be known." - ); - let nditer = self.raw_alloca_var(generator, ctx, None); - let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false); + let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims(), false); // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index 0c84b05e..1b99f464 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -101,12 +101,11 @@ impl<'ctx> NDArrayValue<'ctx> { target_ndims: u64, target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) -> Self { - assert!(self.ndims.is_none_or(|ndims| ndims <= target_ndims)); + assert!(self.ndims <= target_ndims); assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); - let broadcast_ndarray = - NDArrayType::new(generator, ctx.ctx, self.dtype, Some(target_ndims)) - .construct_uninitialized(generator, ctx, None); + let broadcast_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, target_ndims) + .construct_uninitialized(generator, ctx, None); broadcast_ndarray.copy_shape_from_array( generator, ctx, @@ -199,14 +198,13 @@ impl<'ctx> NDArrayType<'ctx> { ndarrays: &[NDArrayValue<'ctx>], ) -> BroadcastAllResult<'ctx, G> { assert!(!ndarrays.is_empty()); - assert!(ndarrays.iter().all(|ndarray| ndarray.get_type().ndims().is_some())); let llvm_usize = generator.get_size_type(ctx.ctx); // Infer the broadcast output ndims. let broadcast_ndims_int = - ndarrays.iter().map(|ndarray| ndarray.get_type().ndims().unwrap()).max().unwrap(); - assert!(self.ndims().is_none_or(|ndims| ndims >= broadcast_ndims_int)); + ndarrays.iter().map(|ndarray| ndarray.get_type().ndims()).max().unwrap(); + assert!(self.ndims() >= broadcast_ndims_int); let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false); let broadcast_shape = ArraySliceValue::from_ptr_val( @@ -223,10 +221,7 @@ impl<'ctx> NDArrayType<'ctx> { let shape_entries = ndarrays .iter() .map(|ndarray| { - ( - ndarray.shape().as_slice_value(ctx, generator), - ndarray.get_type().ndims().unwrap(), - ) + (ndarray.shape().as_slice_value(ctx, generator), ndarray.get_type().ndims()) }) .collect_vec(); broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape); diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index f3b03dd1..8eb700b9 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -121,9 +121,7 @@ impl<'ctx> NDArrayValue<'ctx> { .alloca_var(generator, ctx, self.name); // Set ndims and shape. - let ndims = self - .ndims - .map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false)); + let ndims = self.llvm_usize.const_int(self.ndims, false); result.store_ndims(ctx, ndims); let shape = self.shape(); @@ -180,7 +178,7 @@ impl<'ctx> NDArrayValue<'ctx> { // TODO: Debug assert `ndims == carray.ndims` to catch bugs. // Allocate the resulting ndarray. - let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims)) + let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, ndims) .construct_uninitialized(generator, ctx, carray.name); // Copy shape and update strides diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 3d575028..3821f232 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -98,8 +98,8 @@ impl<'ctx> From> for PointerValue<'ctx> { impl<'ctx> NDArrayValue<'ctx> { /// Get the expected `ndims` after indexing with `indices`. #[must_use] - fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> Option { - let mut ndims = self.ndims?; + fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 { + let mut ndims = self.ndims; for index in indices { match index { @@ -113,7 +113,7 @@ impl<'ctx> NDArrayValue<'ctx> { } } - Some(ndims) + ndims } /// Index into the ndarray, and return a newly-allocated view on this ndarray. @@ -127,8 +127,6 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, indices: &[RustNDIndex<'ctx>], ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::index is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); - let dst_ndims = self.deduce_ndims_after_indexing_with(indices); let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims) .construct_uninitialized(generator, ctx, None); diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index 88a94394..f802c0c0 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -29,16 +29,8 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( (in_a_ty, in_a): (Type, NDArrayValue<'ctx>), (in_b_ty, in_b): (Type, NDArrayValue<'ctx>), ) -> NDArrayValue<'ctx> { - assert!( - in_a.ndims.is_some_and(|ndims| ndims >= 2), - "in_a (which is {:?}) must be compile-time known and >= 2", - in_a.ndims - ); - assert!( - in_b.ndims.is_some_and(|ndims| ndims >= 2), - "in_b (which is {:?}) must be compile-time known and >= 2", - in_b.ndims - ); + assert!(in_a.ndims >= 2, "in_a (which is {}) must be >= 2", in_a.ndims); + assert!(in_b.ndims >= 2, "in_b (which is {}) must be >= 2", in_b.ndims); let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); @@ -47,13 +39,13 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); // Deduce ndims of the result of matmul. - let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap()); + let ndims_int = max(in_a.ndims, in_b.ndims); let ndims = llvm_usize.const_int(ndims_int, false); // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the // destination ndarray to store the result of matmul. let (lhs, rhs, dst) = { - let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false); + let in_lhs_ndims = llvm_usize.const_int(in_a.ndims, false); let in_lhs_shape = TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val( in_a.shape().base_ptr(ctx, generator), @@ -63,7 +55,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( |_, _, val| val.into_int_value(), |_, _, val| val.into(), ); - let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false); + let in_rhs_ndims = llvm_usize.const_int(in_b.ndims, false); let in_rhs_shape = TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val( in_b.shape().base_ptr(ctx, generator), @@ -116,7 +108,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); - let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int)) + let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, ndims_int) .construct_uninitialized(generator, ctx, None); dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); unsafe { @@ -266,10 +258,7 @@ impl<'ctx> NDArrayValue<'ctx> { (out_dtype, out): (Type, NDArrayOut<'ctx>), ) -> Self { // Sanity check, but type inference should prevent this. - assert!( - self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0), - "np.matmul disallows scalar input" - ); + assert!(self.ndims > 0 && other.ndims > 0, "np.matmul disallows scalar input"); // If both arguments are 2-D they are multiplied like conventional matrices. // @@ -282,14 +271,14 @@ impl<'ctx> NDArrayValue<'ctx> { // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its // dimensions. After matrix multiplication the appended 1 is removed. - let new_a = if self.ndims.unwrap() == 1 { + let new_a = if self.ndims == 1 { // Prepend 1 to its dimensions self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) } else { *self }; - let new_b = if other.ndims.unwrap() == 1 { + let new_b = if other.ndims == 1 { // Append 1 to its dimensions other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) } else { @@ -305,12 +294,12 @@ impl<'ctx> NDArrayValue<'ctx> { let mut postindices = vec![]; let zero = ctx.ctx.i32_type().const_zero(); - if self.ndims.unwrap() == 1 { + if self.ndims == 1 { // Remove the prepended 1 postindices.push(RustNDIndex::SingleElement(zero)); } - if other.ndims.unwrap() == 1 { + if other.ndims == 1 { // Remove the appended 1 postindices.push(RustNDIndex::Ellipsis); postindices.push(RustNDIndex::SingleElement(zero)); diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 707c79a2..ad50d32e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -42,7 +42,7 @@ mod view; pub struct NDArrayValue<'ctx> { value: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } @@ -62,7 +62,7 @@ impl<'ctx> NDArrayValue<'ctx> { pub fn from_pointer_value( ptr: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { @@ -245,26 +245,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, ) { - if self.ndims.is_some() && src_ndarray.ndims.is_some() { - assert_eq!(self.ndims, src_ndarray.ndims); - } else { - let self_ndims = self.load_ndims(ctx); - let src_ndims = src_ndarray.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - self_ndims, - src_ndims, - "" - ).unwrap(), - "0:AssertionError", - "NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})", - [Some(self_ndims), Some(src_ndims), None], - ctx.current_loc - ); - } + assert_eq!(self.ndims, src_ndarray.ndims); let src_shape = src_ndarray.shape().base_ptr(ctx, generator); self.copy_shape_from_array(generator, ctx, src_shape); @@ -296,26 +277,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, ) { - if self.ndims.is_some() && src_ndarray.ndims.is_some() { - assert_eq!(self.ndims, src_ndarray.ndims); - } else { - let self_ndims = self.load_ndims(ctx); - let src_ndims = src_ndarray.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - self_ndims, - src_ndims, - "" - ).unwrap(), - "0:AssertionError", - "NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})", - [Some(self_ndims), Some(src_ndims), None], - ctx.current_loc - ); - } + assert_eq!(self.ndims, src_ndarray.ndims); let src_strides = src_ndarray.strides().base_ptr(ctx, generator); self.copy_strides_from_array(generator, ctx, src_strides); @@ -380,11 +342,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> Self { - let clone = if self.ndims.is_some() { - self.get_type().construct_uninitialized(generator, ctx, None) - } else { - self.get_type().construct_dyn_ndims(generator, ctx, self.load_ndims(ctx), None) - }; + let clone = self.get_type().construct_uninitialized(generator, ctx, None); let shape = self.shape(); clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); @@ -437,11 +395,9 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> TupleValue<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::make_shape_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - let llvm_i32 = ctx.ctx.i32_type(); - let objects = (0..self.ndims.unwrap()) + let objects = (0..self.ndims) .map(|i| { let dim = unsafe { self.shape().get_typed_unchecked( @@ -459,7 +415,7 @@ impl<'ctx> NDArrayValue<'ctx> { TupleType::new( generator, ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), ) .construct_from_objects(ctx, objects, None) } @@ -473,11 +429,9 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> TupleValue<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::make_strides_tuple can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - let llvm_i32 = ctx.ctx.i32_type(); - let objects = (0..self.ndims.unwrap()) + let objects = (0..self.ndims) .map(|i| { let dim = unsafe { self.strides().get_typed_unchecked( @@ -495,15 +449,15 @@ impl<'ctx> NDArrayValue<'ctx> { TupleType::new( generator, ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims.unwrap() as usize).collect_vec(), + &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), ) .construct_from_objects(ctx, objects, None) } /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] - pub fn is_unsized(&self) -> Option { - self.ndims.map(|ndims| ndims == 0) + pub fn is_unsized(&self) -> bool { + self.ndims == 0 } /// Returns the element present in this `ndarray` if this is unsized. @@ -512,11 +466,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> Option> { - let Some(is_unsized) = self.is_unsized() else { - panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - }; - - if is_unsized { + if self.is_unsized() { // NOTE: `np.size(self) == 0` here is never possible. let zero = generator.get_size_type(ctx.ctx).const_zero(); let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; @@ -534,8 +484,6 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> ScalarOrNDArray<'ctx> { - assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) { ScalarOrNDArray::Scalar(unsized_elem) } else { diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 450f7444..5027be58 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -26,9 +26,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndmin: u64, ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::atleast_nd is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); - - let ndims = self.ndims.unwrap(); + let ndims = self.ndims; if ndims < ndmin { // Extend the dimensions with np.newaxis. @@ -67,13 +65,13 @@ impl<'ctx> NDArrayValue<'ctx> { // not contiguous but could be reshaped without copying data. Look into how numpy does // it. - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, Some(new_ndims)) + let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, new_ndims) .construct_uninitialized(generator, ctx, None); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); // Resolve negative indices let size = self.size(generator, ctx); - let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims().unwrap(), false); + let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false); let dst_shape = dst_ndarray.shape(); irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( generator, @@ -121,7 +119,6 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, axes: Option>, ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::transpose is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); assert!( axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into()) ); @@ -130,7 +127,7 @@ impl<'ctx> NDArrayValue<'ctx> { let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None); let axes = if let Some(axes) = axes { - let num_axes = self.llvm_usize.const_int(self.ndims.unwrap(), false); + let num_axes = self.llvm_usize.const_int(self.ndims, false); // `axes = nullptr` if `axes` is unspecified. let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None);