From bbd045d21602e43ff3945fdf0229471e9c20fc0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Violeta=20Hern=C3=A1ndez?= Date: Wed, 14 Jul 2021 23:30:31 -0500 Subject: [PATCH] `blas.rs` should be sound now --- src/base/array_storage.rs | 6 +- src/base/blas.rs | 149 +++++++++++++++++++++++++++++--------- src/base/construction.rs | 8 +- src/base/edition.rs | 21 ++++-- src/base/matrix.rs | 4 +- src/base/matrix_slice.rs | 4 +- src/base/ops.rs | 49 ++++++------- 7 files changed, 162 insertions(+), 79 deletions(-) diff --git a/src/base/array_storage.rs b/src/base/array_storage.rs index 09ac8a4b..b87442a4 100644 --- a/src/base/array_storage.rs +++ b/src/base/array_storage.rs @@ -108,7 +108,7 @@ where unsafe impl StorageMut, Const> for ArrayStorage where - DefaultAllocator:InnerAllocator, Const, Buffer = Self>, + DefaultAllocator: InnerAllocator, Const, Buffer = Self>, { #[inline] fn ptr_mut(&mut self) -> *mut T { @@ -124,14 +124,14 @@ where unsafe impl ContiguousStorage, Const> for ArrayStorage where - DefaultAllocator:InnerAllocator, Const, Buffer = Self>, + DefaultAllocator: InnerAllocator, Const, Buffer = Self>, { } unsafe impl ContiguousStorageMut, Const> for ArrayStorage where - DefaultAllocator:InnerAllocator, Const, Buffer = Self>, + DefaultAllocator: InnerAllocator, Const, Buffer = Self>, { } diff --git a/src/base/blas.rs b/src/base/blas.rs index b705c6c1..3b8ac951 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -1,10 +1,11 @@ -use crate::SimdComplexField; +use crate::{OVector, SimdComplexField}; #[cfg(feature = "std")] use matrixmultiply; use num::{One, Zero}; use simba::scalar::{ClosedAdd, ClosedMul}; #[cfg(feature = "std")] use std::mem; +use std::mem::MaybeUninit; use crate::base::allocator::Allocator; use crate::base::constraint::{ @@ -315,6 +316,28 @@ where } } +fn array_axc_uninit( + y: &mut [MaybeUninit], + a: T, + x: &[T], + c: T, + stride1: usize, + stride2: usize, + len: usize, +) where + T: Scalar + Zero + ClosedAdd + ClosedMul, +{ + for i in 0..len { + unsafe { + *y.get_unchecked_mut(i * stride1) = MaybeUninit::new( + a.inlined_clone() + * x.get_unchecked(i * stride2).inlined_clone() + * c.inlined_clone(), + ); + } + } +} + /// # BLAS functions impl Vector where @@ -723,6 +746,80 @@ where } } +impl OVector, D> +where + T: Scalar + Zero + ClosedAdd + ClosedMul, + DefaultAllocator: Allocator, +{ + pub fn axc(&mut self, a: T, x: &Vector, c: T) -> OVector + where + SB: Storage, + ShapeConstraint: DimEq, + { + assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes."); + + let rstride1 = self.strides().0; + let rstride2 = x.strides().0; + + unsafe { + // SAFETY: the conversion to slices is OK because we access the + // elements taking the strides into account. + let y = self.data.as_mut_slice_unchecked(); + let x = x.data.as_slice_unchecked(); + + array_axc_uninit(y, a, x, c, rstride1, rstride2, x.len()); + self.assume_init() + } + } + + /// Computes `self = alpha * a * x, where `a` is a matrix, `x` a vector, and + /// `alpha` is a scalar. + /// + /// By the time this method returns, `self` will have been initialized. + #[inline] + pub fn gemv_uninit( + mut self, + alpha: T, + a: &Matrix, + x: &Vector, + beta: T, + ) -> OVector + where + T: One, + SB: Storage, + SC: Storage, + ShapeConstraint: DimEq + AreMultipliable, + { + let dim1 = self.nrows(); + let (nrows2, ncols2) = a.shape(); + let dim3 = x.nrows(); + + assert!( + ncols2 == dim3 && dim1 == nrows2, + "Gemv: dimensions mismatch." + ); + + if ncols2 == 0 { + self.fill_fn(|| MaybeUninit::new(T::zero())); + return self.assume_init(); + } + + // TODO: avoid bound checks. + let col2 = a.column(0); + let val = unsafe { x.vget_unchecked(0).inlined_clone() }; + let res = self.axc(alpha.inlined_clone(), &col2, val); + + for j in 1..ncols2 { + let col2 = a.column(j); + let val = unsafe { x.vget_unchecked(j).inlined_clone() }; + + res.axcpy(alpha.inlined_clone(), &col2, val, T::one()); + } + + res + } +} + impl> Matrix where T: Scalar + Zero + ClosedAdd + ClosedMul, @@ -1275,29 +1372,25 @@ where /// /// mat.quadform_tr_with_workspace(&mut workspace, 10.0, &lhs, &mid, 5.0); /// assert_relative_eq!(mat, expected); - pub fn quadform_tr_with_workspace( + pub fn quadform_tr_with_workspace( &mut self, - work: &mut Vector, + work: &mut OVector, D2>, alpha: T, lhs: &Matrix, mid: &SquareMatrix, beta: T, ) where - D2: Dim, - R3: Dim, - C3: Dim, - D4: Dim, - S2: StorageMut, S3: Storage, S4: Storage, ShapeConstraint: DimEq + DimEq + DimEq + DimEq, + DefaultAllocator: Allocator, { - work.gemv(T::one(), lhs, &mid.column(0), T::zero()); - self.ger(alpha.inlined_clone(), work, &lhs.column(0), beta); + let work = work.gemv_uninit(T::one(), lhs, &mid.column(0), T::zero()); + self.ger(alpha.inlined_clone(), &work, &lhs.column(0), beta); for j in 1..mid.ncols() { work.gemv(T::one(), lhs, &mid.column(j), T::zero()); - self.ger(alpha.inlined_clone(), work, &lhs.column(j), T::one()); + self.ger(alpha.inlined_clone(), &work, &lhs.column(j), T::one()); } } @@ -1322,24 +1415,19 @@ where /// /// mat.quadform_tr(10.0, &lhs, &mid, 5.0); /// assert_relative_eq!(mat, expected); - pub fn quadform_tr( + pub fn quadform_tr( &mut self, alpha: T, lhs: &Matrix, mid: &SquareMatrix, beta: T, ) where - R3: Dim, - C3: Dim, - D4: Dim, S3: Storage, S4: Storage, ShapeConstraint: DimEq + DimEq + DimEq, DefaultAllocator: Allocator, { - let mut work = unsafe { - crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, Const::<1>) - }; + let mut work = Matrix::new_uninitialized_generic(self.data.shape().0, Const::<1>); self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta) } @@ -1368,32 +1456,28 @@ where /// /// mat.quadform_with_workspace(&mut workspace, 10.0, &mid, &rhs, 5.0); /// assert_relative_eq!(mat, expected); - pub fn quadform_with_workspace( + pub fn quadform_with_workspace( &mut self, - work: &mut Vector, + work: &mut OVector, D2>, alpha: T, mid: &SquareMatrix, rhs: &Matrix, beta: T, ) where - D2: Dim, - D3: Dim, - R4: Dim, - C4: Dim, - S2: StorageMut, S3: Storage, S4: Storage, ShapeConstraint: DimEq + DimEq + DimEq + AreMultipliable, + DefaultAllocator: Allocator, { - work.gemv(T::one(), mid, &rhs.column(0), T::zero()); + let work = work.gemv_uninit(T::one(), mid, &rhs.column(0), T::zero()); self.column_mut(0) - .gemv_tr(alpha.inlined_clone(), rhs, work, beta.inlined_clone()); + .gemv_tr(alpha.inlined_clone(), rhs, &work, beta.inlined_clone()); for j in 1..rhs.ncols() { work.gemv(T::one(), mid, &rhs.column(j), T::zero()); self.column_mut(j) - .gemv_tr(alpha.inlined_clone(), rhs, work, beta.inlined_clone()); + .gemv_tr(alpha.inlined_clone(), rhs, &work, beta.inlined_clone()); } } @@ -1417,24 +1501,19 @@ where /// /// mat.quadform(10.0, &mid, &rhs, 5.0); /// assert_relative_eq!(mat, expected); - pub fn quadform( + pub fn quadform( &mut self, alpha: T, mid: &SquareMatrix, rhs: &Matrix, beta: T, ) where - D2: Dim, - R3: Dim, - C3: Dim, S2: Storage, S3: Storage, ShapeConstraint: DimEq + DimEq + AreMultipliable, DefaultAllocator: Allocator, { - let mut work = unsafe { - crate::unimplemented_or_uninitialized_generic!(mid.data.shape().0, Const::<1>) - }; + let mut work = Matrix::new_uninitialized_generic(mid.data.shape().0, Const::<1>); self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta) } } diff --git a/src/base/construction.rs b/src/base/construction.rs index bb12cd45..c040a9dc 100644 --- a/src/base/construction.rs +++ b/src/base/construction.rs @@ -18,7 +18,7 @@ use typenum::{self, Cmp, Greater}; use simba::scalar::{ClosedAdd, ClosedMul}; -use crate::base::allocator::Allocator; +use crate::{base::allocator::Allocator}; use crate::base::dimension::{Dim, DimName, Dynamic, ToTypenum}; use crate::base::storage::Storage; use crate::base::{ @@ -117,7 +117,7 @@ where /// Creates a matrix with its elements filled with the components provided by a slice. The /// components must have the same layout as the matrix data storage (i.e. column-major). #[inline] - pub fn from_column_slice_generic(nrows: R, ncols: C, slice: &[T]) -> Self { + pub fn from_column_slice_generic(nrows: R, ncols: C, slice: &[T]) -> Self where T:Clone{ Self::from_iterator_generic(nrows, ncols, slice.iter().cloned()) } @@ -139,7 +139,7 @@ where } // Safety: all entries have been initialized. - unsafe { Matrix::assume_init(res) } + unsafe { res.assume_init()} } /// Creates a new identity matrix. @@ -352,7 +352,7 @@ where #[inline] pub fn from_diagonal>(diag: &Vector) -> Self where - T: Zero, + T: Zero+Scalar, { let (dim, _) = diag.data.shape(); let mut res = Self::zeros_generic(dim, dim); diff --git a/src/base/edition.rs b/src/base/edition.rs index f403f9d3..81e10b48 100644 --- a/src/base/edition.rs +++ b/src/base/edition.rs @@ -158,12 +158,23 @@ impl> Matrix { } /// # In-place filling -impl> Matrix { +impl> Matrix { /// Sets all the elements of this matrix to `val`. #[inline] - pub fn fill(&mut self, val: T) { + pub fn fill(&mut self, val: T) + where + T: Clone, + { for e in self.iter_mut() { - *e = val.inlined_clone() + *e = val.clone() + } + } + + /// Sets all the elements of this matrix to `f()`. + #[inline] + pub fn fill_fn T>(&mut self, f: F) { + for e in self.iter_mut() { + *e = f(); } } @@ -171,7 +182,7 @@ impl> Matrix { #[inline] pub fn fill_with_identity(&mut self) where - T: Zero + One, + T: Zero + One + Scalar, { self.fill(T::zero()); self.fill_diagonal(T::one()); @@ -184,7 +195,7 @@ impl> Matrix { let n = cmp::min(nrows, ncols); for i in 0..n { - unsafe { *self.get_unchecked_mut((i, i)) = val.inlined_clone() } + unsafe { *self.get_unchecked_mut((i, i)) = val.clone() } } } diff --git a/src/base/matrix.rs b/src/base/matrix.rs index 90668044..7e8f79cc 100644 --- a/src/base/matrix.rs +++ b/src/base/matrix.rs @@ -657,7 +657,7 @@ impl> Matrix { } } - unsafe { Matrix::assume_init(res) } + unsafe { res.assume_init()} } /// Transposes `self` and store the result into `out`, which will become @@ -666,7 +666,7 @@ impl> Matrix { pub fn transpose_to(&self, out: &mut Matrix, R2, C2, SB>) where T: Clone, - SB: StorageMut, + SB: StorageMut, R2, C2>, ShapeConstraint: SameNumberOfRows + SameNumberOfColumns, { let (nrows, ncols) = self.shape(); diff --git a/src/base/matrix_slice.rs b/src/base/matrix_slice.rs index cb142b5b..5f6bfd6f 100644 --- a/src/base/matrix_slice.rs +++ b/src/base/matrix_slice.rs @@ -2,12 +2,12 @@ use std::marker::PhantomData; use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo}; use std::slice; -use crate::base::allocator::Allocator; +use crate::base::allocator::{Allocator, InnerAllocator}; use crate::base::default_allocator::DefaultAllocator; use crate::base::dimension::{Const, Dim, DimName, Dynamic, IsNotStaticOne, U1}; use crate::base::iter::MatrixIter; use crate::base::storage::{ContiguousStorage, ContiguousStorageMut, Owned, Storage, StorageMut}; -use crate::base::{Matrix, Scalar}; +use crate::base::Matrix; macro_rules! slice_storage_impl( ($doc: expr; $Storage: ident as $SRef: ty; $T: ident.$get_addr: ident ($Ptr: ty as $Ref: ty)) => { diff --git a/src/base/ops.rs b/src/base/ops.rs index b52eb741..8da0249f 100644 --- a/src/base/ops.rs +++ b/src/base/ops.rs @@ -7,16 +7,17 @@ use std::ops::{ use simba::scalar::{ClosedAdd, ClosedDiv, ClosedMul, ClosedNeg, ClosedSub}; -use crate::allocator::InnerAllocator; -use crate::base::allocator::{Allocator, SameShapeAllocator, SameShapeC, SameShapeR}; +use crate::base::allocator::{ + Allocator, InnerAllocator, SameShapeAllocator, SameShapeC, SameShapeR, +}; use crate::base::constraint::{ AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint, }; use crate::base::dimension::{Dim, DimMul, DimName, DimProd, Dynamic}; use crate::base::storage::{ContiguousStorageMut, Storage, StorageMut}; use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorSlice}; -use crate::SimdComplexField; use crate::storage::Owned; +use crate::SimdComplexField; /* * @@ -431,7 +432,7 @@ where // TODO: we should take out this trait bound, as T: Clone should suffice. // The brute way to do it would be how it was already done: by adding this // trait bound on the associated type itself. - Owned: Clone, + Owned: Clone, { /// # Example /// ``` @@ -575,11 +576,9 @@ where #[inline] fn mul(self, rhs: &'b Matrix) -> Self::Output { - let mut res = unsafe { - crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, rhs.data.shape().1) - }; - self.mul_to(rhs, &mut res); - res + let mut res =Matrix::new_uninitialized_generic(self.data.shape().0, rhs.data.shape().1); + self.mul_to(rhs, &mut res); + unsafe{ res.assume_init()} } } @@ -687,12 +686,9 @@ where DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { - let mut res = unsafe { - crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1) - }; - + let mut res = Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1); self.tr_mul_to(rhs, &mut res); - res + unsafe { res.assume_init() } } /// Equivalent to `self.adjoint() * rhs`. @@ -701,30 +697,27 @@ where pub fn ad_mul(&self, rhs: &Matrix) -> OMatrix where T: SimdComplexField, - SB: Storage, + SB: Storage, R2, C2>, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { - let mut res = unsafe { - crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1) - }; - + let mut res = Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1); self.ad_mul_to(rhs, &mut res); - res + unsafe { res.assume_init() } } #[inline(always)] fn xx_mul_to( &self, rhs: &Matrix, - out: &mut Matrix, + out: &mut Matrix, R3, C3, SC>, dot: impl Fn( &VectorSlice, &VectorSlice, ) -> T, ) where SB: Storage, - SC: StorageMut, + SC: StorageMut, R3, C3>, ShapeConstraint: SameNumberOfRows + DimEq + DimEq, { let (nrows1, ncols1) = self.shape(); @@ -753,7 +746,7 @@ where for i in 0..ncols1 { for j in 0..ncols2 { let dot = dot(&self.column(i), &rhs.column(j)); - unsafe { *out.get_unchecked_mut((i, j)) = dot }; + unsafe { *out.get_unchecked_mut((i, j)) = MaybeUninit::new(dot) ;} } } } @@ -764,10 +757,10 @@ where pub fn tr_mul_to( &self, rhs: &Matrix, - out: &mut Matrix, + out: &mut Matrix, R3, C3, SC>, ) where SB: Storage, - SC: StorageMut, + SC: StorageMut, R3, C3>, ShapeConstraint: SameNumberOfRows + DimEq + DimEq, { self.xx_mul_to(rhs, out, |a, b| a.dot(b)) @@ -779,11 +772,11 @@ where pub fn ad_mul_to( &self, rhs: &Matrix, - out: &mut Matrix, + out: &mut Matrix, R3, C3, SC>, ) where T: SimdComplexField, SB: Storage, - SC: StorageMut, + SC: StorageMut, R3, C3>, ShapeConstraint: SameNumberOfRows + DimEq + DimEq, { self.xx_mul_to(rhs, out, |a, b| a.dotc(b)) @@ -793,7 +786,7 @@ where #[inline] pub fn mul_to( &self, - rhs: &Matrix, + rhs: &Matrix, R2, C2, SB>, out: &mut Matrix, ) where SB: Storage,