From 976a9512c1baacf0716607bc356471993577f1b7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 15 Feb 2024 15:10:12 +0800 Subject: [PATCH] core: Add const variants to NDArray element getters --- nac3core/src/codegen/classes.rs | 115 ++++++++++++++++++++++++++++++- nac3core/src/codegen/irrt/mod.rs | 98 ++++++++++++++++++++++---- 2 files changed, 198 insertions(+), 15 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 86822eef..30928378 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1,12 +1,12 @@ use inkwell::{ IntPredicate, types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{ArrayValue, BasicValueEnum, IntValue, PointerValue}, }; use crate::codegen::{ CodeGenContext, CodeGenerator, - irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, + irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const}, stmt::gen_for_callback, }; @@ -725,7 +725,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> { let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { panic!("Expected list[int32] but got {indices_elem_ty}") }; - assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}"); + debug_assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}"); let index = call_ndarray_flatten_index( generator, @@ -743,6 +743,92 @@ impl<'ctx> NDArrayDataProxy<'ctx> { } } + pub unsafe fn ptr_offset_unchecked_const( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + indices: ArrayValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let index = call_ndarray_flatten_index_const( + generator, + ctx, + self.0, + indices, + ).unwrap(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[index], + name.unwrap_or_default(), + ) + } + } + + /// Returns the pointer to the data at the index specified by `indices`. + pub fn ptr_offset_const( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + indices: ArrayValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let indices_elem_ty = indices.get_type().get_element_type(); + let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { + panic!("Expected [int32] but got [{indices_elem_ty}]") + }; + assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected [int32] but got [{indices_elem_ty}]"); + + let nidx_leq_ndims = ctx.builder.build_int_compare( + IntPredicate::SLE, + llvm_usize.const_int(indices.get_type().len() as u64, false), + self.0.load_ndims(ctx), + "" + ); + ctx.make_assert( + generator, + nidx_leq_ndims, + "0:IndexError", + "invalid index to scalar variable", + [None, None, None], + ctx.current_loc, + ); + + for idx in 0..indices.get_type().len() { + let i = llvm_usize.const_int(idx as u64, false); + + let dim_idx = ctx.builder + .build_extract_value(indices, idx, "") + .map(|v| v.into_int_value()) + .map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "")) + .unwrap(); + let dim_sz = self.0.get_dims().get(ctx, generator, i, None); + + let dim_lt = ctx.builder.build_int_compare( + IntPredicate::SLT, + dim_idx, + dim_sz, + "" + ); + + ctx.make_assert( + generator, + dim_lt, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(dim_idx), Some(dim_sz), None], + ctx.current_loc, + ); + } + + unsafe { + self.ptr_offset_unchecked_const(ctx, generator, indices, name) + } + } + /// Returns the pointer to the data at the index specified by `indices`. pub fn ptr_offset( &self, @@ -844,6 +930,17 @@ impl<'ctx> NDArrayDataProxy<'ctx> { } } + pub unsafe fn get_unsafe_const( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + indices: ArrayValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset_unchecked_const(ctx, generator, indices, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + pub unsafe fn get_unsafe( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -855,6 +952,18 @@ impl<'ctx> NDArrayDataProxy<'ctx> { ctx.builder.build_load(ptr, name.unwrap_or_default()) } + /// Returns the data at the index specified by `indices`. + pub fn get_const( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + indices: ArrayValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset_const(ctx, generator, indices, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + /// Returns the data at the index specified by `indices`. pub fn get( &self, diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 20ec0272..a6f030fe 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -10,8 +10,8 @@ use inkwell::{ context::Context, memory_buffer::MemoryBuffer, module::Module, - types::BasicTypeEnum, - values::{FloatValue, IntValue, PointerValue}, + types::{BasicTypeEnum, IntType}, + values::{ArrayValue, FloatValue, IntValue, PointerValue}, AddressSpace, IntPredicate, }; use nac3parser::ast::Expr; @@ -707,16 +707,12 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( Ok(indices) } -/// Generates a call to `__nac3_ndarray_flatten_index`. -/// -/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx>( +fn call_ndarray_flatten_index_impl<'ctx>( generator: &dyn CodeGenerator, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, - indices: ListValue<'ctx>, + indices: PointerValue<'ctx>, + indices_size: IntValue<'ctx>, ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -724,6 +720,19 @@ pub fn call_ndarray_flatten_index<'ctx>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + debug_assert_eq!( + IntType::try_from(indices.get_type().get_element_type()) + .map(|itype| itype.get_bit_width()) + .unwrap_or_default(), + llvm_i32.get_bit_width(), + "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" + ); + debug_assert_eq!( + indices_size.get_type().get_bit_width(), + llvm_usize.get_bit_width(), + "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" + ); + let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_flatten_index", 64 => "__nac3_ndarray_flatten_index64", @@ -745,8 +754,6 @@ pub fn call_ndarray_flatten_index<'ctx>( let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_dims = ndarray.get_dims(); - let indices_size = indices.load_size(ctx, None); - let indices_data = indices.get_data(); let index = ctx.builder .build_call( @@ -765,4 +772,71 @@ pub fn call_ndarray_flatten_index<'ctx>( .unwrap(); Ok(index) -} \ No newline at end of file +} + +/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the +/// multidimensional index. +/// +/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index<'ctx>( + generator: &dyn CodeGenerator, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: ListValue<'ctx>, +) -> Result, String> { + let indices_size = indices.load_size(ctx, None); + let indices_data = indices.get_data(); + + call_ndarray_flatten_index_impl( + generator, + ctx, + ndarray, + indices_data.get_ptr(ctx), + indices_size, + ) +} +/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the +/// multidimensional index. +/// +/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index_const<'ctx>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: ArrayValue<'ctx>, +) -> Result, String> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let indices_size = indices.get_type().len(); + let indices_alloca = generator.gen_array_var_alloc( + ctx, + indices.get_type().get_element_type(), + llvm_usize.const_int(indices_size as u64, false), + None + )?; + for i in 0..indices_size { + let v = ctx.builder.build_extract_value(indices, i, "") + .unwrap() + .into_int_value(); + let elem_ptr = unsafe { + ctx.builder.build_in_bounds_gep( + indices_alloca, + &[ctx.ctx.i32_type().const_int(i as u64, false)], + "" + ) + }; + ctx.builder.build_store(elem_ptr, v); + } + + call_ndarray_flatten_index_impl( + generator, + ctx, + ndarray, + indices_alloca, + llvm_usize.const_int(indices_size as u64, false), + ) +}