core/ndstrides: fix and rewrite is_c_contiguous

This commit is contained in:
lyken 2024-07-29 13:35:43 +08:00
parent d5880b119a
commit dfb8bf9748
2 changed files with 5 additions and 6 deletions

View File

@ -223,10 +223,9 @@ bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
return false;
}
for (SizeT i = 1; i < ndarray->ndims; i++) {
SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] !=
ndarray->shape[axis_i + 1] + ndarray->strides[axis_i + 1]) {
for (SizeT i = 0; i < ndarray->ndims - 1; i++) {
if (ndarray->strides[i] !=
ndarray->shape[i + 1] + ndarray->strides[i + 1]) {
return false;
}
}

View File

@ -121,7 +121,7 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, Bool> {
) {
let tyctx = generator.type_context(ctx.ctx);
CallFunction::begin(
@ -131,7 +131,7 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
)
.arg("src_ndarray", src_ndarray)
.arg("dst_ndarray", dst_ndarray)
.returning("is_c_contiguous")
.returning_void();
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(