`blas.rs` should be sound now

This commit is contained in:
Violeta Hernández 2021-07-14 23:30:31 -05:00
parent 775917142b
commit bbd045d216
7 changed files with 162 additions and 79 deletions

View File

@ -108,7 +108,7 @@ where
unsafe impl<T, const R: usize, const C: usize> StorageMut<T, Const<R>, Const<C>> unsafe impl<T, const R: usize, const C: usize> StorageMut<T, Const<R>, Const<C>>
for ArrayStorage<T, R, C> for ArrayStorage<T, R, C>
where where
DefaultAllocator:InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>, DefaultAllocator: InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
{ {
#[inline] #[inline]
fn ptr_mut(&mut self) -> *mut T { fn ptr_mut(&mut self) -> *mut T {
@ -124,14 +124,14 @@ where
unsafe impl<T, const R: usize, const C: usize> ContiguousStorage<T, Const<R>, Const<C>> unsafe impl<T, const R: usize, const C: usize> ContiguousStorage<T, Const<R>, Const<C>>
for ArrayStorage<T, R, C> for ArrayStorage<T, R, C>
where where
DefaultAllocator:InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>, DefaultAllocator: InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
{ {
} }
unsafe impl<T, const R: usize, const C: usize> ContiguousStorageMut<T, Const<R>, Const<C>> unsafe impl<T, const R: usize, const C: usize> ContiguousStorageMut<T, Const<R>, Const<C>>
for ArrayStorage<T, R, C> for ArrayStorage<T, R, C>
where where
DefaultAllocator:InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>, DefaultAllocator: InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
{ {
} }

View File

@ -1,10 +1,11 @@
use crate::SimdComplexField; use crate::{OVector, SimdComplexField};
#[cfg(feature = "std")] #[cfg(feature = "std")]
use matrixmultiply; use matrixmultiply;
use num::{One, Zero}; use num::{One, Zero};
use simba::scalar::{ClosedAdd, ClosedMul}; use simba::scalar::{ClosedAdd, ClosedMul};
#[cfg(feature = "std")] #[cfg(feature = "std")]
use std::mem; use std::mem;
use std::mem::MaybeUninit;
use crate::base::allocator::Allocator; use crate::base::allocator::Allocator;
use crate::base::constraint::{ use crate::base::constraint::{
@ -315,6 +316,28 @@ where
} }
} }
fn array_axc_uninit<T>(
y: &mut [MaybeUninit<T>],
a: T,
x: &[T],
c: T,
stride1: usize,
stride2: usize,
len: usize,
) where
T: Scalar + Zero + ClosedAdd + ClosedMul,
{
for i in 0..len {
unsafe {
*y.get_unchecked_mut(i * stride1) = MaybeUninit::new(
a.inlined_clone()
* x.get_unchecked(i * stride2).inlined_clone()
* c.inlined_clone(),
);
}
}
}
/// # BLAS functions /// # BLAS functions
impl<T, D: Dim, S> Vector<T, D, S> impl<T, D: Dim, S> Vector<T, D, S>
where where
@ -723,6 +746,80 @@ where
} }
} }
impl<T, D: Dim> OVector<MaybeUninit<T>, D>
where
T: Scalar + Zero + ClosedAdd + ClosedMul,
DefaultAllocator: Allocator<T, D>,
{
pub fn axc<D2: Dim, SB>(&mut self, a: T, x: &Vector<T, D2, SB>, c: T) -> OVector<T, D>
where
SB: Storage<T, D2>,
ShapeConstraint: DimEq<D, D2>,
{
assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
let rstride1 = self.strides().0;
let rstride2 = x.strides().0;
unsafe {
// SAFETY: the conversion to slices is OK because we access the
// elements taking the strides into account.
let y = self.data.as_mut_slice_unchecked();
let x = x.data.as_slice_unchecked();
array_axc_uninit(y, a, x, c, rstride1, rstride2, x.len());
self.assume_init()
}
}
/// Computes `self = alpha * a * x, where `a` is a matrix, `x` a vector, and
/// `alpha` is a scalar.
///
/// By the time this method returns, `self` will have been initialized.
#[inline]
pub fn gemv_uninit<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
mut self,
alpha: T,
a: &Matrix<T, R2, C2, SB>,
x: &Vector<T, D3, SC>,
beta: T,
) -> OVector<T, D>
where
T: One,
SB: Storage<T, R2, C2>,
SC: Storage<T, D3>,
ShapeConstraint: DimEq<D, R2> + AreMultipliable<R2, C2, D3, U1>,
{
let dim1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let dim3 = x.nrows();
assert!(
ncols2 == dim3 && dim1 == nrows2,
"Gemv: dimensions mismatch."
);
if ncols2 == 0 {
self.fill_fn(|| MaybeUninit::new(T::zero()));
return self.assume_init();
}
// TODO: avoid bound checks.
let col2 = a.column(0);
let val = unsafe { x.vget_unchecked(0).inlined_clone() };
let res = self.axc(alpha.inlined_clone(), &col2, val);
for j in 1..ncols2 {
let col2 = a.column(j);
let val = unsafe { x.vget_unchecked(j).inlined_clone() };
res.axcpy(alpha.inlined_clone(), &col2, val, T::one());
}
res
}
}
impl<T, R1: Dim, C1: Dim, S: StorageMut<T, R1, C1>> Matrix<T, R1, C1, S> impl<T, R1: Dim, C1: Dim, S: StorageMut<T, R1, C1>> Matrix<T, R1, C1, S>
where where
T: Scalar + Zero + ClosedAdd + ClosedMul, T: Scalar + Zero + ClosedAdd + ClosedMul,
@ -1275,29 +1372,25 @@ where
/// ///
/// mat.quadform_tr_with_workspace(&mut workspace, 10.0, &lhs, &mid, 5.0); /// mat.quadform_tr_with_workspace(&mut workspace, 10.0, &lhs, &mid, 5.0);
/// assert_relative_eq!(mat, expected); /// assert_relative_eq!(mat, expected);
pub fn quadform_tr_with_workspace<D2, S2, R3, C3, S3, D4, S4>( pub fn quadform_tr_with_workspace<D2: Dim, R3: Dim, C3: Dim, S3, D4: Dim, S4>(
&mut self, &mut self,
work: &mut Vector<T, D2, S2>, work: &mut OVector<MaybeUninit<T>, D2>,
alpha: T, alpha: T,
lhs: &Matrix<T, R3, C3, S3>, lhs: &Matrix<T, R3, C3, S3>,
mid: &SquareMatrix<T, D4, S4>, mid: &SquareMatrix<T, D4, S4>,
beta: T, beta: T,
) where ) where
D2: Dim,
R3: Dim,
C3: Dim,
D4: Dim,
S2: StorageMut<T, D2>,
S3: Storage<T, R3, C3>, S3: Storage<T, R3, C3>,
S4: Storage<T, D4, D4>, S4: Storage<T, D4, D4>,
ShapeConstraint: DimEq<D1, D2> + DimEq<D1, R3> + DimEq<D2, R3> + DimEq<C3, D4>, ShapeConstraint: DimEq<D1, D2> + DimEq<D1, R3> + DimEq<D2, R3> + DimEq<C3, D4>,
DefaultAllocator: Allocator<T, D2>,
{ {
work.gemv(T::one(), lhs, &mid.column(0), T::zero()); let work = work.gemv_uninit(T::one(), lhs, &mid.column(0), T::zero());
self.ger(alpha.inlined_clone(), work, &lhs.column(0), beta); self.ger(alpha.inlined_clone(), &work, &lhs.column(0), beta);
for j in 1..mid.ncols() { for j in 1..mid.ncols() {
work.gemv(T::one(), lhs, &mid.column(j), T::zero()); work.gemv(T::one(), lhs, &mid.column(j), T::zero());
self.ger(alpha.inlined_clone(), work, &lhs.column(j), T::one()); self.ger(alpha.inlined_clone(), &work, &lhs.column(j), T::one());
} }
} }
@ -1322,24 +1415,19 @@ where
/// ///
/// mat.quadform_tr(10.0, &lhs, &mid, 5.0); /// mat.quadform_tr(10.0, &lhs, &mid, 5.0);
/// assert_relative_eq!(mat, expected); /// assert_relative_eq!(mat, expected);
pub fn quadform_tr<R3, C3, S3, D4, S4>( pub fn quadform_tr<R3: Dim, C3: Dim, S3, D4: Dim, S4>(
&mut self, &mut self,
alpha: T, alpha: T,
lhs: &Matrix<T, R3, C3, S3>, lhs: &Matrix<T, R3, C3, S3>,
mid: &SquareMatrix<T, D4, S4>, mid: &SquareMatrix<T, D4, S4>,
beta: T, beta: T,
) where ) where
R3: Dim,
C3: Dim,
D4: Dim,
S3: Storage<T, R3, C3>, S3: Storage<T, R3, C3>,
S4: Storage<T, D4, D4>, S4: Storage<T, D4, D4>,
ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>, ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>,
DefaultAllocator: Allocator<T, D1>, DefaultAllocator: Allocator<T, D1>,
{ {
let mut work = unsafe { let mut work = Matrix::new_uninitialized_generic(self.data.shape().0, Const::<1>);
crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, Const::<1>)
};
self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta) self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta)
} }
@ -1368,32 +1456,28 @@ where
/// ///
/// mat.quadform_with_workspace(&mut workspace, 10.0, &mid, &rhs, 5.0); /// mat.quadform_with_workspace(&mut workspace, 10.0, &mid, &rhs, 5.0);
/// assert_relative_eq!(mat, expected); /// assert_relative_eq!(mat, expected);
pub fn quadform_with_workspace<D2, S2, D3, S3, R4, C4, S4>( pub fn quadform_with_workspace<D2: Dim, D3: Dim, S3, R4: Dim, C4: Dim, S4>(
&mut self, &mut self,
work: &mut Vector<T, D2, S2>, work: &mut OVector<MaybeUninit<T>, D2>,
alpha: T, alpha: T,
mid: &SquareMatrix<T, D3, S3>, mid: &SquareMatrix<T, D3, S3>,
rhs: &Matrix<T, R4, C4, S4>, rhs: &Matrix<T, R4, C4, S4>,
beta: T, beta: T,
) where ) where
D2: Dim,
D3: Dim,
R4: Dim,
C4: Dim,
S2: StorageMut<T, D2>,
S3: Storage<T, D3, D3>, S3: Storage<T, D3, D3>,
S4: Storage<T, R4, C4>, S4: Storage<T, R4, C4>,
ShapeConstraint: ShapeConstraint:
DimEq<D3, R4> + DimEq<D1, C4> + DimEq<D2, D3> + AreMultipliable<C4, R4, D2, U1>, DimEq<D3, R4> + DimEq<D1, C4> + DimEq<D2, D3> + AreMultipliable<C4, R4, D2, U1>,
DefaultAllocator: Allocator<T, D2>,
{ {
work.gemv(T::one(), mid, &rhs.column(0), T::zero()); let work = work.gemv_uninit(T::one(), mid, &rhs.column(0), T::zero());
self.column_mut(0) self.column_mut(0)
.gemv_tr(alpha.inlined_clone(), rhs, work, beta.inlined_clone()); .gemv_tr(alpha.inlined_clone(), rhs, &work, beta.inlined_clone());
for j in 1..rhs.ncols() { for j in 1..rhs.ncols() {
work.gemv(T::one(), mid, &rhs.column(j), T::zero()); work.gemv(T::one(), mid, &rhs.column(j), T::zero());
self.column_mut(j) self.column_mut(j)
.gemv_tr(alpha.inlined_clone(), rhs, work, beta.inlined_clone()); .gemv_tr(alpha.inlined_clone(), rhs, &work, beta.inlined_clone());
} }
} }
@ -1417,24 +1501,19 @@ where
/// ///
/// mat.quadform(10.0, &mid, &rhs, 5.0); /// mat.quadform(10.0, &mid, &rhs, 5.0);
/// assert_relative_eq!(mat, expected); /// assert_relative_eq!(mat, expected);
pub fn quadform<D2, S2, R3, C3, S3>( pub fn quadform<D2: Dim, S2, R3: Dim, C3: Dim, S3>(
&mut self, &mut self,
alpha: T, alpha: T,
mid: &SquareMatrix<T, D2, S2>, mid: &SquareMatrix<T, D2, S2>,
rhs: &Matrix<T, R3, C3, S3>, rhs: &Matrix<T, R3, C3, S3>,
beta: T, beta: T,
) where ) where
D2: Dim,
R3: Dim,
C3: Dim,
S2: Storage<T, D2, D2>, S2: Storage<T, D2, D2>,
S3: Storage<T, R3, C3>, S3: Storage<T, R3, C3>,
ShapeConstraint: DimEq<D2, R3> + DimEq<D1, C3> + AreMultipliable<C3, R3, D2, U1>, ShapeConstraint: DimEq<D2, R3> + DimEq<D1, C3> + AreMultipliable<C3, R3, D2, U1>,
DefaultAllocator: Allocator<T, D2>, DefaultAllocator: Allocator<T, D2>,
{ {
let mut work = unsafe { let mut work = Matrix::new_uninitialized_generic(mid.data.shape().0, Const::<1>);
crate::unimplemented_or_uninitialized_generic!(mid.data.shape().0, Const::<1>)
};
self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta) self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta)
} }
} }

