GEMM on empty matrices: properly take the beta parameter into account.

This commit is contained in:
sebcrozet 2019-09-01 21:08:06 +02:00 committed by Sébastien Crozet
parent f9f7ddd08f
commit be41cb96e8
2 changed files with 124 additions and 76 deletions

View File

@ -565,7 +565,14 @@ where
); );
if ncols2 == 0 { if ncols2 == 0 {
// NOTE: we can't just always multiply by beta
// because we documented the guaranty that `self` is
// never read if `beta` is zero.
if beta.is_zero() {
self.fill(N::zero()); self.fill(N::zero());
} else {
*self *= beta;
}
return; return;
} }
@ -992,11 +999,29 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
#[cfg(feature = "std")] #[cfg(feature = "std")]
{ {
// We assume large matrices will be Dynamic but small matrices static.
// We could use matrixmultiply for large statically-sized matrices but the performance
// threshold to activate it would be different from SMALL_DIM because our code optimizes
// better for statically-sized matrices.
if R1::is::<Dynamic>()
|| C1::is::<Dynamic>()
|| R2::is::<Dynamic>()
|| C2::is::<Dynamic>()
|| R3::is::<Dynamic>()
|| C3::is::<Dynamic>() {
// matrixmultiply can be used only if the std feature is available. // matrixmultiply can be used only if the std feature is available.
let nrows1 = self.nrows(); let nrows1 = self.nrows();
let (nrows2, ncols2) = a.shape(); let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape(); let (nrows3, ncols3) = b.shape();
// Threshold determined empirically.
const SMALL_DIM: usize = 5;
if nrows1 > SMALL_DIM
&& ncols1 > SMALL_DIM
&& nrows2 > SMALL_DIM
&& ncols2 > SMALL_DIM
{
assert_eq!( assert_eq!(
ncols2, nrows3, ncols2, nrows3,
"gemm: dimensions mismatch for multiplication." "gemm: dimensions mismatch for multiplication."
@ -1007,31 +1032,22 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
"gemm: dimensions mismatch for addition." "gemm: dimensions mismatch for addition."
); );
// NOTE: this case should never happen because we enter this
if a.ncols() == 0 { // codepath only when ncols2 > SMALL_DIM. Though we keep this
// here just in case if in the future we change the conditions to
// enter this codepath.
if ncols2 == 0 {
// NOTE: we can't just always multiply by beta
// because we documented the guaranty that `self` is
// never read if `beta` is zero.
if beta.is_zero() {
self.fill(N::zero()); self.fill(N::zero());
} else {
*self *= beta;
}
return; return;
} }
// We assume large matrices will be Dynamic but small matrices static.
// We could use matrixmultiply for large statically-sized matrices but the performance
// threshold to activate it would be different from SMALL_DIM because our code optimizes
// better for statically-sized matrices.
let is_dynamic = R1::is::<Dynamic>()
|| C1::is::<Dynamic>()
|| R2::is::<Dynamic>()
|| C2::is::<Dynamic>()
|| R3::is::<Dynamic>()
|| C3::is::<Dynamic>();
// Threshold determined empirically.
const SMALL_DIM: usize = 5;
if is_dynamic
&& nrows1 > SMALL_DIM
&& ncols1 > SMALL_DIM
&& nrows2 > SMALL_DIM
&& ncols2 > SMALL_DIM
{
if N::is::<f32>() { if N::is::<f32>() {
let (rsa, csa) = a.strides(); let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides(); let (rsb, csb) = b.strides();
@ -1083,6 +1099,8 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
} }
} }
} }
}
for j1 in 0..ncols1 { for j1 in 0..ncols1 {
// FIXME: avoid bound checks. // FIXME: avoid bound checks.

View File

@ -13,6 +13,11 @@ fn empty_matrix_mul_matrix() {
let m1 = DMatrix::<f32>::zeros(3, 0); let m1 = DMatrix::<f32>::zeros(3, 0);
let m2 = DMatrix::<f32>::zeros(0, 4); let m2 = DMatrix::<f32>::zeros(0, 4);
assert_eq!(m1 * m2, DMatrix::zeros(3, 4)); assert_eq!(m1 * m2, DMatrix::zeros(3, 4));
// Still works with larger matrices.
let m1 = DMatrix::<f32>::zeros(13, 0);
let m2 = DMatrix::<f32>::zeros(0, 14);
assert_eq!(m1 * m2, DMatrix::zeros(13, 14));
} }
#[test] #[test]
@ -28,3 +33,28 @@ fn empty_matrix_tr_mul_matrix() {
let m2 = DMatrix::<f32>::zeros(0, 4); let m2 = DMatrix::<f32>::zeros(0, 4);
assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4)); assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4));
} }
#[test]
fn empty_matrix_gemm() {
let mut res = DMatrix::repeat(3, 4, 1.0);
let m1 = DMatrix::<f32>::zeros(3, 0);
let m2 = DMatrix::<f32>::zeros(0, 4);
res.gemm(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
// Still works with lager matrices.
let mut res = DMatrix::repeat(13, 14, 1.0);
let m1 = DMatrix::<f32>::zeros(13, 0);
let m2 = DMatrix::<f32>::zeros(0, 14);
res.gemm(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(13, 14, 0.5));
}
#[test]
fn empty_matrix_gemm_tr() {
let mut res = DMatrix::repeat(3, 4, 1.0);
let m1 = DMatrix::<f32>::zeros(0, 3);
let m2 = DMatrix::<f32>::zeros(0, 4);
res.gemm_tr(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
}