GEMM on empty matrices: properly take the beta parameter into account.

This commit is contained in:
sebcrozet 2019-09-01 21:08:06 +02:00
parent 49abb14e1b
commit 94dd355cad
2 changed files with 124 additions and 76 deletions

View File

@ -565,7 +565,14 @@ where
); );
if ncols2 == 0 { if ncols2 == 0 {
self.fill(N::zero()); // NOTE: we can't just always multiply by beta
// because we documented the guaranty that `self` is
// never read if `beta` is zero.
if beta.is_zero() {
self.fill(N::zero());
} else {
*self *= beta;
}
return; return;
} }
@ -992,98 +999,109 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
#[cfg(feature = "std")] #[cfg(feature = "std")]
{ {
// matrixmultiply can be used only if the std feature is available.
let nrows1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape();
assert_eq!(
ncols2, nrows3,
"gemm: dimensions mismatch for multiplication."
);
assert_eq!(
(nrows1, ncols1),
(nrows2, ncols3),
"gemm: dimensions mismatch for addition."
);
if a.ncols() == 0 {
self.fill(N::zero());
return;
}
// We assume large matrices will be Dynamic but small matrices static. // We assume large matrices will be Dynamic but small matrices static.
// We could use matrixmultiply for large statically-sized matrices but the performance // We could use matrixmultiply for large statically-sized matrices but the performance
// threshold to activate it would be different from SMALL_DIM because our code optimizes // threshold to activate it would be different from SMALL_DIM because our code optimizes
// better for statically-sized matrices. // better for statically-sized matrices.
let is_dynamic = R1::is::<Dynamic>() if R1::is::<Dynamic>()
|| C1::is::<Dynamic>() || C1::is::<Dynamic>()
|| R2::is::<Dynamic>() || R2::is::<Dynamic>()
|| C2::is::<Dynamic>() || C2::is::<Dynamic>()
|| R3::is::<Dynamic>() || R3::is::<Dynamic>()
|| C3::is::<Dynamic>(); || C3::is::<Dynamic>() {
// Threshold determined empirically. // matrixmultiply can be used only if the std feature is available.
const SMALL_DIM: usize = 5; let nrows1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape();
if is_dynamic // Threshold determined empirically.
&& nrows1 > SMALL_DIM const SMALL_DIM: usize = 5;
&& ncols1 > SMALL_DIM
&& nrows2 > SMALL_DIM
&& ncols2 > SMALL_DIM
{
if N::is::<f32>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = self.strides();
unsafe { if nrows1 > SMALL_DIM
matrixmultiply::sgemm( && ncols1 > SMALL_DIM
nrows2, && nrows2 > SMALL_DIM
ncols2, && ncols2 > SMALL_DIM
ncols3, {
mem::transmute_copy(&alpha), assert_eq!(
a.data.ptr() as *const f32, ncols2, nrows3,
rsa as isize, "gemm: dimensions mismatch for multiplication."
csa as isize, );
b.data.ptr() as *const f32, assert_eq!(
rsb as isize, (nrows1, ncols1),
csb as isize, (nrows2, ncols3),
mem::transmute_copy(&beta), "gemm: dimensions mismatch for addition."
self.data.ptr_mut() as *mut f32, );
rsc as isize,
csc as isize, // NOTE: this case should never happen because we enter this
); // codepath only when ncols2 > SMALL_DIM. Though we keep this
// here just in case if in the future we change the conditions to
// enter this codepath.
if ncols2 == 0 {
// NOTE: we can't just always multiply by beta
// because we documented the guaranty that `self` is
// never read if `beta` is zero.
if beta.is_zero() {
self.fill(N::zero());
} else {
*self *= beta;
}
return;
} }
return;
} else if N::is::<f64>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = self.strides();
unsafe { if N::is::<f32>() {
matrixmultiply::dgemm( let (rsa, csa) = a.strides();
nrows2, let (rsb, csb) = b.strides();
ncols2, let (rsc, csc) = self.strides();
ncols3,
mem::transmute_copy(&alpha), unsafe {
a.data.ptr() as *const f64, matrixmultiply::sgemm(
rsa as isize, nrows2,
csa as isize, ncols2,
b.data.ptr() as *const f64, ncols3,
rsb as isize, mem::transmute_copy(&alpha),
csb as isize, a.data.ptr() as *const f32,
mem::transmute_copy(&beta), rsa as isize,
self.data.ptr_mut() as *mut f64, csa as isize,
rsc as isize, b.data.ptr() as *const f32,
csc as isize, rsb as isize,
); csb as isize,
mem::transmute_copy(&beta),
self.data.ptr_mut() as *mut f32,
rsc as isize,
csc as isize,
);
}
return;
} else if N::is::<f64>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = self.strides();
unsafe {
matrixmultiply::dgemm(
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f64,
rsa as isize,
csa as isize,
b.data.ptr() as *const f64,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
self.data.ptr_mut() as *mut f64,
rsc as isize,
csc as isize,
);
}
return;
} }
return;
} }
} }
} }
for j1 in 0..ncols1 { for j1 in 0..ncols1 {
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
self.column_mut(j1).gemv(alpha, a, &b.column(j1), beta); self.column_mut(j1).gemv(alpha, a, &b.column(j1), beta);

View File

@ -13,6 +13,11 @@ fn empty_matrix_mul_matrix() {
let m1 = DMatrix::<f32>::zeros(3, 0); let m1 = DMatrix::<f32>::zeros(3, 0);
let m2 = DMatrix::<f32>::zeros(0, 4); let m2 = DMatrix::<f32>::zeros(0, 4);
assert_eq!(m1 * m2, DMatrix::zeros(3, 4)); assert_eq!(m1 * m2, DMatrix::zeros(3, 4));
// Still works with larger matrices.
let m1 = DMatrix::<f32>::zeros(13, 0);
let m2 = DMatrix::<f32>::zeros(0, 14);
assert_eq!(m1 * m2, DMatrix::zeros(13, 14));
} }
#[test] #[test]
@ -27,4 +32,29 @@ fn empty_matrix_tr_mul_matrix() {
let m1 = DMatrix::<f32>::zeros(0, 3); let m1 = DMatrix::<f32>::zeros(0, 3);
let m2 = DMatrix::<f32>::zeros(0, 4); let m2 = DMatrix::<f32>::zeros(0, 4);
assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4)); assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4));
}
#[test]
fn empty_matrix_gemm() {
let mut res = DMatrix::repeat(3, 4, 1.0);
let m1 = DMatrix::<f32>::zeros(3, 0);
let m2 = DMatrix::<f32>::zeros(0, 4);
res.gemm(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
// Still works with lager matrices.
let mut res = DMatrix::repeat(13, 14, 1.0);
let m1 = DMatrix::<f32>::zeros(13, 0);
let m2 = DMatrix::<f32>::zeros(0, 14);
res.gemm(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(13, 14, 0.5));
}
#[test]
fn empty_matrix_gemm_tr() {
let mut res = DMatrix::repeat(3, 4, 1.0);
let m1 = DMatrix::<f32>::zeros(0, 3);
let m2 = DMatrix::<f32>::zeros(0, 4);
res.gemm_tr(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
} }