From d5e8df070adc1c91595881aacc382b79ab887310 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 16:10:00 +0800 Subject: [PATCH] [core] Minor improvements to IRRT and add missing documentation --- nac3core/src/codegen/generator.rs | 1 + nac3core/src/codegen/irrt/list.rs | 52 ++++++---- nac3core/src/codegen/irrt/math.rs | 20 +++- nac3core/src/codegen/irrt/mod.rs | 11 ++- nac3core/src/codegen/irrt/ndarray/basic.rs | 88 ++++++++++++++--- nac3core/src/codegen/irrt/ndarray/indexing.rs | 5 + nac3core/src/codegen/irrt/ndarray/iter.rs | 20 +++- nac3core/src/codegen/irrt/ndarray/mod.rs | 99 ++++++++----------- nac3core/src/codegen/irrt/range.rs | 18 +++- nac3core/src/codegen/types/ndarray/nditer.rs | 19 ++-- nac3core/src/codegen/values/array.rs | 8 ++ nac3core/src/codegen/values/ndarray/mod.rs | 4 + nac3core/src/codegen/values/ndarray/nditer.rs | 4 + 13 files changed, 238 insertions(+), 111 deletions(-) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index f277ec9..be007c2 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -17,6 +17,7 @@ pub trait CodeGenerator { /// Return the module name for the code generator. fn get_name(&self) -> &str; + /// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>; /// Generate function call and returns the function return value. diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs index a7fec59..2c57f8e 100644 --- a/nac3core/src/codegen/irrt/list.rs +++ b/nac3core/src/codegen/irrt/list.rs @@ -24,42 +24,52 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { - let size_ty = generator.get_size_type(ctx.ctx); - let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let int32 = ctx.ctx.i32_type(); - let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let llvm_i32 = ctx.ctx.i32_type(); + + assert_eq!(dest_idx.0.get_type(), llvm_i32); + assert_eq!(dest_idx.1.get_type(), llvm_i32); + assert_eq!(dest_idx.2.get_type(), llvm_i32); + assert_eq!(src_idx.0.get_type(), llvm_i32); + assert_eq!(src_idx.1.get_type(), llvm_i32); + assert_eq!(src_idx.2.get_type(), llvm_i32); + + let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8); let slice_assign_fun = { let ty_vec = vec![ - int32.into(), // dest start idx - int32.into(), // dest end idx - int32.into(), // dest step + llvm_i32.into(), // dest start idx + llvm_i32.into(), // dest end idx + llvm_i32.into(), // dest step elem_ptr_type.into(), // dest arr ptr - int32.into(), // dest arr len - int32.into(), // src start idx - int32.into(), // src end idx - int32.into(), // src step + llvm_i32.into(), // dest arr len + llvm_i32.into(), // src start idx + llvm_i32.into(), // src end idx + llvm_i32.into(), // src step elem_ptr_type.into(), // src arr ptr - int32.into(), // src arr len - int32.into(), // size + llvm_i32.into(), // src arr len + llvm_i32.into(), // size ]; ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = int32.fn_type(ty_vec.as_slice(), false); + let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false); ctx.module.add_function(fun_symbol, fn_t, None) }) }; - let zero = int32.const_zero(); - let one = int32.const_int(1, false); + let zero = llvm_i32.const_zero(); + let one = llvm_i32.const_int(1, false); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); let dest_arr_ptr = ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); let dest_len = dest_arr.load_size(ctx, Some("dest.len")); - let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); + let dest_len = + ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32").unwrap(); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); let src_arr_ptr = ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); let src_len = src_arr.load_size(ctx, Some("src.len")); - let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); + let src_len = + ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32").unwrap(); // index in bound and positive should be done // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and @@ -136,7 +146,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( BasicTypeEnum::StructType(t) => t.size_of().unwrap(), _ => codegen_unreachable!(ctx), }; - ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() + ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size").unwrap() } .into(), ]; @@ -147,6 +157,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( .map(Either::unwrap_left) .unwrap() }; + // update length let need_update = ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); @@ -155,7 +166,8 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); ctx.builder.position_at_end(update_bb); - let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); + let new_len = + ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); dest_arr.store_size(ctx, generator, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.position_at_end(cont_bb); diff --git a/nac3core/src/codegen/irrt/math.rs b/nac3core/src/codegen/irrt/math.rs index 4bc9591..33445b2 100644 --- a/nac3core/src/codegen/irrt/math.rs +++ b/nac3core/src/codegen/irrt/math.rs @@ -62,8 +62,13 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_f64 = ctx.ctx.f64_type(); + + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_isinf", fn_type, None) }); @@ -84,8 +89,13 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_f64 = ctx.ctx.f64_type(); + + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_isnan", fn_type, None) }); @@ -104,6 +114,8 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gamma", fn_type, None) @@ -121,6 +133,8 @@ pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gammaln", fn_type, None) @@ -138,6 +152,8 @@ pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) - pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_j0", fn_type, None) diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 21a16bd..4cacdcc 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -132,10 +132,11 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( generator: &mut G, length: IntValue<'ctx>, ) -> Result, IntValue<'ctx>, IntValue<'ctx>)>, String> { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let one = int32.const_int(1, false); - let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap(); + let llvm_i32 = ctx.ctx.i32_type(); + + let zero = llvm_i32.const_zero(); + let one = llvm_i32.const_int(1, false); + let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32").unwrap(); Ok(Some(match (start, end, step) { (s, e, None) => ( if let Some(s) = s.as_ref() { @@ -144,7 +145,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( None => return Ok(None), } } else { - int32.const_zero() + llvm_i32.const_zero() }, { let e = if let Some(s) = e.as_ref() { diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index 0daea1c..d11c9b8 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -1,4 +1,5 @@ use inkwell::{ + types::BasicTypeEnum, values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, }; @@ -7,19 +8,26 @@ use crate::codegen::{ expr::{create_and_call_function, infer_and_call_function}, irrt::get_usize_dependent_function_name, types::ProxyType, - values::{ndarray::NDArrayValue, ProxyValue}, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_ndarray_util_assert_shape_no_negative`. +/// +/// Assets that `shape` does not contain negative dimensions. pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, - ndims: IntValue<'ctx>, - shape: PointerValue<'ctx>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name( generator, ctx, @@ -30,23 +38,37 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ctx, &name, Some(llvm_usize.into()), - &[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())], + &[ + (llvm_usize.into(), shape.size(ctx, generator).into()), + (llvm_pusize.into(), shape.base_ptr(ctx, generator).into()), + ], None, None, ); } +/// Generates a call to `__nac3_ndarray_util_assert_shape_output_shape_same`. +/// +/// Asserts that `ndarray_shape` and `output_shape` are the same in the context of writing output to +/// an `ndarray`. pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, - ndarray_ndims: IntValue<'ctx>, - ndarray_shape: PointerValue<'ctx>, - output_ndims: IntValue<'ctx>, - output_shape: IntValue<'ctx>, + ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name( generator, ctx, @@ -58,16 +80,20 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + &name, Some(llvm_usize.into()), &[ - (llvm_usize.into(), ndarray_ndims.into()), - (llvm_pusize.into(), ndarray_shape.into()), - (llvm_usize.into(), output_ndims.into()), - (llvm_pusize.into(), output_shape.into()), + (llvm_usize.into(), ndarray_shape.size(ctx, generator).into()), + (llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()), + (llvm_usize.into(), output_shape.size(ctx, generator).into()), + (llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()), ], None, None, ); } +/// Generates a call to `__nac3_ndarray_size`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an +/// `ndarray`, corresponding to the value of `ndarray.size`. pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -90,6 +116,10 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_nbytes`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the +/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`. pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -112,6 +142,10 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_len`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of +/// the `ndarray`, corresponding to the value of `ndarray.__len__`. pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -134,6 +168,9 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_is_c_contiguous`. +/// +/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous. pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -156,6 +193,9 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_get_nth_pelement`. +/// +/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`. pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -167,6 +207,8 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_ndarray = ndarray.get_type().as_base_type(); + assert_eq!(index.get_type(), llvm_usize); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); create_and_call_function( @@ -181,11 +223,16 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_ndarray_get_pelement_by_indices`. +/// +/// `indices` must have the same number of elements as the number of dimensions in `ndarray`. +/// +/// Returns a [`PointerValue`] to the element indexed by `indices`. pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, - indices: PointerValue<'ctx>, + indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); @@ -193,6 +240,11 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_ndarray = ndarray.get_type().as_base_type(); + assert_eq!( + BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); @@ -202,7 +254,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized Some(llvm_pi8.into()), &[ (llvm_ndarray.into(), ndarray.as_base_value().into()), - (llvm_pusize.into(), indices.into()), + (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], Some("pelement"), None, @@ -211,6 +263,9 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized .unwrap() } +/// Generates a call to `__nac3_ndarray_set_strides_by_shape`. +/// +/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous. pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -231,6 +286,11 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( ); } +/// Generates a call to `__nac3_ndarray_copy_data`. +/// +/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number +/// of elements in `src_ndarray` must be greater than or equal to the number of elements in +/// `dst_ndarray`. pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index 0821b2c..3e2c908 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -5,6 +5,11 @@ use crate::codegen::{ CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_ndarray_index`. +/// +/// Performs [basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) +/// on `src_ndarray` using `indices`, writing the result to `dst_ndarray`, corresponding to the +/// operation `dst_ndarray = src_ndarray[indices]`. pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index 966d660..47cd5b2 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -1,4 +1,5 @@ use inkwell::{ + types::BasicTypeEnum, values::{BasicValueEnum, IntValue}, AddressSpace, }; @@ -9,21 +10,29 @@ use crate::codegen::{ types::ProxyType, values::{ ndarray::{NDArrayValue, NDIterValue}, - ArrayLikeValue, ArraySliceValue, ProxyValue, + ProxyValue, TypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_nditer_initialize`. +/// +/// Initializes the `iter` object. pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>, ndarray: NDArrayValue<'ctx>, - indices: ArraySliceValue<'ctx>, + indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_eq!( + BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); create_and_call_function( @@ -40,6 +49,10 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( ); } +/// Generates a call to `__nac3_nditer_initialize_has_element`. +/// +/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter` +/// object. pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -59,6 +72,9 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } +/// Generates a call to `__nac3_nditer_next`. +/// +/// Moves `iter` to point to the next element. pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 56d9094..b74ace0 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,10 +1,11 @@ use inkwell::{ - types::IntType, + types::{BasicTypeEnum, IntType}, values::{BasicValueEnum, CallSiteValue, IntValue}, AddressSpace, IntPredicate, }; use itertools::Either; +use super::get_usize_dependent_function_name; use crate::codegen::{ llvm_intrinsics, macros::codegen_unreachable, @@ -23,8 +24,8 @@ mod basic; mod indexing; mod iter; -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. +/// Generates a call to `__nac3_ndarray_calc_size`. Returns a +/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. /// /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, @@ -43,18 +44,22 @@ where let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize)); + assert!(end.is_none_or(|end| end.get_type() == llvm_usize)); + assert_eq!( + BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let ndarray_calc_size_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size"); let ndarray_calc_size_fn_t = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], false, ); let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) + ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| { + ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) }); let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); @@ -76,10 +81,10 @@ where .unwrap() } -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] +/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`] /// containing `i32` indices of the flattened index. /// -/// * `index` - The index to compute the multidimensional index for. +/// * `index` - The `llvm_usize` index to compute the multidimensional index for. /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an /// `NDArray`. pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( @@ -94,19 +99,18 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + assert_eq!(index.get_type(), llvm_usize); + + let ndarray_calc_nd_indices_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices"); let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { let fn_type = llvm_void.fn_type( &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], false, ); - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None) }); let ndarray_num_dims = ndarray.load_ndims(ctx); @@ -134,15 +138,21 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ) } -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( +/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for +/// the multidimensional index. +/// +/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index<'ctx, G, Index>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, - indices: &Indices, + indices: &Index, ) -> IntValue<'ctx> where G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, + Index: ArrayLikeIndexer<'ctx>, { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -163,19 +173,16 @@ where "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" ); - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + let ndarray_flatten_index_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index"); let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], false, ); - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None) }); let ndarray_num_dims = ndarray.load_ndims(ctx); @@ -201,27 +208,8 @@ where index } -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`] +/// containing the size of each dimension of the resultant `ndarray`. pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -231,13 +219,10 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; + let ndarray_calc_broadcast_fn_name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast"); let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[ llvm_pusize.into(), @@ -249,7 +234,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( false, ); - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None) }); let lhs_ndims = lhs.load_ndims(ctx); diff --git a/nac3core/src/codegen/irrt/range.rs b/nac3core/src/codegen/irrt/range.rs index 47c63c4..3b6bc31 100644 --- a/nac3core/src/codegen/irrt/range.rs +++ b/nac3core/src/codegen/irrt/range.rs @@ -6,6 +6,13 @@ use itertools::Either; use crate::codegen::{CodeGenContext, CodeGenerator}; +/// Invokes the `__nac3_range_slice_len` in IRRT. +/// +/// - `start`: The `i32` start value for the slice. +/// - `end`: The `i32` end value for the slice. +/// - `step`: The `i32` step value for the slice. +/// +/// Returns an `i32` value of the length of the slice. pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -14,9 +21,15 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( step: IntValue<'ctx>, ) -> IntValue<'ctx> { const SYMBOL: &str = "__nac3_range_slice_len"; + + let llvm_i32 = ctx.ctx.i32_type(); + + assert_eq!(start.get_type(), llvm_i32); + assert_eq!(end.get_type(), llvm_i32); + assert_eq!(step.get_type(), llvm_i32); + let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); + let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false); ctx.module.add_function(SYMBOL, fn_t, None) }); @@ -33,6 +46,7 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( [None, None, None], ctx.current_loc, ); + ctx.builder .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") .map(CallSiteValue::try_as_basic_value) diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index c9b6b7d..772d5b2 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -14,7 +14,7 @@ use crate::codegen::{ types::structure::{check_struct_type_matches_fields, StructField, StructFields}, values::{ ndarray::{NDArrayValue, NDIterValue}, - ArraySliceValue, ProxyValue, + ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -128,6 +128,11 @@ impl<'ctx> NDIterType<'ctx> { } /// Allocate an [`NDIter`] that iterates through the given `ndarray`. + /// + /// Note: This function allocates an array on the stack at the current builder location, which + /// may lead to stack explosion if called in a hot loop. Therefore, callers are recommended to + /// call `llvm.stacksave` before calling this function and call `llvm.stackrestore` after the + /// [`NDIter`] is no longer needed. #[must_use] pub fn construct( &self, @@ -141,16 +146,12 @@ impl<'ctx> NDIterType<'ctx> { // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap(); + let indices = + TypedArrayLikeAdapter::from(indices, |_, _, v| v.into_int_value(), |_, _, v| v.into()); - let nditer = >::Value::from_pointer_value( - nditer, - ndarray, - indices, - self.llvm_usize, - None, - ); + let nditer = self.map_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None); - irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, indices); + irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, &indices); nditer } diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 9f3ec0e..55e91b2 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -265,6 +265,14 @@ where ) -> IntValue<'ctx> { self.adapted.size(ctx, generator) } + + fn as_slice_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &CG, + ) -> ArraySliceValue<'ctx> { + self.adapted.as_slice_value(ctx, generator) + } } impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 0da3a2e..4c5be43 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -358,6 +358,10 @@ impl<'ctx> NDArrayValue<'ctx> { irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self); } + /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and + /// copy the contents over. + /// + /// The new ndarray will own its data and will be C-contiguous. #[must_use] pub fn make_copy( &self, diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index e29770e..4b4e07a 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -141,6 +141,10 @@ impl<'ctx> NDArrayValue<'ctx> { /// /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to /// get properties of the current iteration (e.g., the current element, indices, etc.) + /// + /// Note: The caller is recommended to call `llvm.stacksave` and `llvm.stackrestore` before and + /// after invoking this function respectively. See [`NDIterType::construct`] for an explanation + /// on why this is suggested. pub fn foreach<'a, G, F>( &self, generator: &mut G,