Tied some blas loose strings

This commit is contained in:
Violeta Hernández 2021-07-16 00:27:16 -05:00
parent df9b6f5f64
commit 54e9750191
6 changed files with 48 additions and 47 deletions

View File

@ -6,7 +6,7 @@
//! that return an owned matrix that would otherwise result from setting a //! that return an owned matrix that would otherwise result from setting a
//! parameter to zero in the other methods. //! parameter to zero in the other methods.
use crate::{OMatrix, OVector, SimdComplexField}; use crate::{OMatrix, SimdComplexField};
#[cfg(feature = "std")] #[cfg(feature = "std")]
use matrixmultiply; use matrixmultiply;
use num::{One, Zero}; use num::{One, Zero};
@ -795,7 +795,7 @@ where
} }
} }
impl<T, R1: Dim, C1: Dim> OMatrix<T, R1, C1> impl<T, R1: Dim, C1: Dim, S: StorageMut<MaybeUninit<T>, R1, C1>> Matrix<MaybeUninit<T>, R1, C1, S>
where where
T: Scalar + Zero + One + ClosedAdd + ClosedMul, T: Scalar + Zero + One + ClosedAdd + ClosedMul,
DefaultAllocator: Allocator<T, R1, C1>, DefaultAllocator: Allocator<T, R1, C1>,
@ -821,27 +821,18 @@ where
/// ``` /// ```
#[inline] #[inline]
pub fn gemm_z<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>( pub fn gemm_z<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
&mut self,
alpha: T, alpha: T,
a: &Matrix<T, R2, C2, SB>, a: &Matrix<T, R2, C2, SB>,
b: &Matrix<T, R3, C3, SC>, b: &Matrix<T, R3, C3, SC>,
) -> Self ) where
where
SB: Storage<T, R2, C2>, SB: Storage<T, R2, C2>,
SC: Storage<T, R3, C3>, SC: Storage<T, R3, C3>,
ShapeConstraint: SameNumberOfRows<R1, R2> ShapeConstraint: SameNumberOfRows<R1, R2>
+ SameNumberOfColumns<C1, C3> + SameNumberOfColumns<C1, C3>
+ AreMultipliable<R2, C2, R3, C3>, + AreMultipliable<R2, C2, R3, C3>,
{ {
let (nrows1, ncols1) = a.shape(); let ncols1 = self.ncols();
let (nrows2, ncols2) = b.shape();
assert_eq!(
ncols1, nrows2,
"gemm: dimensions mismatch for multiplication."
);
let mut res =
Matrix::new_uninitialized_generic(R1::from_usize(nrows1), C1::from_usize(ncols2));
#[cfg(feature = "std")] #[cfg(feature = "std")]
{ {
@ -857,6 +848,9 @@ where
|| C3::is::<Dynamic>() || C3::is::<Dynamic>()
{ {
// matrixmultiply can be used only if the std feature is available. // matrixmultiply can be used only if the std feature is available.
let nrows1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape();
// Threshold determined empirically. // Threshold determined empirically.
const SMALL_DIM: usize = 5; const SMALL_DIM: usize = 5;
@ -866,29 +860,35 @@ where
&& nrows2 > SMALL_DIM && nrows2 > SMALL_DIM
&& ncols2 > SMALL_DIM && ncols2 > SMALL_DIM
{ {
assert_eq!(
ncols1, nrows2,
"gemm: dimensions mismatch for multiplication."
);
assert_eq!(
(nrows1, ncols1),
(nrows2, ncols3),
"gemm: dimensions mismatch for addition."
);
// NOTE: this case should never happen because we enter this // NOTE: this case should never happen because we enter this
// codepath only when ncols2 > SMALL_DIM. Though we keep this // codepath only when ncols2 > SMALL_DIM. Though we keep this
// here just in case if in the future we change the conditions to // here just in case if in the future we change the conditions to
// enter this codepath. // enter this codepath.
if ncols1 == 0 { if ncols1 == 0 {
// NOTE: we can't just always multiply by beta self.fill_fn(|| MaybeUninit::new(T::zero()));
// because we documented the guaranty that `self` is return;
// never read if `beta` is zero.
// Safety: this buffer is empty.
return res.assume_init();
} }
let (rsa, csa) = a.strides(); let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides(); let (rsb, csb) = b.strides();
let (rsc, csc) = res.strides(); let (rsc, csc) = self.strides();
if T::is::<f32>() { if T::is::<f32>() {
unsafe { unsafe {
matrixmultiply::sgemm( matrixmultiply::sgemm(
nrows1, nrows2,
ncols1,
ncols2, ncols2,
ncols3,
mem::transmute_copy(&alpha), mem::transmute_copy(&alpha),
a.data.ptr() as *const f32, a.data.ptr() as *const f32,
rsa as isize, rsa as isize,
@ -897,19 +897,19 @@ where
rsb as isize, rsb as isize,
csb as isize, csb as isize,
0.0, 0.0,
res.data.ptr_mut() as *mut f32, self.data.ptr_mut() as *mut f32,
rsc as isize, rsc as isize,
csc as isize, csc as isize,
); );
return res.assume_init(); return;
} }
} else if T::is::<f64>() { } else if T::is::<f64>() {
unsafe { unsafe {
matrixmultiply::dgemm( matrixmultiply::dgemm(
nrows1, nrows2,
ncols1,
ncols2, ncols2,
ncols3,
mem::transmute_copy(&alpha), mem::transmute_copy(&alpha),
a.data.ptr() as *const f64, a.data.ptr() as *const f64,
rsa as isize, rsa as isize,
@ -918,12 +918,12 @@ where
rsb as isize, rsb as isize,
csb as isize, csb as isize,
0.0, 0.0,
res.data.ptr_mut() as *mut f64, self.data.ptr_mut() as *mut f64,
rsc as isize, rsc as isize,
csc as isize, csc as isize,
); );
return res.assume_init(); return ;
} }
} }
} }
@ -932,11 +932,9 @@ where
for j1 in 0..ncols1 { for j1 in 0..ncols1 {
// TODO: avoid bound checks. // TODO: avoid bound checks.
res.column_mut(j1) self.column_mut(j1)
.gemv_z(alpha.inlined_clone(), a, &b.column(j1)); .gemv_z(alpha.inlined_clone(), a, &b.column(j1));
} }
unsafe { res.assume_init() }
} }
} }

View File

@ -633,13 +633,13 @@ where
); // Arguments for non-generic constructors. ); // Arguments for non-generic constructors.
} }
impl<T, R: DimName, C: DimName> OMatrix<MaybeUninit<T>, R, C> impl<T, R: DimName, C: DimName> OMatrix<T, R, C>
where where
DefaultAllocator: Allocator<T, R, C>, DefaultAllocator: Allocator<T, R, C>,
{ {
/// Creates a new uninitialized matrix or vector. /// Creates a new uninitialized matrix or vector.
#[inline] #[inline]
pub fn new_uninitialized() -> Self { pub fn new_uninitialized() -> OMatrix<MaybeUninit<T>, R, C> {
Self::new_uninitialized_generic(R::name(), C::name()) Self::new_uninitialized_generic(R::name(), C::name())
} }
} }
@ -655,13 +655,13 @@ where
ncols); ncols);
} }
impl<T, R: DimName> OMatrix<MaybeUninit<T>, R, Dynamic> impl<T, R: DimName> OMatrix<T, R, Dynamic>
where where
DefaultAllocator: Allocator<T, R, Dynamic>, DefaultAllocator: Allocator<T, R, Dynamic>,
{ {
/// Creates a new uninitialized matrix or vector. /// Creates a new uninitialized matrix or vector.
#[inline] #[inline]
pub fn new_uninitialized(ncols: usize) -> Self { pub fn new_uninitialized(ncols: usize) -> OMatrix<MaybeUninit<T>, R, Dynamic> {
Self::new_uninitialized_generic(R::name(), Dynamic::new(ncols)) Self::new_uninitialized_generic(R::name(), Dynamic::new(ncols))
} }
} }
@ -677,13 +677,13 @@ where
nrows); nrows);
} }
impl<T, C: DimName> OMatrix<MaybeUninit<T>, Dynamic, C> impl<T, C: DimName> OMatrix<T, Dynamic, C>
where where
DefaultAllocator: Allocator<T, Dynamic, C>, DefaultAllocator: Allocator<T, Dynamic, C>,
{ {
/// Creates a new uninitialized matrix or vector. /// Creates a new uninitialized matrix or vector.
#[inline] #[inline]
pub fn new_uninitialized(nrows: usize) -> Self { pub fn new_uninitialized(nrows: usize) -> OMatrix<MaybeUninit<T>, Dynamic, C> {
Self::new_uninitialized_generic(Dynamic::new(nrows), C::name()) Self::new_uninitialized_generic(Dynamic::new(nrows), C::name())
} }
} }
@ -699,13 +699,13 @@ where
nrows, ncols); nrows, ncols);
} }
impl<T> OMatrix<MaybeUninit<T>, Dynamic, Dynamic> impl<T> OMatrix<T, Dynamic, Dynamic>
where where
DefaultAllocator: Allocator<T, Dynamic, Dynamic>, DefaultAllocator: Allocator<T, Dynamic, Dynamic>,
{ {
/// Creates a new uninitialized matrix or vector. /// Creates a new uninitialized matrix or vector.
#[inline] #[inline]
pub fn new_uninitialized(nrows: usize, ncols: usize) -> Self { pub fn new_uninitialized(nrows: usize, ncols: usize) -> OMatrix<MaybeUninit<T>, Dynamic, Dynamic> {
Self::new_uninitialized_generic(Dynamic::new(nrows), Dynamic::new(ncols)) Self::new_uninitialized_generic(Dynamic::new(nrows), Dynamic::new(ncols))
} }
} }

