forked from M-Labs/nac3
1
0
Fork 0

core: Add missing unchecked accessors for NDArrayDimsProxy

This commit is contained in:
David Mak 2024-03-18 15:51:01 +08:00
parent 1b77e62901
commit 50264e8750
3 changed files with 52 additions and 20 deletions

View File

@ -627,6 +627,26 @@ impl<'ctx> NDArrayDimsProxy<'ctx> {
.unwrap() .unwrap()
} }
/// # Safety
///
/// This function should be called with a valid index.
pub unsafe fn ptr_offset_unchecked(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
idx: IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name
.map(|v| format!("{v}.addr"))
.unwrap_or_default();
ctx.builder.build_in_bounds_gep(
self.as_ptr_value(ctx),
&[idx],
var_name.as_str(),
).unwrap()
}
/// Returns the pointer to the size of the `idx`-th dimension. /// Returns the pointer to the size of the `idx`-th dimension.
pub fn ptr_offset( pub fn ptr_offset(
&self, &self,
@ -650,19 +670,26 @@ impl<'ctx> NDArrayDimsProxy<'ctx> {
ctx.current_loc, ctx.current_loc,
); );
let var_name = name
.map(|v| format!("{v}.addr"))
.unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( self.ptr_offset_unchecked(ctx, idx, name)
self.as_ptr_value(ctx),
&[idx],
var_name.as_str(),
).unwrap()
} }
} }
/// # Safety
///
/// This function should be called with a valid index.
pub unsafe fn get_unchecked(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
idx: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
let ptr = self.ptr_offset_unchecked(ctx, idx, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
/// Returns the size of the `idx`-th dimension. /// Returns the size of the `idx`-th dimension.
pub fn get( pub fn get(
&self, &self,
@ -862,7 +889,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()) .map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap())
.unwrap(); .unwrap();
let dim_sz = self.0.dim_sizes().get(ctx, generator, i, None); let dim_sz = unsafe {
self.0.dim_sizes().get_unchecked(ctx, i, None)
};
let dim_lt = ctx.builder.build_int_compare( let dim_lt = ctx.builder.build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
@ -938,7 +967,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
let (dim_idx, dim_sz) = unsafe { let (dim_idx, dim_sz) = unsafe {
( (
indices.data().get_unchecked(ctx, i, None).into_int_value(), indices.data().get_unchecked(ctx, i, None).into_int_value(),
self.0.dim_sizes().get(ctx, generator, i, None), self.0.dim_sizes().get_unchecked(ctx, i, None),
) )
}; };

View File

@ -1300,12 +1300,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = v.dim_sizes().ptr_offset( let v_dims_src_ptr = unsafe {
ctx, v.dim_sizes().ptr_offset_unchecked(
generator, ctx,
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
None, None,
); )
};
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
ndarray.dim_sizes().as_ptr_value(ctx), ndarray.dim_sizes().as_ptr_value(ctx),

View File

@ -150,7 +150,9 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
.build_int_z_extend(shape_dim, llvm_usize, "") .build_int_z_extend(shape_dim, llvm_usize, "")
.unwrap(); .unwrap();
let ndarray_pdim = ndarray.dim_sizes().ptr_offset(ctx, generator, i, None); let ndarray_pdim = unsafe {
ndarray.dim_sizes().ptr_offset_unchecked(ctx, i, None)
};
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
@ -620,8 +622,8 @@ fn ndarray_copy_impl<'ctx>(
|_, ctx, shape| { |_, ctx, shape| {
Ok(shape.load_ndims(ctx)) Ok(shape.load_ndims(ctx))
}, },
|generator, ctx, shape, idx| { |_, ctx, shape, idx| {
Ok(shape.dim_sizes().get(ctx, generator, idx, None)) unsafe { Ok(shape.dim_sizes().get_unchecked(ctx, idx, None)) }
}, },
)?; )?;