diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 1036d0451..68803f13d 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -9,6 +9,7 @@ use crate::{ NDArrayValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, @@ -16,6 +17,7 @@ use crate::{ CodeGenerator, expr::gen_binop_expr_with_values, irrt::{ + calculate_len_for_slice_range, call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, @@ -23,7 +25,7 @@ use crate::{ }, llvm_intrinsics, llvm_intrinsics::{call_memcpy_generic}, - stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback}, + stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, }, symbol_resolver::ValueEnum, toplevel::{ @@ -645,6 +647,240 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } +/// Copies a slice of an [`NDArrayValue`] to another. +/// +/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz` +/// fields should be populated before calling this function. +/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing +/// dimensional slice in the destination array. +/// - `src_arr`: The [`NDArrayValue`] instance of the source array. +/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing +/// dimensional slice in the source array. +/// - `dim`: The index of the currently processing dimension. +/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to +/// this dimension. The `start`/`stop` values of each slice must be non-negative indices. +fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), + (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), + dim: u64, + slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], +) -> Result<(), String> { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + // If there are no (remaining) slice expressions, memcpy the entire dimension + if slices.is_empty() { + let stride = call_ndarray_calc_size( + generator, + ctx, + &src_arr.dim_sizes(), + (Some(llvm_usize.const_int(dim, false)), None), + ); + let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); + let cpy_len = ctx.builder.build_int_mul( + stride, + sizeof_elem, + "" + ).unwrap(); + + call_memcpy_generic( + ctx, + dst_slice_ptr, + src_slice_ptr, + cpy_len, + llvm_i1.const_zero(), + ); + + return Ok(()) + } + + // The stride of elements in this dimension, i.e. the number of elements between arr[i] and + // arr[i + 1] in this dimension + let src_stride = call_ndarray_calc_size( + generator, + ctx, + &src_arr.dim_sizes(), + (Some(llvm_usize.const_int(dim + 1, false)), None), + ); + let dst_stride = call_ndarray_calc_size( + generator, + ctx, + &dst_arr.dim_sizes(), + (Some(llvm_usize.const_int(dim + 1, false)), None), + ); + + let (start, stop, step) = slices[0]; + let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap(); + let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap(); + let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap(); + + let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap(); + ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap(); + + gen_for_range_callback( + generator, + ctx, + false, + |_, _| Ok(start), + (|_, _| Ok(stop), true), + |_, _| Ok(step), + |generator, ctx, src_i| { + // Calculate the offset of the active slice + let src_data_offset = ctx.builder.build_int_mul( + src_stride, + src_i, + "", + ).unwrap(); + let dst_i = ctx.builder + .build_load(dst_i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let dst_data_offset = ctx.builder.build_int_mul( + dst_stride, + dst_i, + "", + ).unwrap(); + + let (src_ptr, dst_ptr) = unsafe { + ( + ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(), + ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(), + ) + }; + + ndarray_sliced_copyto_impl( + generator, + ctx, + elem_ty, + (dst_arr, dst_ptr), + (src_arr, src_ptr), + dim + 1, + &slices[1..], + )?; + + let dst_i = ctx.builder + .build_load(dst_i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let dst_i_add1 = ctx.builder + .build_int_add(dst_i, llvm_usize.const_int(1, false), "") + .unwrap(); + ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); + + Ok(()) + }, + )?; + + Ok(()) +} + +/// Copies a [`NDArrayValue`] using slices. +/// +/// * `elem_ty` - The element type of the `NDArray`. +/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to +/// this dimension. The `start`/`stop` values of each slice must be positive indices. +pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + this: NDArrayValue<'ctx>, + slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], +) -> Result, String> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let ndarray = if slices.is_empty() { + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &this, + |_, ctx, shape| { + Ok(shape.load_ndims(ctx)) + }, + |generator, ctx, shape, idx| { + unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) } + }, + )? + } else { + let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; + ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); + + let ndims = this.load_ndims(ctx); + ndarray.create_dim_sizes(ctx, llvm_usize, ndims); + + // Populate the first slices.len() dimensions by computing the size of each dim slice + for (i, (start, stop, step)) in slices.iter().enumerate() { + // HACK: workaround calculate_len_for_slice_range requiring exclusive stop + let stop = ctx.builder + .build_select( + ctx.builder.build_int_compare( + IntPredicate::SLT, + *step, + llvm_i32.const_zero(), + "is_neg", + ).unwrap(), + ctx.builder.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one").unwrap(), + ctx.builder.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one").unwrap(), + "final_e", + ) + .map(BasicValueEnum::into_int_value) + .unwrap(); + + let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); + let slice_len = ctx.builder.build_int_z_extend_or_bit_cast( + slice_len, + llvm_usize, + "" + ).unwrap(); + + unsafe { + ndarray.dim_sizes() + .set_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(i as u64, false), + slice_len, + ); + } + } + + // Populate the rest by directly copying the dim size from the source array + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_int(slices.len() as u64, false), + (this.load_ndims(ctx), false), + |generator, ctx, idx| { + unsafe { + let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); + ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); + } + + Ok(()) + }, + llvm_usize.const_int(1, false), + ).unwrap(); + + ndarray_init_data(generator, ctx, elem_ty, ndarray) + }; + + ndarray_sliced_copyto_impl( + generator, + ctx, + elem_ty, + (ndarray, ndarray.data().base_ptr(ctx, generator)), + (this, this.data().base_ptr(ctx, generator)), + 0, + slices, + )?; + + Ok(ndarray) +} + /// LLVM-typed implementation for generating the implementation for `ndarray.copy`. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -654,45 +890,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( elem_ty: Type, this: NDArrayValue<'ctx>, ) -> Result, String> { - let llvm_i1 = ctx.ctx.bool_type(); - - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &this, - |_, ctx, shape| { - Ok(shape.load_ndims(ctx)) - }, - |generator, ctx, shape, idx| { - unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) } - }, - )?; - - let len = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - let sizeof_ty = ctx.get_llvm_type(generator, elem_ty); - let len_bytes = ctx.builder - .build_int_mul( - len, - sizeof_ty.size_of().unwrap(), - "", - ) - .unwrap(); - - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - this.data().base_ptr(ctx, generator), - len_bytes, - llvm_i1.const_zero(), - ); - - Ok(ndarray) + ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) } pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(