From 54e9750191aec7f0a2dfca9444454aece0cc7e07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Violeta=20Hern=C3=A1ndez?= Date: Fri, 16 Jul 2021 00:27:16 -0500 Subject: [PATCH] Tied some blas loose strings --- src/base/blas.rs | 62 +++++++++++++++++------------------ src/base/construction.rs | 16 ++++----- src/base/default_allocator.rs | 1 - src/base/edition.rs | 3 +- src/base/matrix_slice.rs | 6 ++-- src/base/ops.rs | 7 ++-- 6 files changed, 48 insertions(+), 47 deletions(-) diff --git a/src/base/blas.rs b/src/base/blas.rs index 2ef0dff7..57d93c87 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -6,7 +6,7 @@ //! that return an owned matrix that would otherwise result from setting a //! parameter to zero in the other methods. -use crate::{OMatrix, OVector, SimdComplexField}; +use crate::{OMatrix, SimdComplexField}; #[cfg(feature = "std")] use matrixmultiply; use num::{One, Zero}; @@ -795,7 +795,7 @@ where } } -impl OMatrix +impl, R1, C1>> Matrix, R1, C1, S> where T: Scalar + Zero + One + ClosedAdd + ClosedMul, DefaultAllocator: Allocator, @@ -821,27 +821,18 @@ where /// ``` #[inline] pub fn gemm_z( + &mut self, alpha: T, a: &Matrix, b: &Matrix, - ) -> Self - where + ) where SB: Storage, SC: Storage, ShapeConstraint: SameNumberOfRows + SameNumberOfColumns + AreMultipliable, { - let (nrows1, ncols1) = a.shape(); - let (nrows2, ncols2) = b.shape(); - - assert_eq!( - ncols1, nrows2, - "gemm: dimensions mismatch for multiplication." - ); - - let mut res = - Matrix::new_uninitialized_generic(R1::from_usize(nrows1), C1::from_usize(ncols2)); + let ncols1 = self.ncols(); #[cfg(feature = "std")] { @@ -857,6 +848,9 @@ where || C3::is::() { // matrixmultiply can be used only if the std feature is available. + let nrows1 = self.nrows(); + let (nrows2, ncols2) = a.shape(); + let (nrows3, ncols3) = b.shape(); // Threshold determined empirically. const SMALL_DIM: usize = 5; @@ -866,29 +860,35 @@ where && nrows2 > SMALL_DIM && ncols2 > SMALL_DIM { + assert_eq!( + ncols1, nrows2, + "gemm: dimensions mismatch for multiplication." + ); + assert_eq!( + (nrows1, ncols1), + (nrows2, ncols3), + "gemm: dimensions mismatch for addition." + ); + // NOTE: this case should never happen because we enter this // codepath only when ncols2 > SMALL_DIM. Though we keep this // here just in case if in the future we change the conditions to // enter this codepath. if ncols1 == 0 { - // NOTE: we can't just always multiply by beta - // because we documented the guaranty that `self` is - // never read if `beta` is zero. - - // Safety: this buffer is empty. - return res.assume_init(); + self.fill_fn(|| MaybeUninit::new(T::zero())); + return; } let (rsa, csa) = a.strides(); let (rsb, csb) = b.strides(); - let (rsc, csc) = res.strides(); + let (rsc, csc) = self.strides(); if T::is::() { unsafe { matrixmultiply::sgemm( - nrows1, - ncols1, + nrows2, ncols2, + ncols3, mem::transmute_copy(&alpha), a.data.ptr() as *const f32, rsa as isize, @@ -897,19 +897,19 @@ where rsb as isize, csb as isize, 0.0, - res.data.ptr_mut() as *mut f32, + self.data.ptr_mut() as *mut f32, rsc as isize, csc as isize, ); - return res.assume_init(); + return; } } else if T::is::() { unsafe { matrixmultiply::dgemm( - nrows1, - ncols1, + nrows2, ncols2, + ncols3, mem::transmute_copy(&alpha), a.data.ptr() as *const f64, rsa as isize, @@ -918,12 +918,12 @@ where rsb as isize, csb as isize, 0.0, - res.data.ptr_mut() as *mut f64, + self.data.ptr_mut() as *mut f64, rsc as isize, csc as isize, ); - return res.assume_init(); + return ; } } } @@ -932,11 +932,9 @@ where for j1 in 0..ncols1 { // TODO: avoid bound checks. - res.column_mut(j1) + self.column_mut(j1) .gemv_z(alpha.inlined_clone(), a, &b.column(j1)); } - - unsafe { res.assume_init() } } } diff --git a/src/base/construction.rs b/src/base/construction.rs index f0709917..6f4893ae 100644 --- a/src/base/construction.rs +++ b/src/base/construction.rs @@ -633,13 +633,13 @@ where ); // Arguments for non-generic constructors. } -impl OMatrix, R, C> +impl OMatrix where DefaultAllocator: Allocator, { /// Creates a new uninitialized matrix or vector. #[inline] - pub fn new_uninitialized() -> Self { + pub fn new_uninitialized() -> OMatrix, R, C> { Self::new_uninitialized_generic(R::name(), C::name()) } } @@ -655,13 +655,13 @@ where ncols); } -impl OMatrix, R, Dynamic> +impl OMatrix where DefaultAllocator: Allocator, { /// Creates a new uninitialized matrix or vector. #[inline] - pub fn new_uninitialized(ncols: usize) -> Self { + pub fn new_uninitialized(ncols: usize) -> OMatrix, R, Dynamic> { Self::new_uninitialized_generic(R::name(), Dynamic::new(ncols)) } } @@ -677,13 +677,13 @@ where nrows); } -impl OMatrix, Dynamic, C> +impl OMatrix where DefaultAllocator: Allocator, { /// Creates a new uninitialized matrix or vector. #[inline] - pub fn new_uninitialized(nrows: usize) -> Self { + pub fn new_uninitialized(nrows: usize) -> OMatrix, Dynamic, C> { Self::new_uninitialized_generic(Dynamic::new(nrows), C::name()) } } @@ -699,13 +699,13 @@ where nrows, ncols); } -impl OMatrix, Dynamic, Dynamic> +impl OMatrix where DefaultAllocator: Allocator, { /// Creates a new uninitialized matrix or vector. #[inline] - pub fn new_uninitialized(nrows: usize, ncols: usize) -> Self { + pub fn new_uninitialized(nrows: usize, ncols: usize) -> OMatrix, Dynamic, Dynamic> { Self::new_uninitialized_generic(Dynamic::new(nrows), Dynamic::new(ncols)) } } diff --git a/src/base/default_allocator.rs b/src/base/default_allocator.rs index b9cb793c..4991312e 100644 --- a/src/base/default_allocator.rs +++ b/src/base/default_allocator.rs @@ -4,7 +4,6 @@ //! heap-allocated buffers for matrices with at least one dimension unknown at compile-time. use std::cmp; -use std::mem; use std::mem::ManuallyDrop; use std::mem::MaybeUninit; use std::ptr; diff --git a/src/base/edition.rs b/src/base/edition.rs index f013ffd3..c9dc402e 100644 --- a/src/base/edition.rs +++ b/src/base/edition.rs @@ -53,7 +53,8 @@ impl> Matrix { { let irows = irows.into_iter(); let ncols = self.data.shape().1; - let mut res = OMatrix::::new_uninitialized_generic(Dynamic::new(irows.len()), ncols); + let mut res = + OMatrix::::new_uninitialized_generic(Dynamic::new(irows.len()), ncols); // First, check that all the indices from irows are valid. // This will allow us to use unchecked access in the inner loop. diff --git a/src/base/matrix_slice.rs b/src/base/matrix_slice.rs index d8ccb44f..30f30c41 100644 --- a/src/base/matrix_slice.rs +++ b/src/base/matrix_slice.rs @@ -223,7 +223,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> SliceStorage<'a, MaybeUninit, R, C, RStride, CStride> { pub unsafe fn assume_init(self) -> SliceStorage<'a, T, R, C, RStride, CStride> { - Self::from_raw_parts(self.ptr as *const T, self.shape, self.strides) + SliceStorage::from_raw_parts(self.ptr as *const T, self.shape, self.strides) } } @@ -231,7 +231,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> SliceStorageMut<'a, MaybeUninit, R, C, RStride, CStride> { pub unsafe fn assume_init(self) -> SliceStorageMut<'a, T, R, C, RStride, CStride> { - Self::from_raw_parts(self.ptr as *mut T, self.shape, self.strides) + SliceStorageMut::from_raw_parts(self.ptr as *mut T, self.shape, self.strides) } } @@ -606,7 +606,7 @@ macro_rules! matrix_slice_impl( /// Returns a slice containing the entire matrix. pub fn $full_slice($me: $Me) -> $MatrixSlice { let (nrows, ncols) = $me.shape(); - $me.generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols))) + $me.$generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols))) } /* diff --git a/src/base/ops.rs b/src/base/ops.rs index 44b1c7c5..a595a2b1 100644 --- a/src/base/ops.rs +++ b/src/base/ops.rs @@ -640,7 +640,8 @@ where // TODO: this is too restrictive: // − we can't use `a *= b` when `a` is a mutable slice. // − we can't use `a *= b` when C2 is not equal to C1. -impl MulAssign> for Matrix +impl MulAssign> + for Matrix where T: Scalar + Zero + One + ClosedAdd + ClosedMul, SB: Storage, @@ -654,7 +655,8 @@ where } } -impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix> for Matrix +impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix> + for Matrix where T: Scalar + Zero + One + ClosedAdd + ClosedMul, SB: Storage, @@ -794,6 +796,7 @@ where ShapeConstraint: SameNumberOfRows + SameNumberOfColumns + AreMultipliable, + DefaultAllocator: Allocator, { out.gemm_z(T::one(), self, rhs); }