forked from M-Labs/nac3
core/numpy: Implement ndarray_sliced_{copy,copyto_impl}
Performing copying with optional support for slicing. Also made copy_impl delegate to sliced_copy, as sliced_copy now performs a superset of operations that copy_impl can already do.
This commit is contained in:
parent
a176c3eb70
commit
135ef557f9
@ -9,6 +9,7 @@ use crate::{
|
||||
NDArrayValue,
|
||||
TypedArrayLikeAccessor,
|
||||
TypedArrayLikeAdapter,
|
||||
TypedArrayLikeMutator,
|
||||
UntypedArrayLikeAccessor,
|
||||
UntypedArrayLikeMutator,
|
||||
},
|
||||
@ -16,6 +17,7 @@ use crate::{
|
||||
CodeGenerator,
|
||||
expr::gen_binop_expr_with_values,
|
||||
irrt::{
|
||||
calculate_len_for_slice_range,
|
||||
call_ndarray_calc_broadcast,
|
||||
call_ndarray_calc_broadcast_index,
|
||||
call_ndarray_calc_nd_indices,
|
||||
@ -23,7 +25,7 @@ use crate::{
|
||||
},
|
||||
llvm_intrinsics,
|
||||
llvm_intrinsics::{call_memcpy_generic},
|
||||
stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback},
|
||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{
|
||||
@ -645,6 +647,240 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Copies a slice of an [`NDArrayValue`] to another.
|
||||
///
|
||||
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz`
|
||||
/// fields should be populated before calling this function.
|
||||
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
||||
/// dimensional slice in the destination array.
|
||||
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
|
||||
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
||||
/// dimensional slice in the source array.
|
||||
/// - `dim`: The index of the currently processing dimension.
|
||||
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
||||
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
|
||||
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||
dim: u64,
|
||||
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
||||
) -> Result<(), String> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
// If there are no (remaining) slice expressions, memcpy the entire dimension
|
||||
if slices.is_empty() {
|
||||
let stride = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&src_arr.dim_sizes(),
|
||||
(Some(llvm_usize.const_int(dim, false)), None),
|
||||
);
|
||||
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
|
||||
let cpy_len = ctx.builder.build_int_mul(
|
||||
stride,
|
||||
sizeof_elem,
|
||||
""
|
||||
).unwrap();
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
dst_slice_ptr,
|
||||
src_slice_ptr,
|
||||
cpy_len,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
return Ok(())
|
||||
}
|
||||
|
||||
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
|
||||
// arr[i + 1] in this dimension
|
||||
let src_stride = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&src_arr.dim_sizes(),
|
||||
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||
);
|
||||
let dst_stride = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&dst_arr.dim_sizes(),
|
||||
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||
);
|
||||
|
||||
let (start, stop, step) = slices[0];
|
||||
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
|
||||
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
|
||||
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
|
||||
|
||||
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
|
||||
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
|
||||
|
||||
gen_for_range_callback(
|
||||
generator,
|
||||
ctx,
|
||||
false,
|
||||
|_, _| Ok(start),
|
||||
(|_, _| Ok(stop), true),
|
||||
|_, _| Ok(step),
|
||||
|generator, ctx, src_i| {
|
||||
// Calculate the offset of the active slice
|
||||
let src_data_offset = ctx.builder.build_int_mul(
|
||||
src_stride,
|
||||
src_i,
|
||||
"",
|
||||
).unwrap();
|
||||
let dst_i = ctx.builder
|
||||
.build_load(dst_i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let dst_data_offset = ctx.builder.build_int_mul(
|
||||
dst_stride,
|
||||
dst_i,
|
||||
"",
|
||||
).unwrap();
|
||||
|
||||
let (src_ptr, dst_ptr) = unsafe {
|
||||
(
|
||||
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
|
||||
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
|
||||
)
|
||||
};
|
||||
|
||||
ndarray_sliced_copyto_impl(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
(dst_arr, dst_ptr),
|
||||
(src_arr, src_ptr),
|
||||
dim + 1,
|
||||
&slices[1..],
|
||||
)?;
|
||||
|
||||
let dst_i = ctx.builder
|
||||
.build_load(dst_i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let dst_i_add1 = ctx.builder
|
||||
.build_int_add(dst_i, llvm_usize.const_int(1, false), "")
|
||||
.unwrap();
|
||||
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Copies a [`NDArrayValue`] using slices.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
||||
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
|
||||
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
this: NDArrayValue<'ctx>,
|
||||
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let ndarray = if slices.is_empty() {
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&this,
|
||||
|_, ctx, shape| {
|
||||
Ok(shape.load_ndims(ctx))
|
||||
},
|
||||
|generator, ctx, shape, idx| {
|
||||
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
|
||||
},
|
||||
)?
|
||||
} else {
|
||||
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
||||
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
|
||||
|
||||
let ndims = this.load_ndims(ctx);
|
||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndims);
|
||||
|
||||
// Populate the first slices.len() dimensions by computing the size of each dim slice
|
||||
for (i, (start, stop, step)) in slices.iter().enumerate() {
|
||||
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
|
||||
let stop = ctx.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
*step,
|
||||
llvm_i32.const_zero(),
|
||||
"is_neg",
|
||||
).unwrap(),
|
||||
ctx.builder.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one").unwrap(),
|
||||
ctx.builder.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one").unwrap(),
|
||||
"final_e",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
|
||||
let slice_len = ctx.builder.build_int_z_extend_or_bit_cast(
|
||||
slice_len,
|
||||
llvm_usize,
|
||||
""
|
||||
).unwrap();
|
||||
|
||||
unsafe {
|
||||
ndarray.dim_sizes()
|
||||
.set_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, false),
|
||||
slice_len,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the rest by directly copying the dim size from the source array
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_usize.const_int(slices.len() as u64, false),
|
||||
(this.load_ndims(ctx), false),
|
||||
|generator, ctx, idx| {
|
||||
unsafe {
|
||||
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
|
||||
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
).unwrap();
|
||||
|
||||
ndarray_init_data(generator, ctx, elem_ty, ndarray)
|
||||
};
|
||||
|
||||
ndarray_sliced_copyto_impl(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
||||
(this, this.data().base_ptr(ctx, generator)),
|
||||
0,
|
||||
slices,
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
@ -654,45 +890,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
elem_ty: Type,
|
||||
this: NDArrayValue<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
||||
let ndarray = create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&this,
|
||||
|_, ctx, shape| {
|
||||
Ok(shape.load_ndims(ctx))
|
||||
},
|
||||
|generator, ctx, shape, idx| {
|
||||
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
|
||||
},
|
||||
)?;
|
||||
|
||||
let len = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||
(None, None),
|
||||
);
|
||||
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let len_bytes = ctx.builder
|
||||
.build_int_mul(
|
||||
len,
|
||||
sizeof_ty.size_of().unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.data().base_ptr(ctx, generator),
|
||||
this.data().base_ptr(ctx, generator),
|
||||
len_bytes,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
Ok(ndarray)
|
||||
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
||||
}
|
||||
|
||||
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
|
||||
|
Loading…
Reference in New Issue
Block a user