forked from M-Labs/nac3
1
0
Fork 0

core: Add missing unchecked accessors for NDArrayDimsProxy

This commit is contained in:
David Mak 2024-03-18 15:51:01 +08:00
parent 1b77e62901
commit 50264e8750
3 changed files with 52 additions and 20 deletions

View File

@ -627,6 +627,26 @@ impl<'ctx> NDArrayDimsProxy<'ctx> {
.unwrap()
}
/// # Safety
///
/// This function should be called with a valid index.
pub unsafe fn ptr_offset_unchecked(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
idx: IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name
.map(|v| format!("{v}.addr"))
.unwrap_or_default();
ctx.builder.build_in_bounds_gep(
self.as_ptr_value(ctx),
&[idx],
var_name.as_str(),
).unwrap()
}
/// Returns the pointer to the size of the `idx`-th dimension.
pub fn ptr_offset(
&self,
@ -650,19 +670,26 @@ impl<'ctx> NDArrayDimsProxy<'ctx> {
ctx.current_loc,
);
let var_name = name
.map(|v| format!("{v}.addr"))
.unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_ptr_value(ctx),
&[idx],
var_name.as_str(),
).unwrap()
self.ptr_offset_unchecked(ctx, idx, name)
}
}
/// # Safety
///
/// This function should be called with a valid index.
pub unsafe fn get_unchecked(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
idx: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
let ptr = self.ptr_offset_unchecked(ctx, idx, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
/// Returns the size of the `idx`-th dimension.
pub fn get(
&self,
@ -862,7 +889,9 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
.map(BasicValueEnum::into_int_value)
.map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap())
.unwrap();
let dim_sz = self.0.dim_sizes().get(ctx, generator, i, None);
let dim_sz = unsafe {
self.0.dim_sizes().get_unchecked(ctx, i, None)
};
let dim_lt = ctx.builder.build_int_compare(
IntPredicate::SLT,
@ -938,7 +967,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
let (dim_idx, dim_sz) = unsafe {
(
indices.data().get_unchecked(ctx, i, None).into_int_value(),
self.0.dim_sizes().get(ctx, generator, i, None),
self.0.dim_sizes().get_unchecked(ctx, i, None),
)
};

View File

@ -1300,12 +1300,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = v.dim_sizes().ptr_offset(
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
llvm_usize.const_int(1, false),
None,
);
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().as_ptr_value(ctx),

View File

@ -150,7 +150,9 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
.build_int_z_extend(shape_dim, llvm_usize, "")
.unwrap();
let ndarray_pdim = ndarray.dim_sizes().ptr_offset(ctx, generator, i, None);
let ndarray_pdim = unsafe {
ndarray.dim_sizes().ptr_offset_unchecked(ctx, i, None)
};
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
@ -620,8 +622,8 @@ fn ndarray_copy_impl<'ctx>(
|_, ctx, shape| {
Ok(shape.load_ndims(ctx))
},
|generator, ctx, shape, idx| {
Ok(shape.dim_sizes().get(ctx, generator, idx, None))
|_, ctx, shape, idx| {
unsafe { Ok(shape.dim_sizes().get_unchecked(ctx, idx, None)) }
},
)?;