From 38add0b00df2bfba89b8e753e7cb16fcf4ae93c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Thu, 17 Jun 2021 09:46:49 +0200 Subject: [PATCH] Fix potential undoundness with Storage::as_slice and Storage::as_mut_slice (#905) --- src/base/array_storage.rs | 10 +++--- src/base/blas.rs | 16 +++++---- src/base/default_allocator.rs | 2 +- src/base/edition.rs | 2 +- src/base/matrix_slice.rs | 14 ++++---- src/base/ops.rs | 61 ++++++++++++++++------------------- src/base/storage.rs | 31 +++++++++++++++--- src/base/vec_storage.rs | 12 +++---- src/linalg/inverse.rs | 2 +- 9 files changed, 85 insertions(+), 65 deletions(-) diff --git a/src/base/array_storage.rs b/src/base/array_storage.rs index 6d681ed5..00713fb4 100644 --- a/src/base/array_storage.rs +++ b/src/base/array_storage.rs @@ -79,7 +79,7 @@ where } #[inline] - fn is_contiguous(&self) -> bool { + unsafe fn is_contiguous(&self) -> bool { true } @@ -101,8 +101,8 @@ where } #[inline] - fn as_slice(&self) -> &[T] { - unsafe { std::slice::from_raw_parts(self.ptr(), R * C) } + unsafe fn as_slice_unchecked(&self) -> &[T] { + std::slice::from_raw_parts(self.ptr(), R * C) } } @@ -118,8 +118,8 @@ where } #[inline] - fn as_mut_slice(&mut self) -> &mut [T] { - unsafe { std::slice::from_raw_parts_mut(self.ptr_mut(), R * C) } + unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] { + std::slice::from_raw_parts_mut(self.ptr_mut(), R * C) } } diff --git a/src/base/blas.rs b/src/base/blas.rs index 9617e46a..35133bb6 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -344,13 +344,17 @@ where let rstride1 = self.strides().0; let rstride2 = x.strides().0; - let y = self.data.as_mut_slice(); - let x = x.data.as_slice(); + unsafe { + // SAFETY: the conversion to slices is OK because we access the + // elements taking the strides into account. + let y = self.data.as_mut_slice_unchecked(); + let x = x.data.as_slice_unchecked(); - if !b.is_zero() { - array_axcpy(y, a, x, c, b, rstride1, rstride2, x.len()); - } else { - array_axc(y, a, x, c, rstride1, rstride2, x.len()); + if !b.is_zero() { + array_axcpy(y, a, x, c, b, rstride1, rstride2, x.len()); + } else { + array_axc(y, a, x, c, rstride1, rstride2, x.len()); + } } } diff --git a/src/base/default_allocator.rs b/src/base/default_allocator.rs index a9f994fe..4bfa11a8 100644 --- a/src/base/default_allocator.rs +++ b/src/base/default_allocator.rs @@ -16,7 +16,7 @@ use crate::base::array_storage::ArrayStorage; #[cfg(any(feature = "alloc", feature = "std"))] use crate::base::dimension::Dynamic; use crate::base::dimension::{Dim, DimName}; -use crate::base::storage::{Storage, StorageMut}; +use crate::base::storage::{ContiguousStorageMut, Storage, StorageMut}; #[cfg(any(feature = "std", feature = "alloc"))] use crate::base::vec_storage::VecStorage; use crate::base::Scalar; diff --git a/src/base/edition.rs b/src/base/edition.rs index 82c537fc..e7ef00d2 100644 --- a/src/base/edition.rs +++ b/src/base/edition.rs @@ -11,7 +11,7 @@ use crate::base::constraint::{DimEq, SameNumberOfColumns, SameNumberOfRows, Shap #[cfg(any(feature = "std", feature = "alloc"))] use crate::base::dimension::Dynamic; use crate::base::dimension::{Const, Dim, DimAdd, DimDiff, DimMin, DimMinimum, DimSub, DimSum, U1}; -use crate::base::storage::{ReshapableStorage, Storage, StorageMut}; +use crate::base::storage::{ContiguousStorageMut, ReshapableStorage, Storage, StorageMut}; use crate::base::{DefaultAllocator, Matrix, OMatrix, RowVector, Scalar, Vector}; /// # Rows and columns extraction diff --git a/src/base/matrix_slice.rs b/src/base/matrix_slice.rs index 129493d6..6cb713f1 100644 --- a/src/base/matrix_slice.rs +++ b/src/base/matrix_slice.rs @@ -132,7 +132,7 @@ macro_rules! storage_impl( } #[inline] - fn is_contiguous(&self) -> bool { + unsafe fn is_contiguous(&self) -> bool { // Common cases that can be deduced at compile-time even if one of the dimensions // is Dynamic. if (RStride::is::() && C::is::()) || // Column vector. @@ -162,14 +162,14 @@ macro_rules! storage_impl( } #[inline] - fn as_slice(&self) -> &[T] { + unsafe fn as_slice_unchecked(&self) -> &[T] { let (nrows, ncols) = self.shape(); if nrows.value() != 0 && ncols.value() != 0 { let sz = self.linear_index(nrows.value() - 1, ncols.value() - 1); - unsafe { slice::from_raw_parts(self.ptr, sz + 1) } + slice::from_raw_parts(self.ptr, sz + 1) } else { - unsafe { slice::from_raw_parts(self.ptr, 0) } + slice::from_raw_parts(self.ptr, 0) } } } @@ -187,13 +187,13 @@ unsafe impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> StorageMu } #[inline] - fn as_mut_slice(&mut self) -> &mut [T] { + unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] { let (nrows, ncols) = self.shape(); if nrows.value() != 0 && ncols.value() != 0 { let sz = self.linear_index(nrows.value() - 1, ncols.value() - 1); - unsafe { slice::from_raw_parts_mut(self.ptr, sz + 1) } + slice::from_raw_parts_mut(self.ptr, sz + 1) } else { - unsafe { slice::from_raw_parts_mut(self.ptr, 0) } + slice::from_raw_parts_mut(self.ptr, 0) } } } diff --git a/src/base/ops.rs b/src/base/ops.rs index d9be71a7..f0780f7f 100644 --- a/src/base/ops.rs +++ b/src/base/ops.rs @@ -158,20 +158,17 @@ macro_rules! componentwise_binop_impl( // This is the most common case and should be deduced at compile-time. // TODO: use specialization instead? - if self.data.is_contiguous() && rhs.data.is_contiguous() && out.data.is_contiguous() { - let arr1 = self.data.as_slice(); - let arr2 = rhs.data.as_slice(); - let out = out.data.as_mut_slice(); - for i in 0 .. arr1.len() { - unsafe { + unsafe { + if self.data.is_contiguous() && rhs.data.is_contiguous() && out.data.is_contiguous() { + let arr1 = self.data.as_slice_unchecked(); + let arr2 = rhs.data.as_slice_unchecked(); + let out = out.data.as_mut_slice_unchecked(); + for i in 0 .. arr1.len() { *out.get_unchecked_mut(i) = arr1.get_unchecked(i).inlined_clone().$method(arr2.get_unchecked(i).inlined_clone()); } - } - } - else { - for j in 0 .. self.ncols() { - for i in 0 .. self.nrows() { - unsafe { + } else { + for j in 0 .. self.ncols() { + for i in 0 .. self.nrows() { let val = self.get_unchecked((i, j)).inlined_clone().$method(rhs.get_unchecked((i, j)).inlined_clone()); *out.get_unchecked_mut((i, j)) = val; } @@ -191,19 +188,17 @@ macro_rules! componentwise_binop_impl( // This is the most common case and should be deduced at compile-time. // TODO: use specialization instead? - if self.data.is_contiguous() && rhs.data.is_contiguous() { - let arr1 = self.data.as_mut_slice(); - let arr2 = rhs.data.as_slice(); - for i in 0 .. arr2.len() { - unsafe { + unsafe { + if self.data.is_contiguous() && rhs.data.is_contiguous() { + let arr1 = self.data.as_mut_slice_unchecked(); + let arr2 = rhs.data.as_slice_unchecked(); + + for i in 0 .. arr2.len() { arr1.get_unchecked_mut(i).$method_assign(arr2.get_unchecked(i).inlined_clone()); } - } - } - else { - for j in 0 .. rhs.ncols() { - for i in 0 .. rhs.nrows() { - unsafe { + } else { + for j in 0 .. rhs.ncols() { + for i in 0 .. rhs.nrows() { self.get_unchecked_mut((i, j)).$method_assign(rhs.get_unchecked((i, j)).inlined_clone()) } } @@ -221,20 +216,18 @@ macro_rules! componentwise_binop_impl( // This is the most common case and should be deduced at compile-time. // TODO: use specialization instead? - if self.data.is_contiguous() && rhs.data.is_contiguous() { - let arr1 = self.data.as_slice(); - let arr2 = rhs.data.as_mut_slice(); - for i in 0 .. arr1.len() { - unsafe { + unsafe { + if self.data.is_contiguous() && rhs.data.is_contiguous() { + let arr1 = self.data.as_slice_unchecked(); + let arr2 = rhs.data.as_mut_slice_unchecked(); + + for i in 0 .. arr1.len() { let res = arr1.get_unchecked(i).inlined_clone().$method(arr2.get_unchecked(i).inlined_clone()); *arr2.get_unchecked_mut(i) = res; } - } - } - else { - for j in 0 .. self.ncols() { - for i in 0 .. self.nrows() { - unsafe { + } else { + for j in 0 .. self.ncols() { + for i in 0 .. self.nrows() { let r = rhs.get_unchecked_mut((i, j)); *r = self.get_unchecked((i, j)).inlined_clone().$method(r.inlined_clone()) } diff --git a/src/base/storage.rs b/src/base/storage.rs index 0238b36c..1e8d5dfb 100644 --- a/src/base/storage.rs +++ b/src/base/storage.rs @@ -94,12 +94,19 @@ pub unsafe trait Storage: Debug + Sized { } /// Indicates whether this data buffer stores its elements contiguously. - fn is_contiguous(&self) -> bool; + /// + /// This method is unsafe because unsafe code relies on this properties to performe + /// some low-lever optimizations. + unsafe fn is_contiguous(&self) -> bool; /// Retrieves the data buffer as a contiguous slice. /// /// The matrix components may not be stored in a contiguous way, depending on the strides. - fn as_slice(&self) -> &[T]; + /// This method is unsafe because this can yield to invalid aliasing when called on some pairs + /// of matrix slices originating from the same matrix with strides. + /// + /// Call the safe alternative `matrix.as_slice()` instead. + unsafe fn as_slice_unchecked(&self) -> &[T]; /// Builds a matrix data storage that does not contain any reference. fn into_owned(self) -> Owned @@ -165,8 +172,12 @@ pub unsafe trait StorageMut: Storage { /// Retrieves the mutable data buffer as a contiguous slice. /// - /// Matrix components may not be contiguous, depending on its strides. - fn as_mut_slice(&mut self) -> &mut [T]; + /// Matrix components may not be contiguous, depending on its strides. + /// + /// The matrix components may not be stored in a contiguous way, depending on the strides. + /// This method is unsafe because this can yield to invalid aliasing when called on some pairs + /// of matrix slices originating from the same matrix with strides. + unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T]; } /// A matrix storage that is stored contiguously in memory. @@ -177,6 +188,12 @@ pub unsafe trait StorageMut: Storage { pub unsafe trait ContiguousStorage: Storage { + /// Converts this data storage to a contiguous slice. + fn as_slice(&self) -> &[T] { + // SAFETY: this is safe because this trait guarantees the fact + // that the data is stored contiguously. + unsafe { self.as_slice_unchecked() } + } } /// A mutable matrix storage that is stored contiguously in memory. @@ -187,6 +204,12 @@ pub unsafe trait ContiguousStorage: pub unsafe trait ContiguousStorageMut: ContiguousStorage + StorageMut { + /// Converts this data storage to a contiguous mutable slice. + fn as_mut_slice(&mut self) -> &mut [T] { + // SAFETY: this is safe because this trait guarantees the fact + // that the data is stored contiguously. + unsafe { self.as_mut_slice_unchecked() } + } } /// A matrix storage that can be reshaped in-place. diff --git a/src/base/vec_storage.rs b/src/base/vec_storage.rs index 14490674..cc5f0ab3 100644 --- a/src/base/vec_storage.rs +++ b/src/base/vec_storage.rs @@ -178,7 +178,7 @@ where } #[inline] - fn is_contiguous(&self) -> bool { + unsafe fn is_contiguous(&self) -> bool { true } @@ -199,7 +199,7 @@ where } #[inline] - fn as_slice(&self) -> &[T] { + unsafe fn as_slice_unchecked(&self) -> &[T] { &self.data } } @@ -227,7 +227,7 @@ where } #[inline] - fn is_contiguous(&self) -> bool { + unsafe fn is_contiguous(&self) -> bool { true } @@ -248,7 +248,7 @@ where } #[inline] - fn as_slice(&self) -> &[T] { + unsafe fn as_slice_unchecked(&self) -> &[T] { &self.data } } @@ -268,7 +268,7 @@ where } #[inline] - fn as_mut_slice(&mut self) -> &mut [T] { + unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] { &mut self.data[..] } } @@ -329,7 +329,7 @@ where } #[inline] - fn as_mut_slice(&mut self) -> &mut [T] { + unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] { &mut self.data[..] } } diff --git a/src/linalg/inverse.rs b/src/linalg/inverse.rs index f56a95ec..28b148a1 100644 --- a/src/linalg/inverse.rs +++ b/src/linalg/inverse.rs @@ -127,7 +127,7 @@ fn do_inverse4>( where DefaultAllocator: Allocator, { - let m = m.data.as_slice(); + let m = m.as_slice(); out[(0, 0)] = m[5] * m[10] * m[15] - m[5] * m[11] * m[14] - m[9] * m[6] * m[15] + m[9] * m[7] * m[14]