From b6ff75dcaff3b6792f16bcf75c812ca5418a15b0 Mon Sep 17 00:00:00 2001 From: David Mak <chmakac@connect.ust.hk> Date: Mon, 27 May 2024 15:58:06 +0800 Subject: [PATCH] core/irrt: Add support for calculating partial size of NDArray --- nac3core/src/codegen/builtin_fns.rs | 4 ++-- nac3core/src/codegen/classes.rs | 2 +- nac3core/src/codegen/expr.rs | 1 + nac3core/src/codegen/irrt/irrt.c | 16 ++++++++++++---- nac3core/src/codegen/irrt/mod.rs | 12 ++++++++++-- nac3core/src/codegen/numpy.rs | 4 ++++ 6 files changed, 30 insertions(+), 9 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 1fbfd712..c35018f5 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -737,7 +737,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes()); + let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx.builder .build_int_compare( @@ -955,7 +955,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes()); + let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx.builder .build_int_compare( diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index b3b6da43..6bbb230c 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1122,7 +1122,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, generator: &G, ) -> IntValue<'ctx> { - call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator)) + call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) } } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8812c802..6d8b4ac1 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1819,6 +1819,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( generator, ctx, &ndarray.dim_sizes().as_slice_value(ctx, generator), + (None, None), ); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index 59c481f5..1436447b 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -202,10 +202,14 @@ double __nac3_j0(double x) { uint32_t __nac3_ndarray_calc_size( const uint64_t *list_data, - uint32_t list_len + uint32_t list_len, + uint32_t begin_idx, + uint32_t end_idx ) { + __builtin_assume(end_idx <= list_len); + uint32_t num_elems = 1; - for (uint32_t i = 0; i < list_len; ++i) { + for (uint32_t i = begin_idx; i < end_idx; ++i) { uint64_t val = list_data[i]; __builtin_assume(val > 0); num_elems *= val; @@ -215,10 +219,14 @@ uint32_t __nac3_ndarray_calc_size( uint64_t __nac3_ndarray_calc_size64( const uint64_t *list_data, - uint64_t list_len + uint64_t list_len, + uint64_t begin_idx, + uint64_t end_idx ) { + __builtin_assume(end_idx <= list_len); + uint64_t num_elems = 1; - for (uint64_t i = 0; i < list_len; ++i) { + for (uint64_t i = begin_idx; i < end_idx; ++i) { uint64_t val = list_data[i]; __builtin_assume(val > 0); num_elems *= val; diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 086cdb4f..fbf0edc5 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -583,12 +583,14 @@ pub fn call_j0<'ctx>( /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the /// calculated total size. /// -/// * `num_dims` - An [`IntValue`] containing the number of dimensions. -/// * `dims` - A [`PointerValue`] to an array containing the size of each dimension. +/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. +/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, +/// or [`None`] if starting from the first dimension and ending at the last dimension respectively. pub fn call_ndarray_calc_size<'ctx, G, Dims>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, dims: &Dims, + (begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>), ) -> IntValue<'ctx> where G: CodeGenerator + ?Sized, @@ -607,6 +609,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>( &[ llvm_pi64.into(), llvm_usize.into(), + llvm_usize.into(), + llvm_usize.into(), ], false, ); @@ -615,12 +619,16 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>( ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) }); + let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); + let end = end.unwrap_or_else(|| dims.size(ctx, generator)); ctx.builder .build_call( ndarray_calc_size_fn, &[ dims.base_ptr(ctx, generator).into(), dims.size(ctx, generator).into(), + begin.into(), + end.into(), ], "", ) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index f22c721e..e44232fb 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -134,6 +134,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( generator, ctx, &ndarray.dim_sizes().as_slice_value(ctx, generator), + (None, None), ); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); @@ -203,6 +204,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( generator, ctx, &ndarray.dim_sizes().as_slice_value(ctx, generator), + (None, None), ); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); @@ -293,6 +295,7 @@ fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( generator, ctx, &ndarray.dim_sizes().as_slice_value(ctx, generator), + (None, None), ); gen_for_callback_incrementing( @@ -661,6 +664,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( generator, ctx, &ndarray.dim_sizes().as_slice_value(ctx, generator), + (None, None), ); let sizeof_ty = ctx.get_llvm_type(generator, elem_ty); let len_bytes = ctx.builder