Tied some blas loose strings

This commit is contained in:
Violeta Hernández 2021-07-16 00:27:16 -05:00
parent df9b6f5f64
commit 54e9750191
6 changed files with 48 additions and 47 deletions

View File

@ -6,7 +6,7 @@
//! that return an owned matrix that would otherwise result from setting a
//! parameter to zero in the other methods.
use crate::{OMatrix, OVector, SimdComplexField};
use crate::{OMatrix, SimdComplexField};
#[cfg(feature = "std")]
use matrixmultiply;
use num::{One, Zero};
@ -795,7 +795,7 @@ where
}
}
impl<T, R1: Dim, C1: Dim> OMatrix<T, R1, C1>
impl<T, R1: Dim, C1: Dim, S: StorageMut<MaybeUninit<T>, R1, C1>> Matrix<MaybeUninit<T>, R1, C1, S>
where
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
DefaultAllocator: Allocator<T, R1, C1>,
@ -821,27 +821,18 @@ where
/// ```
#[inline]
pub fn gemm_z<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
&mut self,
alpha: T,
a: &Matrix<T, R2, C2, SB>,
b: &Matrix<T, R3, C3, SC>,
) -> Self
where
) where
SB: Storage<T, R2, C2>,
SC: Storage<T, R3, C3>,
ShapeConstraint: SameNumberOfRows<R1, R2>
+ SameNumberOfColumns<C1, C3>
+ AreMultipliable<R2, C2, R3, C3>,
{
let (nrows1, ncols1) = a.shape();
let (nrows2, ncols2) = b.shape();
assert_eq!(
ncols1, nrows2,
"gemm: dimensions mismatch for multiplication."
);
let mut res =
Matrix::new_uninitialized_generic(R1::from_usize(nrows1), C1::from_usize(ncols2));
let ncols1 = self.ncols();
#[cfg(feature = "std")]
{
@ -857,6 +848,9 @@ where
|| C3::is::<Dynamic>()
{
// matrixmultiply can be used only if the std feature is available.
let nrows1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape();
// Threshold determined empirically.
const SMALL_DIM: usize = 5;
@ -866,29 +860,35 @@ where
&& nrows2 > SMALL_DIM
&& ncols2 > SMALL_DIM
{
assert_eq!(
ncols1, nrows2,
"gemm: dimensions mismatch for multiplication."
);
assert_eq!(
(nrows1, ncols1),
(nrows2, ncols3),
"gemm: dimensions mismatch for addition."
);
// NOTE: this case should never happen because we enter this
// codepath only when ncols2 > SMALL_DIM. Though we keep this
// here just in case if in the future we change the conditions to
// enter this codepath.
if ncols1 == 0 {
// NOTE: we can't just always multiply by beta
// because we documented the guaranty that `self` is
// never read if `beta` is zero.
// Safety: this buffer is empty.
return res.assume_init();
self.fill_fn(|| MaybeUninit::new(T::zero()));
return;
}
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = res.strides();
let (rsc, csc) = self.strides();
if T::is::<f32>() {
unsafe {
matrixmultiply::sgemm(
nrows1,
ncols1,
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f32,
rsa as isize,
@ -897,19 +897,19 @@ where
rsb as isize,
csb as isize,
0.0,
res.data.ptr_mut() as *mut f32,
self.data.ptr_mut() as *mut f32,
rsc as isize,
csc as isize,
);
return res.assume_init();
return;
}
} else if T::is::<f64>() {
unsafe {
matrixmultiply::dgemm(
nrows1,
ncols1,
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f64,
rsa as isize,
@ -918,12 +918,12 @@ where
rsb as isize,
csb as isize,
0.0,
res.data.ptr_mut() as *mut f64,
self.data.ptr_mut() as *mut f64,
rsc as isize,
csc as isize,
);
return res.assume_init();
return ;
}
}
}
@ -932,11 +932,9 @@ where
for j1 in 0..ncols1 {
// TODO: avoid bound checks.
res.column_mut(j1)
self.column_mut(j1)
.gemv_z(alpha.inlined_clone(), a, &b.column(j1));
}
unsafe { res.assume_init() }
}
}

View File

@ -633,13 +633,13 @@ where
); // Arguments for non-generic constructors.
}
impl<T, R: DimName, C: DimName> OMatrix<MaybeUninit<T>, R, C>
impl<T, R: DimName, C: DimName> OMatrix<T, R, C>
where
DefaultAllocator: Allocator<T, R, C>,
{
/// Creates a new uninitialized matrix or vector.
#[inline]
pub fn new_uninitialized() -> Self {
pub fn new_uninitialized() -> OMatrix<MaybeUninit<T>, R, C> {
Self::new_uninitialized_generic(R::name(), C::name())
}
}
@ -655,13 +655,13 @@ where
ncols);
}
impl<T, R: DimName> OMatrix<MaybeUninit<T>, R, Dynamic>
impl<T, R: DimName> OMatrix<T, R, Dynamic>
where
DefaultAllocator: Allocator<T, R, Dynamic>,
{
/// Creates a new uninitialized matrix or vector.
#[inline]
pub fn new_uninitialized(ncols: usize) -> Self {
pub fn new_uninitialized(ncols: usize) -> OMatrix<MaybeUninit<T>, R, Dynamic> {
Self::new_uninitialized_generic(R::name(), Dynamic::new(ncols))
}
}
@ -677,13 +677,13 @@ where
nrows);
}
impl<T, C: DimName> OMatrix<MaybeUninit<T>, Dynamic, C>
impl<T, C: DimName> OMatrix<T, Dynamic, C>
where
DefaultAllocator: Allocator<T, Dynamic, C>,
{
/// Creates a new uninitialized matrix or vector.
#[inline]
pub fn new_uninitialized(nrows: usize) -> Self {
pub fn new_uninitialized(nrows: usize) -> OMatrix<MaybeUninit<T>, Dynamic, C> {
Self::new_uninitialized_generic(Dynamic::new(nrows), C::name())
}
}
@ -699,13 +699,13 @@ where
nrows, ncols);
}
impl<T> OMatrix<MaybeUninit<T>, Dynamic, Dynamic>
impl<T> OMatrix<T, Dynamic, Dynamic>
where
DefaultAllocator: Allocator<T, Dynamic, Dynamic>,
{
/// Creates a new uninitialized matrix or vector.
#[inline]
pub fn new_uninitialized(nrows: usize, ncols: usize) -> Self {
pub fn new_uninitialized(nrows: usize, ncols: usize) -> OMatrix<MaybeUninit<T>, Dynamic, Dynamic> {
Self::new_uninitialized_generic(Dynamic::new(nrows), Dynamic::new(ncols))
}
}

View File

@ -4,7 +4,6 @@
//! heap-allocated buffers for matrices with at least one dimension unknown at compile-time.
use std::cmp;
use std::mem;
use std::mem::ManuallyDrop;
use std::mem::MaybeUninit;
use std::ptr;

View File

@ -53,7 +53,8 @@ impl<T: Scalar + Zero, R: Dim, C: Dim, S: Storage<T, R, C>> Matrix<T, R, C, S> {
{
let irows = irows.into_iter();
let ncols = self.data.shape().1;
let mut res = OMatrix::<T, Dynamic, C>::new_uninitialized_generic(Dynamic::new(irows.len()), ncols);
let mut res =
OMatrix::<T, Dynamic, C>::new_uninitialized_generic(Dynamic::new(irows.len()), ncols);
// First, check that all the indices from irows are valid.
// This will allow us to use unchecked access in the inner loop.

View File

@ -223,7 +223,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
SliceStorage<'a, MaybeUninit<T>, R, C, RStride, CStride>
{
pub unsafe fn assume_init(self) -> SliceStorage<'a, T, R, C, RStride, CStride> {
Self::from_raw_parts(self.ptr as *const T, self.shape, self.strides)
SliceStorage::from_raw_parts(self.ptr as *const T, self.shape, self.strides)
}
}
@ -231,7 +231,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
SliceStorageMut<'a, MaybeUninit<T>, R, C, RStride, CStride>
{
pub unsafe fn assume_init(self) -> SliceStorageMut<'a, T, R, C, RStride, CStride> {
Self::from_raw_parts(self.ptr as *mut T, self.shape, self.strides)
SliceStorageMut::from_raw_parts(self.ptr as *mut T, self.shape, self.strides)
}
}
@ -606,7 +606,7 @@ macro_rules! matrix_slice_impl(
/// Returns a slice containing the entire matrix.
pub fn $full_slice($me: $Me) -> $MatrixSlice<T, R, C, S::RStride, S::CStride> {
let (nrows, ncols) = $me.shape();
$me.generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols)))
$me.$generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols)))
}
/*

View File

@ -640,7 +640,8 @@ where
// TODO: this is too restrictive:
// we can't use `a *= b` when `a` is a mutable slice.
// we can't use `a *= b` when C2 is not equal to C1.
impl<T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<Matrix<T, R2, C1, SB>> for Matrix<T, R1, C1, SA>
impl<T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<Matrix<T, R2, C1, SB>>
for Matrix<T, R1, C1, SA>
where
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
SB: Storage<T, R2, C1>,
@ -654,7 +655,8 @@ where
}
}
impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix<T, R2, C1, SB>> for Matrix<T, R1, C1, SA>
impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix<T, R2, C1, SB>>
for Matrix<T, R1, C1, SA>
where
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
SB: Storage<T, R2, C1>,
@ -794,6 +796,7 @@ where
ShapeConstraint: SameNumberOfRows<R3, R1>
+ SameNumberOfColumns<C3, C2>
+ AreMultipliable<R1, C1, R2, C2>,
DefaultAllocator: Allocator<T, R3, C3>,
{
out.gemm_z(T::one(), self, rhs);
}