From be41cb96e8587412419671482b1862cf47889238 Mon Sep 17 00:00:00 2001 From: sebcrozet Date: Sun, 1 Sep 2019 21:08:06 +0200 Subject: [PATCH] GEMM on empty matrices: properly take the beta parameter into account. --- src/base/blas.rs | 170 ++++++++++++++++++++++++-------------------- tests/core/empty.rs | 30 ++++++++ 2 files changed, 124 insertions(+), 76 deletions(-) diff --git a/src/base/blas.rs b/src/base/blas.rs index 1875e734..cc8f2345 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -565,7 +565,14 @@ where ); if ncols2 == 0 { - self.fill(N::zero()); + // NOTE: we can't just always multiply by beta + // because we documented the guaranty that `self` is + // never read if `beta` is zero. + if beta.is_zero() { + self.fill(N::zero()); + } else { + *self *= beta; + } return; } @@ -992,98 +999,109 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul #[cfg(feature = "std")] { - // matrixmultiply can be used only if the std feature is available. - let nrows1 = self.nrows(); - let (nrows2, ncols2) = a.shape(); - let (nrows3, ncols3) = b.shape(); - - assert_eq!( - ncols2, nrows3, - "gemm: dimensions mismatch for multiplication." - ); - assert_eq!( - (nrows1, ncols1), - (nrows2, ncols3), - "gemm: dimensions mismatch for addition." - ); - - - if a.ncols() == 0 { - self.fill(N::zero()); - return; - } - // We assume large matrices will be Dynamic but small matrices static. // We could use matrixmultiply for large statically-sized matrices but the performance // threshold to activate it would be different from SMALL_DIM because our code optimizes // better for statically-sized matrices. - let is_dynamic = R1::is::() + if R1::is::() || C1::is::() || R2::is::() || C2::is::() || R3::is::() - || C3::is::(); - // Threshold determined empirically. - const SMALL_DIM: usize = 5; + || C3::is::() { + // matrixmultiply can be used only if the std feature is available. + let nrows1 = self.nrows(); + let (nrows2, ncols2) = a.shape(); + let (nrows3, ncols3) = b.shape(); - if is_dynamic - && nrows1 > SMALL_DIM - && ncols1 > SMALL_DIM - && nrows2 > SMALL_DIM - && ncols2 > SMALL_DIM - { - if N::is::() { - let (rsa, csa) = a.strides(); - let (rsb, csb) = b.strides(); - let (rsc, csc) = self.strides(); + // Threshold determined empirically. + const SMALL_DIM: usize = 5; - unsafe { - matrixmultiply::sgemm( - nrows2, - ncols2, - ncols3, - mem::transmute_copy(&alpha), - a.data.ptr() as *const f32, - rsa as isize, - csa as isize, - b.data.ptr() as *const f32, - rsb as isize, - csb as isize, - mem::transmute_copy(&beta), - self.data.ptr_mut() as *mut f32, - rsc as isize, - csc as isize, - ); + if nrows1 > SMALL_DIM + && ncols1 > SMALL_DIM + && nrows2 > SMALL_DIM + && ncols2 > SMALL_DIM + { + assert_eq!( + ncols2, nrows3, + "gemm: dimensions mismatch for multiplication." + ); + assert_eq!( + (nrows1, ncols1), + (nrows2, ncols3), + "gemm: dimensions mismatch for addition." + ); + + // NOTE: this case should never happen because we enter this + // codepath only when ncols2 > SMALL_DIM. Though we keep this + // here just in case if in the future we change the conditions to + // enter this codepath. + if ncols2 == 0 { + // NOTE: we can't just always multiply by beta + // because we documented the guaranty that `self` is + // never read if `beta` is zero. + if beta.is_zero() { + self.fill(N::zero()); + } else { + *self *= beta; + } + return; } - return; - } else if N::is::() { - let (rsa, csa) = a.strides(); - let (rsb, csb) = b.strides(); - let (rsc, csc) = self.strides(); - unsafe { - matrixmultiply::dgemm( - nrows2, - ncols2, - ncols3, - mem::transmute_copy(&alpha), - a.data.ptr() as *const f64, - rsa as isize, - csa as isize, - b.data.ptr() as *const f64, - rsb as isize, - csb as isize, - mem::transmute_copy(&beta), - self.data.ptr_mut() as *mut f64, - rsc as isize, - csc as isize, - ); + if N::is::() { + let (rsa, csa) = a.strides(); + let (rsb, csb) = b.strides(); + let (rsc, csc) = self.strides(); + + unsafe { + matrixmultiply::sgemm( + nrows2, + ncols2, + ncols3, + mem::transmute_copy(&alpha), + a.data.ptr() as *const f32, + rsa as isize, + csa as isize, + b.data.ptr() as *const f32, + rsb as isize, + csb as isize, + mem::transmute_copy(&beta), + self.data.ptr_mut() as *mut f32, + rsc as isize, + csc as isize, + ); + } + return; + } else if N::is::() { + let (rsa, csa) = a.strides(); + let (rsb, csb) = b.strides(); + let (rsc, csc) = self.strides(); + + unsafe { + matrixmultiply::dgemm( + nrows2, + ncols2, + ncols3, + mem::transmute_copy(&alpha), + a.data.ptr() as *const f64, + rsa as isize, + csa as isize, + b.data.ptr() as *const f64, + rsb as isize, + csb as isize, + mem::transmute_copy(&beta), + self.data.ptr_mut() as *mut f64, + rsc as isize, + csc as isize, + ); + } + return; } - return; } } } + for j1 in 0..ncols1 { // FIXME: avoid bound checks. self.column_mut(j1).gemv(alpha, a, &b.column(j1), beta); diff --git a/tests/core/empty.rs b/tests/core/empty.rs index a2377f2f..3e17ee1b 100644 --- a/tests/core/empty.rs +++ b/tests/core/empty.rs @@ -13,6 +13,11 @@ fn empty_matrix_mul_matrix() { let m1 = DMatrix::::zeros(3, 0); let m2 = DMatrix::::zeros(0, 4); assert_eq!(m1 * m2, DMatrix::zeros(3, 4)); + + // Still works with larger matrices. + let m1 = DMatrix::::zeros(13, 0); + let m2 = DMatrix::::zeros(0, 14); + assert_eq!(m1 * m2, DMatrix::zeros(13, 14)); } #[test] @@ -27,4 +32,29 @@ fn empty_matrix_tr_mul_matrix() { let m1 = DMatrix::::zeros(0, 3); let m2 = DMatrix::::zeros(0, 4); assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4)); +} + +#[test] +fn empty_matrix_gemm() { + let mut res = DMatrix::repeat(3, 4, 1.0); + let m1 = DMatrix::::zeros(3, 0); + let m2 = DMatrix::::zeros(0, 4); + res.gemm(1.0, &m1, &m2, 0.5); + assert_eq!(res, DMatrix::repeat(3, 4, 0.5)); + + // Still works with lager matrices. + let mut res = DMatrix::repeat(13, 14, 1.0); + let m1 = DMatrix::::zeros(13, 0); + let m2 = DMatrix::::zeros(0, 14); + res.gemm(1.0, &m1, &m2, 0.5); + assert_eq!(res, DMatrix::repeat(13, 14, 0.5)); +} + +#[test] +fn empty_matrix_gemm_tr() { + let mut res = DMatrix::repeat(3, 4, 1.0); + let m1 = DMatrix::::zeros(0, 3); + let m2 = DMatrix::::zeros(0, 4); + res.gemm_tr(1.0, &m1, &m2, 0.5); + assert_eq!(res, DMatrix::repeat(3, 4, 0.5)); } \ No newline at end of file