forked from M-Labs/nac3
core/irrt: Add support for calculating partial size of NDArray
This commit is contained in:
parent
588c15f80d
commit
b6ff75dcaf
|
@ -737,7 +737,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let n_sz_eqz = ctx.builder
|
let n_sz_eqz = ctx.builder
|
||||||
.build_int_compare(
|
.build_int_compare(
|
||||||
|
@ -955,7 +955,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let n_sz_eqz = ctx.builder
|
let n_sz_eqz = ctx.builder
|
||||||
.build_int_compare(
|
.build_int_compare(
|
||||||
|
|
|
@ -1122,7 +1122,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
generator: &G,
|
generator: &G,
|
||||||
) -> IntValue<'ctx> {
|
) -> IntValue<'ctx> {
|
||||||
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator))
|
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1819,6 +1819,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
);
|
);
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
|
|
|
@ -202,10 +202,14 @@ double __nac3_j0(double x) {
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_calc_size(
|
uint32_t __nac3_ndarray_calc_size(
|
||||||
const uint64_t *list_data,
|
const uint64_t *list_data,
|
||||||
uint32_t list_len
|
uint32_t list_len,
|
||||||
|
uint32_t begin_idx,
|
||||||
|
uint32_t end_idx
|
||||||
) {
|
) {
|
||||||
|
__builtin_assume(end_idx <= list_len);
|
||||||
|
|
||||||
uint32_t num_elems = 1;
|
uint32_t num_elems = 1;
|
||||||
for (uint32_t i = 0; i < list_len; ++i) {
|
for (uint32_t i = begin_idx; i < end_idx; ++i) {
|
||||||
uint64_t val = list_data[i];
|
uint64_t val = list_data[i];
|
||||||
__builtin_assume(val > 0);
|
__builtin_assume(val > 0);
|
||||||
num_elems *= val;
|
num_elems *= val;
|
||||||
|
@ -215,10 +219,14 @@ uint32_t __nac3_ndarray_calc_size(
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_calc_size64(
|
uint64_t __nac3_ndarray_calc_size64(
|
||||||
const uint64_t *list_data,
|
const uint64_t *list_data,
|
||||||
uint64_t list_len
|
uint64_t list_len,
|
||||||
|
uint64_t begin_idx,
|
||||||
|
uint64_t end_idx
|
||||||
) {
|
) {
|
||||||
|
__builtin_assume(end_idx <= list_len);
|
||||||
|
|
||||||
uint64_t num_elems = 1;
|
uint64_t num_elems = 1;
|
||||||
for (uint64_t i = 0; i < list_len; ++i) {
|
for (uint64_t i = begin_idx; i < end_idx; ++i) {
|
||||||
uint64_t val = list_data[i];
|
uint64_t val = list_data[i];
|
||||||
__builtin_assume(val > 0);
|
__builtin_assume(val > 0);
|
||||||
num_elems *= val;
|
num_elems *= val;
|
||||||
|
|
|
@ -583,12 +583,14 @@ pub fn call_j0<'ctx>(
|
||||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||||
/// calculated total size.
|
/// calculated total size.
|
||||||
///
|
///
|
||||||
/// * `num_dims` - An [`IntValue`] containing the number of dimensions.
|
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
|
||||||
/// * `dims` - A [`PointerValue`] to an array containing the size of each dimension.
|
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
|
||||||
|
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
|
||||||
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||||
generator: &G,
|
generator: &G,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
dims: &Dims,
|
dims: &Dims,
|
||||||
|
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
|
||||||
) -> IntValue<'ctx>
|
) -> IntValue<'ctx>
|
||||||
where
|
where
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
|
@ -607,6 +609,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||||
&[
|
&[
|
||||||
llvm_pi64.into(),
|
llvm_pi64.into(),
|
||||||
llvm_usize.into(),
|
llvm_usize.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
],
|
],
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
|
@ -615,12 +619,16 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||||
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
|
||||||
|
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_call(
|
.build_call(
|
||||||
ndarray_calc_size_fn,
|
ndarray_calc_size_fn,
|
||||||
&[
|
&[
|
||||||
dims.base_ptr(ctx, generator).into(),
|
dims.base_ptr(ctx, generator).into(),
|
||||||
dims.size(ctx, generator).into(),
|
dims.size(ctx, generator).into(),
|
||||||
|
begin.into(),
|
||||||
|
end.into(),
|
||||||
],
|
],
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
|
|
@ -134,6 +134,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
);
|
);
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
|
@ -203,6 +204,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
);
|
);
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
|
@ -293,6 +295,7 @@ fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
);
|
);
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
gen_for_callback_incrementing(
|
||||||
|
@ -661,6 +664,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
);
|
);
|
||||||
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
|
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let len_bytes = ctx.builder
|
let len_bytes = ctx.builder
|
||||||
|
|
Loading…
Reference in New Issue