`blas.rs` should be sound now

This commit is contained in:
Violeta Hernández 2021-07-14 23:30:31 -05:00
parent 775917142b
commit bbd045d216
7 changed files with 162 additions and 79 deletions

View File

@ -108,7 +108,7 @@ where
unsafe impl<T, const R: usize, const C: usize> StorageMut<T, Const<R>, Const<C>>
for ArrayStorage<T, R, C>
where
DefaultAllocator:InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
DefaultAllocator: InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
{
#[inline]
fn ptr_mut(&mut self) -> *mut T {
@ -124,14 +124,14 @@ where
unsafe impl<T, const R: usize, const C: usize> ContiguousStorage<T, Const<R>, Const<C>>
for ArrayStorage<T, R, C>
where
DefaultAllocator:InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
DefaultAllocator: InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
{
}
unsafe impl<T, const R: usize, const C: usize> ContiguousStorageMut<T, Const<R>, Const<C>>
for ArrayStorage<T, R, C>
where
DefaultAllocator:InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
DefaultAllocator: InnerAllocator<T, Const<R>, Const<C>, Buffer = Self>,
{
}

View File

@ -1,10 +1,11 @@
use crate::SimdComplexField;
use crate::{OVector, SimdComplexField};
#[cfg(feature = "std")]
use matrixmultiply;
use num::{One, Zero};
use simba::scalar::{ClosedAdd, ClosedMul};
#[cfg(feature = "std")]
use std::mem;
use std::mem::MaybeUninit;
use crate::base::allocator::Allocator;
use crate::base::constraint::{
@ -315,6 +316,28 @@ where
}
}
fn array_axc_uninit<T>(
y: &mut [MaybeUninit<T>],
a: T,
x: &[T],
c: T,
stride1: usize,
stride2: usize,
len: usize,
) where
T: Scalar + Zero + ClosedAdd + ClosedMul,
{
for i in 0..len {
unsafe {
*y.get_unchecked_mut(i * stride1) = MaybeUninit::new(
a.inlined_clone()
* x.get_unchecked(i * stride2).inlined_clone()
* c.inlined_clone(),
);
}
}
}
/// # BLAS functions
impl<T, D: Dim, S> Vector<T, D, S>
where
@ -723,6 +746,80 @@ where
}
}
impl<T, D: Dim> OVector<MaybeUninit<T>, D>
where
T: Scalar + Zero + ClosedAdd + ClosedMul,
DefaultAllocator: Allocator<T, D>,
{
pub fn axc<D2: Dim, SB>(&mut self, a: T, x: &Vector<T, D2, SB>, c: T) -> OVector<T, D>
where
SB: Storage<T, D2>,
ShapeConstraint: DimEq<D, D2>,
{
assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
let rstride1 = self.strides().0;
let rstride2 = x.strides().0;
unsafe {
// SAFETY: the conversion to slices is OK because we access the
// elements taking the strides into account.
let y = self.data.as_mut_slice_unchecked();
let x = x.data.as_slice_unchecked();
array_axc_uninit(y, a, x, c, rstride1, rstride2, x.len());
self.assume_init()
}
}
/// Computes `self = alpha * a * x, where `a` is a matrix, `x` a vector, and
/// `alpha` is a scalar.
///
/// By the time this method returns, `self` will have been initialized.
#[inline]
pub fn gemv_uninit<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
mut self,
alpha: T,
a: &Matrix<T, R2, C2, SB>,
x: &Vector<T, D3, SC>,
beta: T,
) -> OVector<T, D>
where
T: One,
SB: Storage<T, R2, C2>,
SC: Storage<T, D3>,
ShapeConstraint: DimEq<D, R2> + AreMultipliable<R2, C2, D3, U1>,
{
let dim1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let dim3 = x.nrows();
assert!(
ncols2 == dim3 && dim1 == nrows2,
"Gemv: dimensions mismatch."
);
if ncols2 == 0 {
self.fill_fn(|| MaybeUninit::new(T::zero()));
return self.assume_init();
}
// TODO: avoid bound checks.
let col2 = a.column(0);
let val = unsafe { x.vget_unchecked(0).inlined_clone() };
let res = self.axc(alpha.inlined_clone(), &col2, val);
for j in 1..ncols2 {
let col2 = a.column(j);
let val = unsafe { x.vget_unchecked(j).inlined_clone() };
res.axcpy(alpha.inlined_clone(), &col2, val, T::one());
}
res
}
}
impl<T, R1: Dim, C1: Dim, S: StorageMut<T, R1, C1>> Matrix<T, R1, C1, S>
where
T: Scalar + Zero + ClosedAdd + ClosedMul,
@ -1275,29 +1372,25 @@ where
///
/// mat.quadform_tr_with_workspace(&mut workspace, 10.0, &lhs, &mid, 5.0);
/// assert_relative_eq!(mat, expected);
pub fn quadform_tr_with_workspace<D2, S2, R3, C3, S3, D4, S4>(
pub fn quadform_tr_with_workspace<D2: Dim, R3: Dim, C3: Dim, S3, D4: Dim, S4>(
&mut self,
work: &mut Vector<T, D2, S2>,
work: &mut OVector<MaybeUninit<T>, D2>,
alpha: T,
lhs: &Matrix<T, R3, C3, S3>,
mid: &SquareMatrix<T, D4, S4>,
beta: T,
) where
D2: Dim,
R3: Dim,
C3: Dim,
D4: Dim,
S2: StorageMut<T, D2>,
S3: Storage<T, R3, C3>,
S4: Storage<T, D4, D4>,
ShapeConstraint: DimEq<D1, D2> + DimEq<D1, R3> + DimEq<D2, R3> + DimEq<C3, D4>,
DefaultAllocator: Allocator<T, D2>,
{
work.gemv(T::one(), lhs, &mid.column(0), T::zero());
self.ger(alpha.inlined_clone(), work, &lhs.column(0), beta);
let work = work.gemv_uninit(T::one(), lhs, &mid.column(0), T::zero());
self.ger(alpha.inlined_clone(), &work, &lhs.column(0), beta);
for j in 1..mid.ncols() {
work.gemv(T::one(), lhs, &mid.column(j), T::zero());
self.ger(alpha.inlined_clone(), work, &lhs.column(j), T::one());
self.ger(alpha.inlined_clone(), &work, &lhs.column(j), T::one());
}
}
@ -1322,24 +1415,19 @@ where
///
/// mat.quadform_tr(10.0, &lhs, &mid, 5.0);
/// assert_relative_eq!(mat, expected);
pub fn quadform_tr<R3, C3, S3, D4, S4>(
pub fn quadform_tr<R3: Dim, C3: Dim, S3, D4: Dim, S4>(
&mut self,
alpha: T,
lhs: &Matrix<T, R3, C3, S3>,
mid: &SquareMatrix<T, D4, S4>,
beta: T,
) where
R3: Dim,
C3: Dim,
D4: Dim,
S3: Storage<T, R3, C3>,
S4: Storage<T, D4, D4>,
ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>,
DefaultAllocator: Allocator<T, D1>,
{
let mut work = unsafe {
crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, Const::<1>)
};
let mut work = Matrix::new_uninitialized_generic(self.data.shape().0, Const::<1>);
self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta)
}
@ -1368,32 +1456,28 @@ where
///
/// mat.quadform_with_workspace(&mut workspace, 10.0, &mid, &rhs, 5.0);
/// assert_relative_eq!(mat, expected);
pub fn quadform_with_workspace<D2, S2, D3, S3, R4, C4, S4>(
pub fn quadform_with_workspace<D2: Dim, D3: Dim, S3, R4: Dim, C4: Dim, S4>(
&mut self,
work: &mut Vector<T, D2, S2>,
work: &mut OVector<MaybeUninit<T>, D2>,
alpha: T,
mid: &SquareMatrix<T, D3, S3>,
rhs: &Matrix<T, R4, C4, S4>,
beta: T,
) where
D2: Dim,
D3: Dim,
R4: Dim,
C4: Dim,
S2: StorageMut<T, D2>,
S3: Storage<T, D3, D3>,
S4: Storage<T, R4, C4>,
ShapeConstraint:
DimEq<D3, R4> + DimEq<D1, C4> + DimEq<D2, D3> + AreMultipliable<C4, R4, D2, U1>,
DefaultAllocator: Allocator<T, D2>,
{
work.gemv(T::one(), mid, &rhs.column(0), T::zero());
let work = work.gemv_uninit(T::one(), mid, &rhs.column(0), T::zero());
self.column_mut(0)
.gemv_tr(alpha.inlined_clone(), rhs, work, beta.inlined_clone());
.gemv_tr(alpha.inlined_clone(), rhs, &work, beta.inlined_clone());
for j in 1..rhs.ncols() {
work.gemv(T::one(), mid, &rhs.column(j), T::zero());
self.column_mut(j)
.gemv_tr(alpha.inlined_clone(), rhs, work, beta.inlined_clone());
.gemv_tr(alpha.inlined_clone(), rhs, &work, beta.inlined_clone());
}
}
@ -1417,24 +1501,19 @@ where
///
/// mat.quadform(10.0, &mid, &rhs, 5.0);
/// assert_relative_eq!(mat, expected);
pub fn quadform<D2, S2, R3, C3, S3>(
pub fn quadform<D2: Dim, S2, R3: Dim, C3: Dim, S3>(
&mut self,
alpha: T,
mid: &SquareMatrix<T, D2, S2>,
rhs: &Matrix<T, R3, C3, S3>,
beta: T,
) where
D2: Dim,
R3: Dim,
C3: Dim,
S2: Storage<T, D2, D2>,
S3: Storage<T, R3, C3>,
ShapeConstraint: DimEq<D2, R3> + DimEq<D1, C3> + AreMultipliable<C3, R3, D2, U1>,
DefaultAllocator: Allocator<T, D2>,
{
let mut work = unsafe {
crate::unimplemented_or_uninitialized_generic!(mid.data.shape().0, Const::<1>)
};
let mut work = Matrix::new_uninitialized_generic(mid.data.shape().0, Const::<1>);
self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta)
}
}

