forked from M-Labs/nac3
1
0
Fork 0

core/numpy: Implement ndarray_sliced_{copy,copyto_impl}

Performing copying with optional support for slicing. Also made
copy_impl delegate to sliced_copy, as sliced_copy now performs a
superset of operations that copy_impl can already do.
This commit is contained in:
David Mak 2024-05-30 14:25:56 +08:00
parent a176c3eb70
commit 135ef557f9
1 changed files with 238 additions and 40 deletions

View File

@ -9,6 +9,7 @@ use crate::{
NDArrayValue, NDArrayValue,
TypedArrayLikeAccessor, TypedArrayLikeAccessor,
TypedArrayLikeAdapter, TypedArrayLikeAdapter,
TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeAccessor,
UntypedArrayLikeMutator, UntypedArrayLikeMutator,
}, },
@ -16,6 +17,7 @@ use crate::{
CodeGenerator, CodeGenerator,
expr::gen_binop_expr_with_values, expr::gen_binop_expr_with_values,
irrt::{ irrt::{
calculate_len_for_slice_range,
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index, call_ndarray_calc_broadcast_index,
call_ndarray_calc_nd_indices, call_ndarray_calc_nd_indices,
@ -23,7 +25,7 @@ use crate::{
}, },
llvm_intrinsics, llvm_intrinsics,
llvm_intrinsics::{call_memcpy_generic}, llvm_intrinsics::{call_memcpy_generic},
stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{
@ -645,6 +647,240 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
Ok(ndarray) Ok(ndarray)
} }
/// Copies a slice of an [`NDArrayValue`] to another.
///
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz`
/// fields should be populated before calling this function.
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
/// dimensional slice in the destination array.
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
/// dimensional slice in the source array.
/// - `dim`: The index of the currently processing dimension.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
dim: u64,
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
) -> Result<(), String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
// If there are no (remaining) slice expressions, memcpy the entire dimension
if slices.is_empty() {
let stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim, false)), None),
);
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let cpy_len = ctx.builder.build_int_mul(
stride,
sizeof_elem,
""
).unwrap();
call_memcpy_generic(
ctx,
dst_slice_ptr,
src_slice_ptr,
cpy_len,
llvm_i1.const_zero(),
);
return Ok(())
}
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
// arr[i + 1] in this dimension
let src_stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let dst_stride = call_ndarray_calc_size(
generator,
ctx,
&dst_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let (start, stop, step) = slices[0];
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
gen_for_range_callback(
generator,
ctx,
false,
|_, _| Ok(start),
(|_, _| Ok(stop), true),
|_, _| Ok(step),
|generator, ctx, src_i| {
// Calculate the offset of the active slice
let src_data_offset = ctx.builder.build_int_mul(
src_stride,
src_i,
"",
).unwrap();
let dst_i = ctx.builder
.build_load(dst_i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dst_data_offset = ctx.builder.build_int_mul(
dst_stride,
dst_i,
"",
).unwrap();
let (src_ptr, dst_ptr) = unsafe {
(
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
)
};
ndarray_sliced_copyto_impl(
generator,
ctx,
elem_ty,
(dst_arr, dst_ptr),
(src_arr, src_ptr),
dim + 1,
&slices[1..],
)?;
let dst_i = ctx.builder
.build_load(dst_i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dst_i_add1 = ctx.builder
.build_int_add(dst_i, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
Ok(())
},
)?;
Ok(())
}
/// Copies a [`NDArrayValue`] using slices.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray = if slices.is_empty() {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| {
Ok(shape.load_ndims(ctx))
},
|generator, ctx, shape, idx| {
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
},
)?
} else {
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
let ndims = this.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndims);
// Populate the first slices.len() dimensions by computing the size of each dim slice
for (i, (start, stop, step)) in slices.iter().enumerate() {
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
let stop = ctx.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
*step,
llvm_i32.const_zero(),
"is_neg",
).unwrap(),
ctx.builder.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one").unwrap(),
ctx.builder.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
let slice_len = ctx.builder.build_int_z_extend_or_bit_cast(
slice_len,
llvm_usize,
""
).unwrap();
unsafe {
ndarray.dim_sizes()
.set_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
slice_len,
);
}
}
// Populate the rest by directly copying the dim size from the source array
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false),
|generator, ctx, idx| {
unsafe {
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
}
Ok(())
},
llvm_usize.const_int(1, false),
).unwrap();
ndarray_init_data(generator, ctx, elem_ty, ndarray)
};
ndarray_sliced_copyto_impl(
generator,
ctx,
elem_ty,
(ndarray, ndarray.data().base_ptr(ctx, generator)),
(this, this.data().base_ptr(ctx, generator)),
0,
slices,
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`. /// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
@ -654,45 +890,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
elem_ty: Type, elem_ty: Type,
this: NDArrayValue<'ctx>, this: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i1 = ctx.ctx.bool_type(); ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| {
Ok(shape.load_ndims(ctx))
},
|generator, ctx, shape, idx| {
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
},
)?;
let len = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
let len_bytes = ctx.builder
.build_int_mul(
len,
sizeof_ty.size_of().unwrap(),
"",
)
.unwrap();
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
this.data().base_ptr(ctx, generator),
len_bytes,
llvm_i1.const_zero(),
);
Ok(ndarray)
} }
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(