1
0
forked from M-Labs/nac3

core/ndstrides: rewrite and fix np_reshape() bug

Data content should be copied and strides should be updated after
negative indices are resolved.
This commit is contained in:
lyken 2024-07-29 13:40:10 +08:00
parent dfb8bf9748
commit 28e6f23034

View File

@ -5,8 +5,9 @@ use crate::{
codegen::{
irrt::ndarray::{
basic::{
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_nbytes,
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
call_nac3_ndarray_copy_data, call_nac3_ndarray_is_c_contiguous,
call_nac3_ndarray_nbytes, call_nac3_ndarray_set_strides_by_shape,
call_nac3_ndarray_size,
},
reshape::call_nac3_ndarray_resolve_and_check_new_shape,
},
@ -27,51 +28,24 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
new_shape: &ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>,
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String> {
/*
Reference pseudo-code:
```c
NDArray<SizeT>* src_ndarray;
NDArray<SizeT>* dst_ndarray = __builtin_alloca(...);
dst_ndarray->ndims = ...
dst_ndarray->strides = __builtin_alloca(...);
dst_ndarray->shape = ... // Directly set by user, may contain -1, or even illegal values.
dst_ndarray->itemsize = src_ndarray->itemsize;
set_strides_by_shape(dst_ndarray);
// Do assertions on `dst_ndarray->shape` and resolve -1
resolve_and_check_new_shape(ndarray_size(src_ndarray), dst_ndarray->shape);
if (is_c_contiguous(src_ndarray)) {
dst_ndarray->data = src_ndarray->data;
} else {
dst_ndarray->data = __builtin_alloca( ndarray_nbytes(dst_ndarray) );
copy_data(src_ndarray, dst_ndarray);
}
return dst_ndarray;
```
*/
let tyctx = generator.type_context(ctx.ctx);
let byte_model = IntModel(Byte);
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then");
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
// Inserting into current_bb
let dst_ndarray = alloca_ndarray(generator, ctx, new_shape.len, "ndarray").unwrap();
// Set shape - directly from user input
init_ndarray_shape(generator, ctx, dst_ndarray, new_shape)?;
dst_ndarray
.gep(ctx, |f| f.itemsize)
.store(ctx, src_ndarray.gep(ctx, |f| f.itemsize).load(tyctx, ctx, "itemsize"));
call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray);
// Resolve shape input from user
let src_ndarray_size = call_nac3_ndarray_size(generator, ctx, src_ndarray);
call_nac3_ndarray_resolve_and_check_new_shape(
generator,
@ -81,6 +55,9 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
dst_ndarray.gep(ctx, |f| f.shape).load(tyctx, ctx, "shape"),
);
// Update strides
call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray);
let is_c_contiguous = call_nac3_ndarray_is_c_contiguous(generator, ctx, src_ndarray);
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
@ -93,9 +70,12 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
// Inserting into else_bb: reshape is impossible without copying
ctx.builder.position_at_end(else_bb);
// Allocate data
let dst_ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, dst_ndarray);
let data = byte_model.array_alloca(tyctx, ctx, dst_ndarray_nbytes.value, "new_data");
dst_ndarray.gep(ctx, |f| f.data).store(ctx, data);
// Copy content
call_nac3_ndarray_copy_data(generator, ctx, src_ndarray, dst_ndarray);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition for continuation