forked from M-Labs/nac3
core/ndstrides: rewrite and fix np_reshape() bug
Data content should be copied and strides should be updated after negative indices are resolved.
This commit is contained in:
parent
dfb8bf9748
commit
28e6f23034
|
@ -5,8 +5,9 @@ use crate::{
|
|||
codegen::{
|
||||
irrt::ndarray::{
|
||||
basic::{
|
||||
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_nbytes,
|
||||
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
||||
call_nac3_ndarray_copy_data, call_nac3_ndarray_is_c_contiguous,
|
||||
call_nac3_ndarray_nbytes, call_nac3_ndarray_set_strides_by_shape,
|
||||
call_nac3_ndarray_size,
|
||||
},
|
||||
reshape::call_nac3_ndarray_resolve_and_check_new_shape,
|
||||
},
|
||||
|
@ -27,51 +28,24 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
|
|||
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||
new_shape: &ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>,
|
||||
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String> {
|
||||
/*
|
||||
Reference pseudo-code:
|
||||
```c
|
||||
NDArray<SizeT>* src_ndarray;
|
||||
|
||||
NDArray<SizeT>* dst_ndarray = __builtin_alloca(...);
|
||||
dst_ndarray->ndims = ...
|
||||
dst_ndarray->strides = __builtin_alloca(...);
|
||||
dst_ndarray->shape = ... // Directly set by user, may contain -1, or even illegal values.
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
set_strides_by_shape(dst_ndarray);
|
||||
|
||||
// Do assertions on `dst_ndarray->shape` and resolve -1
|
||||
|
||||
resolve_and_check_new_shape(ndarray_size(src_ndarray), dst_ndarray->shape);
|
||||
|
||||
if (is_c_contiguous(src_ndarray)) {
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
} else {
|
||||
dst_ndarray->data = __builtin_alloca( ndarray_nbytes(dst_ndarray) );
|
||||
copy_data(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
return dst_ndarray;
|
||||
```
|
||||
*/
|
||||
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
let byte_model = IntModel(Byte);
|
||||
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then");
|
||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
|
||||
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
||||
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
||||
|
||||
// Inserting into current_bb
|
||||
let dst_ndarray = alloca_ndarray(generator, ctx, new_shape.len, "ndarray").unwrap();
|
||||
|
||||
// Set shape - directly from user input
|
||||
init_ndarray_shape(generator, ctx, dst_ndarray, new_shape)?;
|
||||
dst_ndarray
|
||||
.gep(ctx, |f| f.itemsize)
|
||||
.store(ctx, src_ndarray.gep(ctx, |f| f.itemsize).load(tyctx, ctx, "itemsize"));
|
||||
|
||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray);
|
||||
|
||||
// Resolve shape input from user
|
||||
let src_ndarray_size = call_nac3_ndarray_size(generator, ctx, src_ndarray);
|
||||
call_nac3_ndarray_resolve_and_check_new_shape(
|
||||
generator,
|
||||
|
@ -81,6 +55,9 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
|
|||
dst_ndarray.gep(ctx, |f| f.shape).load(tyctx, ctx, "shape"),
|
||||
);
|
||||
|
||||
// Update strides
|
||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray);
|
||||
|
||||
let is_c_contiguous = call_nac3_ndarray_is_c_contiguous(generator, ctx, src_ndarray);
|
||||
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
||||
|
||||
|
@ -93,9 +70,12 @@ fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
|
|||
|
||||
// Inserting into else_bb: reshape is impossible without copying
|
||||
ctx.builder.position_at_end(else_bb);
|
||||
// Allocate data
|
||||
let dst_ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, dst_ndarray);
|
||||
let data = byte_model.array_alloca(tyctx, ctx, dst_ndarray_nbytes.value, "new_data");
|
||||
dst_ndarray.gep(ctx, |f| f.data).store(ctx, data);
|
||||
// Copy content
|
||||
call_nac3_ndarray_copy_data(generator, ctx, src_ndarray, dst_ndarray);
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
|
||||
// Reposition for continuation
|
||||
|
|
Loading…
Reference in New Issue