From 90071be0a75c0e39a638ccbd45522982146c0f8f Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:21:08 +0800 Subject: [PATCH] [core] codegen/ndarray: Use IRRT for size() and indexing operations Also refactor some usages of call_ndarray_calc_size with ndarray.size(). --- nac3core/irrt/irrt/ndarray.hpp | 31 ------- nac3core/src/codegen/builtin_fns.rs | 14 +++- nac3core/src/codegen/irrt/ndarray/mod.rs | 76 +---------------- nac3core/src/codegen/numpy.rs | 98 +++++++++++++++++----- nac3core/src/codegen/values/ndarray/mod.rs | 75 +++-------------- 5 files changed, 102 insertions(+), 192 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 72ca0b9e..7fc9a63b 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -29,25 +29,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n } } -template -SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, - SizeT num_dims, - const NDIndexInt* indices, - SizeT num_indices) { - SizeT idx = 0; - SizeT stride = 1; - for (SizeT i = 0; i < num_dims; ++i) { - SizeT ri = num_dims - i - 1; - if (ri < num_indices) { - idx += stride * indices[ri]; - } - - __builtin_assume(dims[i] > 0); - stride *= dims[ri]; - } - return idx; -} - template void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, SizeT lhs_ndims, @@ -107,18 +88,6 @@ void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); } -uint32_t -__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, - uint64_t num_dims, - const NDIndexInt* indices, - uint64_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, uint32_t lhs_ndims, const uint32_t* rhs_dims, diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index a41b9f55..b21e721e 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -877,8 +877,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty); let n = llvm_ndarray_ty.map_value(n, None); - let n_sz = - irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); + let n_sz = n.size(generator, ctx); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx .builder @@ -913,7 +912,16 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( llvm_int64.const_int(1, false), (n_sz, false), |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; + let elem = unsafe { + n.data().get_unchecked( + ctx, + generator, + &ctx.builder + .build_int_truncate_or_bit_cast(idx, llvm_usize, "") + .unwrap(), + None, + ) + }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index b74ace0f..56017c94 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,5 +1,5 @@ use inkwell::{ - types::{BasicTypeEnum, IntType}, + types::BasicTypeEnum, values::{BasicValueEnum, CallSiteValue, IntValue}, AddressSpace, IntPredicate, }; @@ -138,78 +138,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ) } -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for -/// the multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index"); - let ndarray_flatten_index_fn = - ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.shape(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`] -/// containing the size of each dimension of the resultant `ndarray`. +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of +/// dimension and size of each dimension of the resultant `ndarray`. pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 30a33f08..9328bb83 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -21,8 +21,8 @@ use super::{ stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, types::{ndarray::NDArrayType, ListType, ProxyType}, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, + ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, @@ -318,12 +318,7 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.shape().as_slice_value(ctx, generator), - (None, None), - ); + let ndarray_num_elems = ndarray.size(generator, ctx); gen_for_callback_incrementing( generator, @@ -434,6 +429,66 @@ where rhs_val.get_type() ); + // Returns the element of an ndarray indexed by the given indices, performing int-promotion on + // `indices` where necessary. + // + // Required for compatibility with `NDArrayType::get_unchecked`. + let get_data_by_indices_compat = + |generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| { + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT + let stackptr = llvm_intrinsics::call_stacksave(ctx, None); + let indices = if llvm_usize == ctx.ctx.i32_type() { + indices + } else { + let indices_usize = TypedArrayLikeAdapter::>::from( + ArraySliceValue::from_ptr_val( + ctx.builder + .build_array_alloca(llvm_usize, indices.size(ctx, generator), "") + .unwrap(), + indices.size(ctx, generator), + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (indices.size(ctx, generator), false), + |generator, ctx, _, i| { + let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) }; + let idx = ctx + .builder + .build_int_z_extend_or_bit_cast(idx, llvm_usize, "") + .unwrap(); + unsafe { + indices_usize.set_typed_unchecked(ctx, generator, &i, idx); + } + + Ok(()) + }, + llvm_usize.const_int(1, false), + ) + .unwrap(); + + indices_usize + }; + + let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) }; + + llvm_intrinsics::call_stackrestore(ctx, stackptr); + + elem + }; + // Assert that all ndarray operands are broadcastable to the target size if !lhs_scalar { let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) @@ -455,7 +510,7 @@ where .map_value(lhs_val.into_pointer_value(), None); let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } + get_data_by_indices_compat(generator, ctx, lhs, lhs_idx) }; let rhs_elem = if rhs_scalar { @@ -465,7 +520,7 @@ where .map_value(rhs_val.into_pointer_value(), None); let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } + get_data_by_indices_compat(generator, ctx, rhs, rhs_idx) }; value_fn(generator, ctx, (lhs_elem, rhs_elem)) @@ -1408,7 +1463,6 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( lhs: NDArrayValue<'ctx>, rhs: NDArrayValue<'ctx>, ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); if cfg!(debug_assertions) { @@ -1597,19 +1651,19 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap() }; let idx0 = unsafe { let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap() }; let idx1 = unsafe { let idx1 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() + ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap() }; let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; @@ -1620,14 +1674,12 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( generator, ctx, None, - llvm_i32.const_zero(), + llvm_usize.const_zero(), (common_dim, false), |generator, ctx, _, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); - let ab_idx = generator.gen_array_var_alloc( ctx, - llvm_i32.into(), + llvm_usize.into(), llvm_usize.const_int(2, false), None, )?; @@ -2002,7 +2054,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n_sz = n1.size(generator, ctx); // Dimensions are reversed in the transposed array let out = create_ndarray_dyn_shape( @@ -2122,7 +2174,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n_sz = n1.size(generator, ctx); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2350,7 +2402,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( ); // The new shape must be compatible with the old shape - let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None)); + let out_sz = out.size(generator, ctx); ctx.make_assert( generator, ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), @@ -2407,8 +2459,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n1_sz = n1.size(generator, ctx); + let n2_sz = n2.size(generator, ctx); ctx.make_assert( generator, diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 4c5be432..5140ae36 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -671,12 +671,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, generator: &G, ) -> IntValue<'ctx> { - irrt::ndarray::call_ndarray_calc_size( - generator, - ctx, - &self.as_slice_value(ctx, generator), - (None, None), - ) + irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0) } } @@ -688,24 +683,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let sizeof_elem = ctx - .builder - .build_int_truncate_or_bit_cast( - self.element_type(ctx, generator).size_of().unwrap(), - idx.get_type(), - "", - ) - .unwrap(); - let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap(); - let ptr = unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[idx], - name.unwrap_or_default(), - ) - .unwrap() - }; + let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx); // Current implementation is transparent - The returned pointer type is // already cast into the expected type, allowing for immediately @@ -716,7 +694,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { BasicTypeEnum::try_from(self.element_type(ctx, generator)) .unwrap() .ptr_type(AddressSpace::default()), - "", + name.unwrap_or_default(), ) .unwrap() } @@ -769,52 +747,25 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into()); - let indices_elem_ty = unsafe { - indices - .ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type() - }; - let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { - panic!("Expected list[int32] but got {indices_elem_ty}") - }; - assert_eq!( - indices_elem_ty.get_bit_width(), - 32, - "Expected list[int32] but got list[int{}]", - indices_elem_ty.get_bit_width() + let ptr = irrt::ndarray::call_nac3_ndarray_get_pelement_by_indices( + generator, + ctx, + *self.0, + indices.base_ptr(ctx, generator), ); - let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices); - let sizeof_elem = ctx - .builder - .build_int_truncate_or_bit_cast( - self.element_type(ctx, generator).size_of().unwrap(), - index.get_type(), - "", - ) - .unwrap(); - let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap(); - - let ptr = unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ) - .unwrap() - }; - // TODO: Current implementation is transparent + // Current implementation is transparent - The returned pointer type is + // already cast into the expected type, allowing for immediately + // load/store. ctx.builder .build_pointer_cast( ptr, BasicTypeEnum::try_from(self.element_type(ctx, generator)) .unwrap() .ptr_type(AddressSpace::default()), - "", + name.unwrap_or_default(), ) .unwrap() }