From b6ff75dcaff3b6792f16bcf75c812ca5418a15b0 Mon Sep 17 00:00:00 2001
From: David Mak <chmakac@connect.ust.hk>
Date: Mon, 27 May 2024 15:58:06 +0800
Subject: [PATCH] core/irrt: Add support for calculating partial size of
 NDArray

---
 nac3core/src/codegen/builtin_fns.rs |  4 ++--
 nac3core/src/codegen/classes.rs     |  2 +-
 nac3core/src/codegen/expr.rs        |  1 +
 nac3core/src/codegen/irrt/irrt.c    | 16 ++++++++++++----
 nac3core/src/codegen/irrt/mod.rs    | 12 ++++++++++--
 nac3core/src/codegen/numpy.rs       |  4 ++++
 6 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs
index 1fbfd712..c35018f5 100644
--- a/nac3core/src/codegen/builtin_fns.rs
+++ b/nac3core/src/codegen/builtin_fns.rs
@@ -737,7 +737,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
             let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
 
             let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
-            let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
+            let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
             if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
                 let n_sz_eqz = ctx.builder
                     .build_int_compare(
@@ -955,7 +955,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
             let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
 
             let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
-            let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
+            let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
             if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
                 let n_sz_eqz = ctx.builder
                     .build_int_compare(
diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs
index b3b6da43..6bbb230c 100644
--- a/nac3core/src/codegen/classes.rs
+++ b/nac3core/src/codegen/classes.rs
@@ -1122,7 +1122,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
         ctx: &CodeGenContext<'ctx, '_>,
         generator: &G,
     ) -> IntValue<'ctx> {
-        call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator))
+        call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
     }
 }
 
diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs
index 8812c802..6d8b4ac1 100644
--- a/nac3core/src/codegen/expr.rs
+++ b/nac3core/src/codegen/expr.rs
@@ -1819,6 +1819,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
             generator,
             ctx,
             &ndarray.dim_sizes().as_slice_value(ctx, generator),
+            (None, None),
         );
         ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
 
diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c
index 59c481f5..1436447b 100644
--- a/nac3core/src/codegen/irrt/irrt.c
+++ b/nac3core/src/codegen/irrt/irrt.c
@@ -202,10 +202,14 @@ double __nac3_j0(double x) {
 
 uint32_t __nac3_ndarray_calc_size(
     const uint64_t *list_data,
-    uint32_t list_len
+    uint32_t list_len,
+    uint32_t begin_idx,
+    uint32_t end_idx
 ) {
+    __builtin_assume(end_idx <= list_len);
+
     uint32_t num_elems = 1;
-    for (uint32_t i = 0; i < list_len; ++i) {
+    for (uint32_t i = begin_idx; i < end_idx; ++i) {
         uint64_t val = list_data[i];
         __builtin_assume(val > 0);
         num_elems *= val;
@@ -215,10 +219,14 @@ uint32_t __nac3_ndarray_calc_size(
 
 uint64_t __nac3_ndarray_calc_size64(
     const uint64_t *list_data,
-    uint64_t list_len
+    uint64_t list_len,
+    uint64_t begin_idx,
+    uint64_t end_idx
 ) {
+    __builtin_assume(end_idx <= list_len);
+
     uint64_t num_elems = 1;
-    for (uint64_t i = 0; i < list_len; ++i) {
+    for (uint64_t i = begin_idx; i < end_idx; ++i) {
         uint64_t val = list_data[i];
         __builtin_assume(val > 0);
        num_elems *= val;
diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs
index 086cdb4f..fbf0edc5 100644
--- a/nac3core/src/codegen/irrt/mod.rs
+++ b/nac3core/src/codegen/irrt/mod.rs
@@ -583,12 +583,14 @@ pub fn call_j0<'ctx>(
 /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
 /// calculated total size.
 ///
-/// * `num_dims` - An [`IntValue`] containing the number of dimensions.
-/// * `dims` - A [`PointerValue`] to an array containing the size of each dimension.
+/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
+/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, 
+/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
 pub fn call_ndarray_calc_size<'ctx, G, Dims>(
     generator: &G,
     ctx: &CodeGenContext<'ctx, '_>,
     dims: &Dims,
+    (begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
 ) -> IntValue<'ctx>
     where
         G: CodeGenerator + ?Sized,
@@ -607,6 +609,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
         &[
             llvm_pi64.into(),
             llvm_usize.into(),
+            llvm_usize.into(),
+            llvm_usize.into(),
         ],
         false,
     );
@@ -615,12 +619,16 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
             ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
         });
 
+    let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
+    let end = end.unwrap_or_else(|| dims.size(ctx, generator));
     ctx.builder
         .build_call(
             ndarray_calc_size_fn,
             &[
                 dims.base_ptr(ctx, generator).into(),
                 dims.size(ctx, generator).into(),
+                begin.into(),
+                end.into(),
             ],
             "",
         )
diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs
index f22c721e..e44232fb 100644
--- a/nac3core/src/codegen/numpy.rs
+++ b/nac3core/src/codegen/numpy.rs
@@ -134,6 +134,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
         generator,
         ctx,
         &ndarray.dim_sizes().as_slice_value(ctx, generator),
+        (None, None),
     );
     ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
 
@@ -203,6 +204,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
         generator,
         ctx,
         &ndarray.dim_sizes().as_slice_value(ctx, generator),
+        (None, None),
     );
     ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
 
@@ -293,6 +295,7 @@ fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
         generator,
         ctx,
         &ndarray.dim_sizes().as_slice_value(ctx, generator),
+        (None, None),
     );
 
     gen_for_callback_incrementing(
@@ -661,6 +664,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
         generator,
         ctx,
         &ndarray.dim_sizes().as_slice_value(ctx, generator),
+        (None, None),
     );
     let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
     let len_bytes = ctx.builder