View File

@ -4,7 +4,6 @@
//! heap-allocated buffers for matrices with at least one dimension unknown at compile-time. //! heap-allocated buffers for matrices with at least one dimension unknown at compile-time.
use std::cmp; use std::cmp;
use std::mem;
use std::mem::ManuallyDrop; use std::mem::ManuallyDrop;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::ptr; use std::ptr;

View File

@ -53,7 +53,8 @@ impl<T: Scalar + Zero, R: Dim, C: Dim, S: Storage<T, R, C>> Matrix<T, R, C, S> {
{ {
let irows = irows.into_iter(); let irows = irows.into_iter();
let ncols = self.data.shape().1; let ncols = self.data.shape().1;
let mut res = OMatrix::<T, Dynamic, C>::new_uninitialized_generic(Dynamic::new(irows.len()), ncols); let mut res =
OMatrix::<T, Dynamic, C>::new_uninitialized_generic(Dynamic::new(irows.len()), ncols);
// First, check that all the indices from irows are valid. // First, check that all the indices from irows are valid.
// This will allow us to use unchecked access in the inner loop. // This will allow us to use unchecked access in the inner loop.

View File

@ -223,7 +223,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
SliceStorage<'a, MaybeUninit<T>, R, C, RStride, CStride> SliceStorage<'a, MaybeUninit<T>, R, C, RStride, CStride>
{ {
pub unsafe fn assume_init(self) -> SliceStorage<'a, T, R, C, RStride, CStride> { pub unsafe fn assume_init(self) -> SliceStorage<'a, T, R, C, RStride, CStride> {
Self::from_raw_parts(self.ptr as *const T, self.shape, self.strides) SliceStorage::from_raw_parts(self.ptr as *const T, self.shape, self.strides)
} }
} }
@ -231,7 +231,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
SliceStorageMut<'a, MaybeUninit<T>, R, C, RStride, CStride> SliceStorageMut<'a, MaybeUninit<T>, R, C, RStride, CStride>
{ {
pub unsafe fn assume_init(self) -> SliceStorageMut<'a, T, R, C, RStride, CStride> { pub unsafe fn assume_init(self) -> SliceStorageMut<'a, T, R, C, RStride, CStride> {
Self::from_raw_parts(self.ptr as *mut T, self.shape, self.strides) SliceStorageMut::from_raw_parts(self.ptr as *mut T, self.shape, self.strides)
} }
} }
@ -606,7 +606,7 @@ macro_rules! matrix_slice_impl(
/// Returns a slice containing the entire matrix. /// Returns a slice containing the entire matrix.
pub fn $full_slice($me: $Me) -> $MatrixSlice<T, R, C, S::RStride, S::CStride> { pub fn $full_slice($me: $Me) -> $MatrixSlice<T, R, C, S::RStride, S::CStride> {
let (nrows, ncols) = $me.shape(); let (nrows, ncols) = $me.shape();
$me.generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols))) $me.$generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols)))
} }
/* /*

View File

@ -640,7 +640,8 @@ where
// TODO: this is too restrictive: // TODO: this is too restrictive:
// we can't use `a *= b` when `a` is a mutable slice. // we can't use `a *= b` when `a` is a mutable slice.
// we can't use `a *= b` when C2 is not equal to C1. // we can't use `a *= b` when C2 is not equal to C1.
impl<T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<Matrix<T, R2, C1, SB>> for Matrix<T, R1, C1, SA> impl<T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<Matrix<T, R2, C1, SB>>
for Matrix<T, R1, C1, SA>
where where
T: Scalar + Zero + One + ClosedAdd + ClosedMul, T: Scalar + Zero + One + ClosedAdd + ClosedMul,
SB: Storage<T, R2, C1>, SB: Storage<T, R2, C1>,
@ -654,7 +655,8 @@ where
} }
} }
impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix<T, R2, C1, SB>> for Matrix<T, R1, C1, SA> impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix<T, R2, C1, SB>>
for Matrix<T, R1, C1, SA>
where where
T: Scalar + Zero + One + ClosedAdd + ClosedMul, T: Scalar + Zero + One + ClosedAdd + ClosedMul,
SB: Storage<T, R2, C1>, SB: Storage<T, R2, C1>,
@ -794,6 +796,7 @@ where
ShapeConstraint: SameNumberOfRows<R3, R1> ShapeConstraint: SameNumberOfRows<R3, R1>
+ SameNumberOfColumns<C3, C2> + SameNumberOfColumns<C3, C2>
+ AreMultipliable<R1, C1, R2, C2>, + AreMultipliable<R1, C1, R2, C2>,
DefaultAllocator: Allocator<T, R3, C3>,
{ {
out.gemm_z(T::one(), self, rhs); out.gemm_z(T::one(), self, rhs);
} }