From 50264e875099e698f2ff5e39fd9fdb04f5ea83db Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 18 Mar 2024 15:51:01 +0800 Subject: [PATCH] core: Add missing unchecked accessors for NDArrayDimsProxy --- nac3core/src/codegen/classes.rs | 51 ++++++++++++++++++++++++++------- nac3core/src/codegen/expr.rs | 13 +++++---- nac3core/src/codegen/numpy.rs | 8 ++++-- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 2867b8f..9c7c262 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -627,6 +627,26 @@ impl<'ctx> NDArrayDimsProxy<'ctx> { .unwrap() } + /// # Safety + /// + /// This function should be called with a valid index. + pub unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let var_name = name + .map(|v| format!("{v}.addr")) + .unwrap_or_default(); + + ctx.builder.build_in_bounds_gep( + self.as_ptr_value(ctx), + &[idx], + var_name.as_str(), + ).unwrap() + } + /// Returns the pointer to the size of the `idx`-th dimension. pub fn ptr_offset( &self, @@ -650,19 +670,26 @@ impl<'ctx> NDArrayDimsProxy<'ctx> { ctx.current_loc, ); - let var_name = name - .map(|v| format!("{v}.addr")) - .unwrap_or_default(); - unsafe { - ctx.builder.build_in_bounds_gep( - self.as_ptr_value(ctx), - &[idx], - var_name.as_str(), - ).unwrap() + self.ptr_offset_unchecked(ctx, idx, name) } } + /// # Safety + /// + /// This function should be called with a valid index. + pub unsafe fn get_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> IntValue<'ctx> { + let ptr = self.ptr_offset_unchecked(ctx, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + .map(BasicValueEnum::into_int_value) + .unwrap() + } + /// Returns the size of the `idx`-th dimension. pub fn get( &self, @@ -862,7 +889,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> { .map(BasicValueEnum::into_int_value) .map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()) .unwrap(); - let dim_sz = self.0.dim_sizes().get(ctx, generator, i, None); + let dim_sz = unsafe { + self.0.dim_sizes().get_unchecked(ctx, i, None) + }; let dim_lt = ctx.builder.build_int_compare( IntPredicate::SLT, @@ -938,7 +967,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> { let (dim_idx, dim_sz) = unsafe { ( indices.data().get_unchecked(ctx, i, None).into_int_value(), - self.0.dim_sizes().get(ctx, generator, i, None), + self.0.dim_sizes().get_unchecked(ctx, i, None), ) }; diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index b4ccef9..5711a6a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1300,12 +1300,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); let ndarray_num_dims = ndarray.load_ndims(ctx); - let v_dims_src_ptr = v.dim_sizes().ptr_offset( - ctx, - generator, - llvm_usize.const_int(1, false), - None, - ); + let v_dims_src_ptr = unsafe { + v.dim_sizes().ptr_offset_unchecked( + ctx, + llvm_usize.const_int(1, false), + None, + ) + }; call_memcpy_generic( ctx, ndarray.dim_sizes().as_ptr_value(ctx), diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 7029890..3f227fe 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -150,7 +150,9 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>( .build_int_z_extend(shape_dim, llvm_usize, "") .unwrap(); - let ndarray_pdim = ndarray.dim_sizes().ptr_offset(ctx, generator, i, None); + let ndarray_pdim = unsafe { + ndarray.dim_sizes().ptr_offset_unchecked(ctx, i, None) + }; ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); @@ -620,8 +622,8 @@ fn ndarray_copy_impl<'ctx>( |_, ctx, shape| { Ok(shape.load_ndims(ctx)) }, - |generator, ctx, shape, idx| { - Ok(shape.dim_sizes().get(ctx, generator, idx, None)) + |_, ctx, shape, idx| { + unsafe { Ok(shape.dim_sizes().get_unchecked(ctx, idx, None)) } }, )?;