core/numpy: Implement ndarray_sliced_{copy,copyto_impl}
Performing copying with optional support for slicing. Also made copy_impl delegate to sliced_copy, as sliced_copy now performs a superset of operations that copy_impl can already do.
This commit is contained in:
parent
a176c3eb70
commit
135ef557f9
@ -9,6 +9,7 @@ use crate::{
|
|||||||
NDArrayValue,
|
NDArrayValue,
|
||||||
TypedArrayLikeAccessor,
|
TypedArrayLikeAccessor,
|
||||||
TypedArrayLikeAdapter,
|
TypedArrayLikeAdapter,
|
||||||
|
TypedArrayLikeMutator,
|
||||||
UntypedArrayLikeAccessor,
|
UntypedArrayLikeAccessor,
|
||||||
UntypedArrayLikeMutator,
|
UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
@ -16,6 +17,7 @@ use crate::{
|
|||||||
CodeGenerator,
|
CodeGenerator,
|
||||||
expr::gen_binop_expr_with_values,
|
expr::gen_binop_expr_with_values,
|
||||||
irrt::{
|
irrt::{
|
||||||
|
calculate_len_for_slice_range,
|
||||||
call_ndarray_calc_broadcast,
|
call_ndarray_calc_broadcast,
|
||||||
call_ndarray_calc_broadcast_index,
|
call_ndarray_calc_broadcast_index,
|
||||||
call_ndarray_calc_nd_indices,
|
call_ndarray_calc_nd_indices,
|
||||||
@ -23,7 +25,7 @@ use crate::{
|
|||||||
},
|
},
|
||||||
llvm_intrinsics,
|
llvm_intrinsics,
|
||||||
llvm_intrinsics::{call_memcpy_generic},
|
llvm_intrinsics::{call_memcpy_generic},
|
||||||
stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback},
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
@ -645,6 +647,240 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Copies a slice of an [`NDArrayValue`] to another.
|
||||||
|
///
|
||||||
|
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz`
|
||||||
|
/// fields should be populated before calling this function.
|
||||||
|
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
||||||
|
/// dimensional slice in the destination array.
|
||||||
|
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
|
||||||
|
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
||||||
|
/// dimensional slice in the source array.
|
||||||
|
/// - `dim`: The index of the currently processing dimension.
|
||||||
|
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
||||||
|
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
|
||||||
|
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||||
|
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||||
|
dim: u64,
|
||||||
|
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
// If there are no (remaining) slice expressions, memcpy the entire dimension
|
||||||
|
if slices.is_empty() {
|
||||||
|
let stride = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&src_arr.dim_sizes(),
|
||||||
|
(Some(llvm_usize.const_int(dim, false)), None),
|
||||||
|
);
|
||||||
|
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
|
||||||
|
let cpy_len = ctx.builder.build_int_mul(
|
||||||
|
stride,
|
||||||
|
sizeof_elem,
|
||||||
|
""
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
dst_slice_ptr,
|
||||||
|
src_slice_ptr,
|
||||||
|
cpy_len,
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
|
||||||
|
// arr[i + 1] in this dimension
|
||||||
|
let src_stride = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&src_arr.dim_sizes(),
|
||||||
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||||
|
);
|
||||||
|
let dst_stride = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&dst_arr.dim_sizes(),
|
||||||
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||||
|
);
|
||||||
|
|
||||||
|
let (start, stop, step) = slices[0];
|
||||||
|
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
|
||||||
|
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
|
||||||
|
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
|
||||||
|
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
|
||||||
|
|
||||||
|
gen_for_range_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
false,
|
||||||
|
|_, _| Ok(start),
|
||||||
|
(|_, _| Ok(stop), true),
|
||||||
|
|_, _| Ok(step),
|
||||||
|
|generator, ctx, src_i| {
|
||||||
|
// Calculate the offset of the active slice
|
||||||
|
let src_data_offset = ctx.builder.build_int_mul(
|
||||||
|
src_stride,
|
||||||
|
src_i,
|
||||||
|
"",
|
||||||
|
).unwrap();
|
||||||
|
let dst_i = ctx.builder
|
||||||
|
.build_load(dst_i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let dst_data_offset = ctx.builder.build_int_mul(
|
||||||
|
dst_stride,
|
||||||
|
dst_i,
|
||||||
|
"",
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
let (src_ptr, dst_ptr) = unsafe {
|
||||||
|
(
|
||||||
|
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
|
||||||
|
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
ndarray_sliced_copyto_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
(dst_arr, dst_ptr),
|
||||||
|
(src_arr, src_ptr),
|
||||||
|
dim + 1,
|
||||||
|
&slices[1..],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let dst_i = ctx.builder
|
||||||
|
.build_load(dst_i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let dst_i_add1 = ctx.builder
|
||||||
|
.build_int_add(dst_i, llvm_usize.const_int(1, false), "")
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copies a [`NDArrayValue`] using slices.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
||||||
|
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
|
||||||
|
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
this: NDArrayValue<'ctx>,
|
||||||
|
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
||||||
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let ndarray = if slices.is_empty() {
|
||||||
|
create_ndarray_dyn_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&this,
|
||||||
|
|_, ctx, shape| {
|
||||||
|
Ok(shape.load_ndims(ctx))
|
||||||
|
},
|
||||||
|
|generator, ctx, shape, idx| {
|
||||||
|
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
|
||||||
|
},
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
||||||
|
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
|
||||||
|
|
||||||
|
let ndims = this.load_ndims(ctx);
|
||||||
|
ndarray.create_dim_sizes(ctx, llvm_usize, ndims);
|
||||||
|
|
||||||
|
// Populate the first slices.len() dimensions by computing the size of each dim slice
|
||||||
|
for (i, (start, stop, step)) in slices.iter().enumerate() {
|
||||||
|
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
|
||||||
|
let stop = ctx.builder
|
||||||
|
.build_select(
|
||||||
|
ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
*step,
|
||||||
|
llvm_i32.const_zero(),
|
||||||
|
"is_neg",
|
||||||
|
).unwrap(),
|
||||||
|
ctx.builder.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one").unwrap(),
|
||||||
|
ctx.builder.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one").unwrap(),
|
||||||
|
"final_e",
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
|
||||||
|
let slice_len = ctx.builder.build_int_z_extend_or_bit_cast(
|
||||||
|
slice_len,
|
||||||
|
llvm_usize,
|
||||||
|
""
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ndarray.dim_sizes()
|
||||||
|
.set_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(i as u64, false),
|
||||||
|
slice_len,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the rest by directly copying the dim size from the source array
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(slices.len() as u64, false),
|
||||||
|
(this.load_ndims(ctx), false),
|
||||||
|
|generator, ctx, idx| {
|
||||||
|
unsafe {
|
||||||
|
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
|
||||||
|
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
ndarray_init_data(generator, ctx, elem_ty, ndarray)
|
||||||
|
};
|
||||||
|
|
||||||
|
ndarray_sliced_copyto_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
||||||
|
(this, this.data().base_ptr(ctx, generator)),
|
||||||
|
0,
|
||||||
|
slices,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
|
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
@ -654,45 +890,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
this: NDArrayValue<'ctx>,
|
this: NDArrayValue<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
||||||
|
|
||||||
let ndarray = create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&this,
|
|
||||||
|_, ctx, shape| {
|
|
||||||
Ok(shape.load_ndims(ctx))
|
|
||||||
},
|
|
||||||
|generator, ctx, shape, idx| {
|
|
||||||
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let len = call_ndarray_calc_size(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
|
||||||
(None, None),
|
|
||||||
);
|
|
||||||
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
let len_bytes = ctx.builder
|
|
||||||
.build_int_mul(
|
|
||||||
len,
|
|
||||||
sizeof_ty.size_of().unwrap(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
call_memcpy_generic(
|
|
||||||
ctx,
|
|
||||||
ndarray.data().base_ptr(ctx, generator),
|
|
||||||
this.data().base_ptr(ctx, generator),
|
|
||||||
len_bytes,
|
|
||||||
llvm_i1.const_zero(),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
|
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
|
||||||
|
Loading…
Reference in New Issue
Block a user