1
0
forked from M-Labs/nac3

core/irrt: Add support for calculating partial size of NDArray

This commit is contained in:
David Mak 2024-05-27 15:58:06 +08:00
parent 588c15f80d
commit b6ff75dcaf
6 changed files with 30 additions and 9 deletions

View File

@ -737,7 +737,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx.builder
.build_int_compare(
@ -955,7 +955,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx.builder
.build_int_compare(

View File

@ -1122,7 +1122,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator))
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
}
}

View File

@ -1819,6 +1819,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);

View File

@ -202,10 +202,14 @@ double __nac3_j0(double x) {
uint32_t __nac3_ndarray_calc_size(
const uint64_t *list_data,
uint32_t list_len
uint32_t list_len,
uint32_t begin_idx,
uint32_t end_idx
) {
__builtin_assume(end_idx <= list_len);
uint32_t num_elems = 1;
for (uint32_t i = 0; i < list_len; ++i) {
for (uint32_t i = begin_idx; i < end_idx; ++i) {
uint64_t val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
@ -215,10 +219,14 @@ uint32_t __nac3_ndarray_calc_size(
uint64_t __nac3_ndarray_calc_size64(
const uint64_t *list_data,
uint64_t list_len
uint64_t list_len,
uint64_t begin_idx,
uint64_t end_idx
) {
__builtin_assume(end_idx <= list_len);
uint64_t num_elems = 1;
for (uint64_t i = 0; i < list_len; ++i) {
for (uint64_t i = begin_idx; i < end_idx; ++i) {
uint64_t val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;

View File

@ -583,12 +583,14 @@ pub fn call_j0<'ctx>(
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `num_dims` - An [`IntValue`] containing the number of dimensions.
/// * `dims` - A [`PointerValue`] to an array containing the size of each dimension.
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
@ -607,6 +609,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
&[
llvm_pi64.into(),
llvm_usize.into(),
llvm_usize.into(),
llvm_usize.into(),
],
false,
);
@ -615,12 +619,16 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)

View File

@ -134,6 +134,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
@ -203,6 +204,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
@ -293,6 +295,7 @@ fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
gen_for_callback_incrementing(
@ -661,6 +664,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
let len_bytes = ctx.builder