View File

@ -18,7 +18,7 @@ use typenum::{self, Cmp, Greater};
use simba::scalar::{ClosedAdd, ClosedMul}; use simba::scalar::{ClosedAdd, ClosedMul};
use crate::base::allocator::Allocator; use crate::{base::allocator::Allocator};
use crate::base::dimension::{Dim, DimName, Dynamic, ToTypenum}; use crate::base::dimension::{Dim, DimName, Dynamic, ToTypenum};
use crate::base::storage::Storage; use crate::base::storage::Storage;
use crate::base::{ use crate::base::{
@ -117,7 +117,7 @@ where
/// Creates a matrix with its elements filled with the components provided by a slice. The /// Creates a matrix with its elements filled with the components provided by a slice. The
/// components must have the same layout as the matrix data storage (i.e. column-major). /// components must have the same layout as the matrix data storage (i.e. column-major).
#[inline] #[inline]
pub fn from_column_slice_generic(nrows: R, ncols: C, slice: &[T]) -> Self { pub fn from_column_slice_generic(nrows: R, ncols: C, slice: &[T]) -> Self where T:Clone{
Self::from_iterator_generic(nrows, ncols, slice.iter().cloned()) Self::from_iterator_generic(nrows, ncols, slice.iter().cloned())
} }
@ -139,7 +139,7 @@ where
} }
// Safety: all entries have been initialized. // Safety: all entries have been initialized.
unsafe { Matrix::assume_init(res) } unsafe { res.assume_init()}
} }
/// Creates a new identity matrix. /// Creates a new identity matrix.
@ -352,7 +352,7 @@ where
#[inline] #[inline]
pub fn from_diagonal<SB: Storage<T, D>>(diag: &Vector<T, D, SB>) -> Self pub fn from_diagonal<SB: Storage<T, D>>(diag: &Vector<T, D, SB>) -> Self
where where
T: Zero, T: Zero+Scalar,
{ {
let (dim, _) = diag.data.shape(); let (dim, _) = diag.data.shape();
let mut res = Self::zeros_generic(dim, dim); let mut res = Self::zeros_generic(dim, dim);

View File

@ -158,12 +158,23 @@ impl<T: Scalar, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
} }
/// # In-place filling /// # In-place filling
impl<T: Scalar, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> { impl<T, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
/// Sets all the elements of this matrix to `val`. /// Sets all the elements of this matrix to `val`.
#[inline] #[inline]
pub fn fill(&mut self, val: T) { pub fn fill(&mut self, val: T)
where
T: Clone,
{
for e in self.iter_mut() { for e in self.iter_mut() {
*e = val.inlined_clone() *e = val.clone()
}
}
/// Sets all the elements of this matrix to `f()`.
#[inline]
pub fn fill_fn<F: FnMut() -> T>(&mut self, f: F) {
for e in self.iter_mut() {
*e = f();
} }
} }
@ -171,7 +182,7 @@ impl<T: Scalar, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
#[inline] #[inline]
pub fn fill_with_identity(&mut self) pub fn fill_with_identity(&mut self)
where where
T: Zero + One, T: Zero + One + Scalar,
{ {
self.fill(T::zero()); self.fill(T::zero());
self.fill_diagonal(T::one()); self.fill_diagonal(T::one());
@ -184,7 +195,7 @@ impl<T: Scalar, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
let n = cmp::min(nrows, ncols); let n = cmp::min(nrows, ncols);
for i in 0..n { for i in 0..n {
unsafe { *self.get_unchecked_mut((i, i)) = val.inlined_clone() } unsafe { *self.get_unchecked_mut((i, i)) = val.clone() }
} }
} }

View File

@ -657,7 +657,7 @@ impl<T, R: Dim, C: Dim, S: Storage<T, R, C>> Matrix<T, R, C, S> {
} }
} }
unsafe { Matrix::assume_init(res) } unsafe { res.assume_init()}
} }
/// Transposes `self` and store the result into `out`, which will become /// Transposes `self` and store the result into `out`, which will become
@ -666,7 +666,7 @@ impl<T, R: Dim, C: Dim, S: Storage<T, R, C>> Matrix<T, R, C, S> {
pub fn transpose_to<R2: Dim, C2: Dim, SB>(&self, out: &mut Matrix<MaybeUninit<T>, R2, C2, SB>) pub fn transpose_to<R2: Dim, C2: Dim, SB>(&self, out: &mut Matrix<MaybeUninit<T>, R2, C2, SB>)
where where
T: Clone, T: Clone,
SB: StorageMut<T, R2, C2>, SB: StorageMut<MaybeUninit<T>, R2, C2>,
ShapeConstraint: SameNumberOfRows<R, C2> + SameNumberOfColumns<C, R2>, ShapeConstraint: SameNumberOfRows<R, C2> + SameNumberOfColumns<C, R2>,
{ {
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();

View File

@ -2,12 +2,12 @@ use std::marker::PhantomData;
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo}; use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo};
use std::slice; use std::slice;
use crate::base::allocator::Allocator; use crate::base::allocator::{Allocator, InnerAllocator};
use crate::base::default_allocator::DefaultAllocator; use crate::base::default_allocator::DefaultAllocator;
use crate::base::dimension::{Const, Dim, DimName, Dynamic, IsNotStaticOne, U1}; use crate::base::dimension::{Const, Dim, DimName, Dynamic, IsNotStaticOne, U1};
use crate::base::iter::MatrixIter; use crate::base::iter::MatrixIter;
use crate::base::storage::{ContiguousStorage, ContiguousStorageMut, Owned, Storage, StorageMut}; use crate::base::storage::{ContiguousStorage, ContiguousStorageMut, Owned, Storage, StorageMut};
use crate::base::{Matrix, Scalar}; use crate::base::Matrix;
macro_rules! slice_storage_impl( macro_rules! slice_storage_impl(
($doc: expr; $Storage: ident as $SRef: ty; $T: ident.$get_addr: ident ($Ptr: ty as $Ref: ty)) => { ($doc: expr; $Storage: ident as $SRef: ty; $T: ident.$get_addr: ident ($Ptr: ty as $Ref: ty)) => {

View File

@ -7,16 +7,17 @@ use std::ops::{
use simba::scalar::{ClosedAdd, ClosedDiv, ClosedMul, ClosedNeg, ClosedSub}; use simba::scalar::{ClosedAdd, ClosedDiv, ClosedMul, ClosedNeg, ClosedSub};
use crate::allocator::InnerAllocator; use crate::base::allocator::{
use crate::base::allocator::{Allocator, SameShapeAllocator, SameShapeC, SameShapeR}; Allocator, InnerAllocator, SameShapeAllocator, SameShapeC, SameShapeR,
};
use crate::base::constraint::{ use crate::base::constraint::{
AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint, AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
}; };
use crate::base::dimension::{Dim, DimMul, DimName, DimProd, Dynamic}; use crate::base::dimension::{Dim, DimMul, DimName, DimProd, Dynamic};
use crate::base::storage::{ContiguousStorageMut, Storage, StorageMut}; use crate::base::storage::{ContiguousStorageMut, Storage, StorageMut};
use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorSlice}; use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorSlice};
use crate::SimdComplexField;
use crate::storage::Owned; use crate::storage::Owned;
use crate::SimdComplexField;
/* /*
* *
@ -431,7 +432,7 @@ where
// TODO: we should take out this trait bound, as T: Clone should suffice. // TODO: we should take out this trait bound, as T: Clone should suffice.
// The brute way to do it would be how it was already done: by adding this // The brute way to do it would be how it was already done: by adding this
// trait bound on the associated type itself. // trait bound on the associated type itself.
Owned<T,Dynamic,C>: Clone, Owned<T, Dynamic, C>: Clone,
{ {
/// # Example /// # Example
/// ``` /// ```
@ -575,11 +576,9 @@ where
#[inline] #[inline]
fn mul(self, rhs: &'b Matrix<T, R2, C2, SB>) -> Self::Output { fn mul(self, rhs: &'b Matrix<T, R2, C2, SB>) -> Self::Output {
let mut res = unsafe { let mut res =Matrix::new_uninitialized_generic(self.data.shape().0, rhs.data.shape().1);
crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, rhs.data.shape().1) self.mul_to(rhs, &mut res);
}; unsafe{ res.assume_init()}
self.mul_to(rhs, &mut res);
res
} }
} }
@ -687,12 +686,9 @@ where
DefaultAllocator: Allocator<T, C1, C2>, DefaultAllocator: Allocator<T, C1, C2>,
ShapeConstraint: SameNumberOfRows<R1, R2>, ShapeConstraint: SameNumberOfRows<R1, R2>,
{ {
let mut res = unsafe { let mut res = Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1);
crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1)
};
self.tr_mul_to(rhs, &mut res); self.tr_mul_to(rhs, &mut res);
res unsafe { res.assume_init() }
} }
/// Equivalent to `self.adjoint() * rhs`. /// Equivalent to `self.adjoint() * rhs`.
@ -701,30 +697,27 @@ where
pub fn ad_mul<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> OMatrix<T, C1, C2> pub fn ad_mul<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> OMatrix<T, C1, C2>
where where
T: SimdComplexField, T: SimdComplexField,
SB: Storage<T, R2, C2>, SB: Storage<MaybeUninit<T>, R2, C2>,
DefaultAllocator: Allocator<T, C1, C2>, DefaultAllocator: Allocator<T, C1, C2>,
ShapeConstraint: SameNumberOfRows<R1, R2>, ShapeConstraint: SameNumberOfRows<R1, R2>,
{ {
let mut res = unsafe { let mut res = Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1);
crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1)
};
self.ad_mul_to(rhs, &mut res); self.ad_mul_to(rhs, &mut res);
res unsafe { res.assume_init() }
} }
#[inline(always)] #[inline(always)]
fn xx_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>( fn xx_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self, &self,
rhs: &Matrix<T, R2, C2, SB>, rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<T, R3, C3, SC>, out: &mut Matrix<MaybeUninit<T>, R3, C3, SC>,
dot: impl Fn( dot: impl Fn(
&VectorSlice<T, R1, SA::RStride, SA::CStride>, &VectorSlice<T, R1, SA::RStride, SA::CStride>,
&VectorSlice<T, R2, SB::RStride, SB::CStride>, &VectorSlice<T, R2, SB::RStride, SB::CStride>,
) -> T, ) -> T,
) where ) where
SB: Storage<T, R2, C2>, SB: Storage<T, R2, C2>,
SC: StorageMut<T, R3, C3>, SC: StorageMut<MaybeUninit<T>, R3, C3>,
ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>, ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>,
{ {
let (nrows1, ncols1) = self.shape(); let (nrows1, ncols1) = self.shape();
@ -753,7 +746,7 @@ where
for i in 0..ncols1 { for i in 0..ncols1 {
for j in 0..ncols2 { for j in 0..ncols2 {
let dot = dot(&self.column(i), &rhs.column(j)); let dot = dot(&self.column(i), &rhs.column(j));
unsafe { *out.get_unchecked_mut((i, j)) = dot }; unsafe { *out.get_unchecked_mut((i, j)) = MaybeUninit::new(dot) ;}
} }
} }
} }
@ -764,10 +757,10 @@ where
pub fn tr_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>( pub fn tr_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self, &self,
rhs: &Matrix<T, R2, C2, SB>, rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<T, R3, C3, SC>, out: &mut Matrix<MaybeUninit<T>, R3, C3, SC>,
) where ) where
SB: Storage<T, R2, C2>, SB: Storage<T, R2, C2>,
SC: StorageMut<T, R3, C3>, SC: StorageMut<MaybeUninit<T>, R3, C3>,
ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>, ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>,
{ {
self.xx_mul_to(rhs, out, |a, b| a.dot(b)) self.xx_mul_to(rhs, out, |a, b| a.dot(b))
@ -779,11 +772,11 @@ where
pub fn ad_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>( pub fn ad_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self, &self,
rhs: &Matrix<T, R2, C2, SB>, rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<T, R3, C3, SC>, out: &mut Matrix<MaybeUninit<T>, R3, C3, SC>,
) where ) where
T: SimdComplexField, T: SimdComplexField,
SB: Storage<T, R2, C2>, SB: Storage<T, R2, C2>,
SC: StorageMut<T, R3, C3>, SC: StorageMut<MaybeUninit<T>, R3, C3>,
ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>, ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>,
{ {
self.xx_mul_to(rhs, out, |a, b| a.dotc(b)) self.xx_mul_to(rhs, out, |a, b| a.dotc(b))
@ -793,7 +786,7 @@ where
#[inline] #[inline]
pub fn mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>( pub fn mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self, &self,
rhs: &Matrix<T, R2, C2, SB>, rhs: &Matrix<MaybeUninit<T>, R2, C2, SB>,
out: &mut Matrix<T, R3, C3, SC>, out: &mut Matrix<T, R3, C3, SC>,
) where ) where
SB: Storage<T, R2, C2>, SB: Storage<T, R2, C2>,