View File

@ -18,7 +18,7 @@ use typenum::{self, Cmp, Greater};
use simba::scalar::{ClosedAdd, ClosedMul};
use crate::base::allocator::Allocator;
use crate::{base::allocator::Allocator};
use crate::base::dimension::{Dim, DimName, Dynamic, ToTypenum};
use crate::base::storage::Storage;
use crate::base::{
@ -117,7 +117,7 @@ where
/// Creates a matrix with its elements filled with the components provided by a slice. The
/// components must have the same layout as the matrix data storage (i.e. column-major).
#[inline]
pub fn from_column_slice_generic(nrows: R, ncols: C, slice: &[T]) -> Self {
pub fn from_column_slice_generic(nrows: R, ncols: C, slice: &[T]) -> Self where T:Clone{
Self::from_iterator_generic(nrows, ncols, slice.iter().cloned())
}
@ -139,7 +139,7 @@ where
}
// Safety: all entries have been initialized.
unsafe { Matrix::assume_init(res) }
unsafe { res.assume_init()}
}
/// Creates a new identity matrix.
@ -352,7 +352,7 @@ where
#[inline]
pub fn from_diagonal<SB: Storage<T, D>>(diag: &Vector<T, D, SB>) -> Self
where
T: Zero,
T: Zero+Scalar,
{
let (dim, _) = diag.data.shape();
let mut res = Self::zeros_generic(dim, dim);

View File

@ -158,12 +158,23 @@ impl<T: Scalar, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
}
/// # In-place filling
impl<T: Scalar, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
impl<T, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
/// Sets all the elements of this matrix to `val`.
#[inline]
pub fn fill(&mut self, val: T) {
pub fn fill(&mut self, val: T)
where
T: Clone,
{
for e in self.iter_mut() {
*e = val.inlined_clone()
*e = val.clone()
}
}
/// Sets all the elements of this matrix to `f()`.
#[inline]
pub fn fill_fn<F: FnMut() -> T>(&mut self, f: F) {
for e in self.iter_mut() {
*e = f();
}
}
@ -171,7 +182,7 @@ impl<T: Scalar, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
#[inline]
pub fn fill_with_identity(&mut self)
where
T: Zero + One,
T: Zero + One + Scalar,
{
self.fill(T::zero());
self.fill_diagonal(T::one());
@ -184,7 +195,7 @@ impl<T: Scalar, R: Dim, C: Dim, S: StorageMut<T, R, C>> Matrix<T, R, C, S> {
let n = cmp::min(nrows, ncols);
for i in 0..n {
unsafe { *self.get_unchecked_mut((i, i)) = val.inlined_clone() }
unsafe { *self.get_unchecked_mut((i, i)) = val.clone() }
}
}

View File

@ -657,7 +657,7 @@ impl<T, R: Dim, C: Dim, S: Storage<T, R, C>> Matrix<T, R, C, S> {
}
}
unsafe { Matrix::assume_init(res) }
unsafe { res.assume_init()}
}
/// Transposes `self` and store the result into `out`, which will become
@ -666,7 +666,7 @@ impl<T, R: Dim, C: Dim, S: Storage<T, R, C>> Matrix<T, R, C, S> {
pub fn transpose_to<R2: Dim, C2: Dim, SB>(&self, out: &mut Matrix<MaybeUninit<T>, R2, C2, SB>)
where
T: Clone,
SB: StorageMut<T, R2, C2>,
SB: StorageMut<MaybeUninit<T>, R2, C2>,
ShapeConstraint: SameNumberOfRows<R, C2> + SameNumberOfColumns<C, R2>,
{
let (nrows, ncols) = self.shape();

View File

@ -2,12 +2,12 @@ use std::marker::PhantomData;
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo};
use std::slice;
use crate::base::allocator::Allocator;
use crate::base::allocator::{Allocator, InnerAllocator};
use crate::base::default_allocator::DefaultAllocator;
use crate::base::dimension::{Const, Dim, DimName, Dynamic, IsNotStaticOne, U1};
use crate::base::iter::MatrixIter;
use crate::base::storage::{ContiguousStorage, ContiguousStorageMut, Owned, Storage, StorageMut};
use crate::base::{Matrix, Scalar};
use crate::base::Matrix;
macro_rules! slice_storage_impl(
($doc: expr; $Storage: ident as $SRef: ty; $T: ident.$get_addr: ident ($Ptr: ty as $Ref: ty)) => {

View File

@ -7,16 +7,17 @@ use std::ops::{
use simba::scalar::{ClosedAdd, ClosedDiv, ClosedMul, ClosedNeg, ClosedSub};
use crate::allocator::InnerAllocator;
use crate::base::allocator::{Allocator, SameShapeAllocator, SameShapeC, SameShapeR};
use crate::base::allocator::{
Allocator, InnerAllocator, SameShapeAllocator, SameShapeC, SameShapeR,
};
use crate::base::constraint::{
AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
};
use crate::base::dimension::{Dim, DimMul, DimName, DimProd, Dynamic};
use crate::base::storage::{ContiguousStorageMut, Storage, StorageMut};
use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorSlice};
use crate::SimdComplexField;
use crate::storage::Owned;
use crate::SimdComplexField;
/*
*
@ -431,7 +432,7 @@ where
// TODO: we should take out this trait bound, as T: Clone should suffice.
// The brute way to do it would be how it was already done: by adding this
// trait bound on the associated type itself.
Owned<T,Dynamic,C>: Clone,
Owned<T, Dynamic, C>: Clone,
{
/// # Example
/// ```
@ -575,11 +576,9 @@ where
#[inline]
fn mul(self, rhs: &'b Matrix<T, R2, C2, SB>) -> Self::Output {
let mut res = unsafe {
crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, rhs.data.shape().1)
};
self.mul_to(rhs, &mut res);
res
let mut res =Matrix::new_uninitialized_generic(self.data.shape().0, rhs.data.shape().1);
self.mul_to(rhs, &mut res);
unsafe{ res.assume_init()}
}
}
@ -687,12 +686,9 @@ where
DefaultAllocator: Allocator<T, C1, C2>,
ShapeConstraint: SameNumberOfRows<R1, R2>,
{
let mut res = unsafe {
crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1)
};
let mut res = Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1);
self.tr_mul_to(rhs, &mut res);
res
unsafe { res.assume_init() }
}
/// Equivalent to `self.adjoint() * rhs`.
@ -701,30 +697,27 @@ where
pub fn ad_mul<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> OMatrix<T, C1, C2>
where
T: SimdComplexField,
SB: Storage<T, R2, C2>,
SB: Storage<MaybeUninit<T>, R2, C2>,
DefaultAllocator: Allocator<T, C1, C2>,
ShapeConstraint: SameNumberOfRows<R1, R2>,
{
let mut res = unsafe {
crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1)
};
let mut res = Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1);
self.ad_mul_to(rhs, &mut res);
res
unsafe { res.assume_init() }
}
#[inline(always)]
fn xx_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self,
rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<T, R3, C3, SC>,
out: &mut Matrix<MaybeUninit<T>, R3, C3, SC>,
dot: impl Fn(
&VectorSlice<T, R1, SA::RStride, SA::CStride>,
&VectorSlice<T, R2, SB::RStride, SB::CStride>,
) -> T,
) where
SB: Storage<T, R2, C2>,
SC: StorageMut<T, R3, C3>,
SC: StorageMut<MaybeUninit<T>, R3, C3>,
ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>,
{
let (nrows1, ncols1) = self.shape();
@ -753,7 +746,7 @@ where
for i in 0..ncols1 {
for j in 0..ncols2 {
let dot = dot(&self.column(i), &rhs.column(j));
unsafe { *out.get_unchecked_mut((i, j)) = dot };
unsafe { *out.get_unchecked_mut((i, j)) = MaybeUninit::new(dot) ;}
}
}
}
@ -764,10 +757,10 @@ where
pub fn tr_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self,
rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<T, R3, C3, SC>,
out: &mut Matrix<MaybeUninit<T>, R3, C3, SC>,
) where
SB: Storage<T, R2, C2>,
SC: StorageMut<T, R3, C3>,
SC: StorageMut<MaybeUninit<T>, R3, C3>,
ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>,
{
self.xx_mul_to(rhs, out, |a, b| a.dot(b))
@ -779,11 +772,11 @@ where
pub fn ad_mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self,
rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<T, R3, C3, SC>,
out: &mut Matrix<MaybeUninit<T>, R3, C3, SC>,
) where
T: SimdComplexField,
SB: Storage<T, R2, C2>,
SC: StorageMut<T, R3, C3>,
SC: StorageMut<MaybeUninit<T>, R3, C3>,
ShapeConstraint: SameNumberOfRows<R1, R2> + DimEq<C1, R3> + DimEq<C2, C3>,
{
self.xx_mul_to(rhs, out, |a, b| a.dotc(b))
@ -793,7 +786,7 @@ where
#[inline]
pub fn mul_to<R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self,
rhs: &Matrix<T, R2, C2, SB>,
rhs: &Matrix<MaybeUninit<T>, R2, C2, SB>,
out: &mut Matrix<T, R3, C3, SC>,
) where
SB: Storage<T, R2, C2>,