From 28e6f23034d8016860ba680aa3e5c3e2a9537fe7 Mon Sep 17 00:00:00 2001 From: lyken Date: Mon, 29 Jul 2024 13:40:10 +0800 Subject: [PATCH] core/ndstrides: rewrite and fix np_reshape() bug Data content should be copied and strides should be updated after negative indices are resolved. --- nac3core/src/codegen/numpy_new/view.rs | 44 +++++++------------------- 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/nac3core/src/codegen/numpy_new/view.rs b/nac3core/src/codegen/numpy_new/view.rs index df16431a..368adbfd 100644 --- a/nac3core/src/codegen/numpy_new/view.rs +++ b/nac3core/src/codegen/numpy_new/view.rs @@ -5,8 +5,9 @@ use crate::{ codegen::{ irrt::ndarray::{ basic::{ - call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_nbytes, - call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, + call_nac3_ndarray_copy_data, call_nac3_ndarray_is_c_contiguous, + call_nac3_ndarray_nbytes, call_nac3_ndarray_set_strides_by_shape, + call_nac3_ndarray_size, }, reshape::call_nac3_ndarray_resolve_and_check_new_shape, }, @@ -27,51 +28,24 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>( src_ndarray: Ptr<'ctx, StructModel>, new_shape: &ArrayWriter<'ctx, G, SizeT, IntModel>, ) -> Result>, String> { - /* - Reference pseudo-code: - ```c - NDArray* src_ndarray; - - NDArray* dst_ndarray = __builtin_alloca(...); - dst_ndarray->ndims = ... - dst_ndarray->strides = __builtin_alloca(...); - dst_ndarray->shape = ... // Directly set by user, may contain -1, or even illegal values. - dst_ndarray->itemsize = src_ndarray->itemsize; - set_strides_by_shape(dst_ndarray); - - // Do assertions on `dst_ndarray->shape` and resolve -1 - - resolve_and_check_new_shape(ndarray_size(src_ndarray), dst_ndarray->shape); - - if (is_c_contiguous(src_ndarray)) { - dst_ndarray->data = src_ndarray->data; - } else { - dst_ndarray->data = __builtin_alloca( ndarray_nbytes(dst_ndarray) ); - copy_data(src_ndarray, dst_ndarray); - } - - return dst_ndarray; - ``` - */ - let tyctx = generator.type_context(ctx.ctx); let byte_model = IntModel(Byte); let current_bb = ctx.builder.get_insert_block().unwrap(); - let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then"); + let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb"); let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb"); let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); // Inserting into current_bb let dst_ndarray = alloca_ndarray(generator, ctx, new_shape.len, "ndarray").unwrap(); + // Set shape - directly from user input init_ndarray_shape(generator, ctx, dst_ndarray, new_shape)?; dst_ndarray .gep(ctx, |f| f.itemsize) .store(ctx, src_ndarray.gep(ctx, |f| f.itemsize).load(tyctx, ctx, "itemsize")); - call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray); - + // Resolve shape input from user let src_ndarray_size = call_nac3_ndarray_size(generator, ctx, src_ndarray); call_nac3_ndarray_resolve_and_check_new_shape( generator, @@ -81,6 +55,9 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>( dst_ndarray.gep(ctx, |f| f.shape).load(tyctx, ctx, "shape"), ); + // Update strides + call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray); + let is_c_contiguous = call_nac3_ndarray_is_c_contiguous(generator, ctx, src_ndarray); ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap(); @@ -93,9 +70,12 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>( // Inserting into else_bb: reshape is impossible without copying ctx.builder.position_at_end(else_bb); + // Allocate data let dst_ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, dst_ndarray); let data = byte_model.array_alloca(tyctx, ctx, dst_ndarray_nbytes.value, "new_data"); dst_ndarray.gep(ctx, |f| f.data).store(ctx, data); + // Copy content + call_nac3_ndarray_copy_data(generator, ctx, src_ndarray, dst_ndarray); ctx.builder.build_unconditional_branch(end_bb).unwrap(); // Reposition for continuation