GEMM on empty matrices: properly take the beta parameter into account.

This commit is contained in:
sebcrozet 2019-09-01 21:08:06 +02:00
parent 49abb14e1b
commit 94dd355cad
2 changed files with 124 additions and 76 deletions

View File

@ -565,7 +565,14 @@ where
);
if ncols2 == 0 {
self.fill(N::zero());
// NOTE: we can't just always multiply by beta
// because we documented the guaranty that `self` is
// never read if `beta` is zero.
if beta.is_zero() {
self.fill(N::zero());
} else {
*self *= beta;
}
return;
}
@ -992,98 +999,109 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
#[cfg(feature = "std")]
{
// matrixmultiply can be used only if the std feature is available.
let nrows1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape();
assert_eq!(
ncols2, nrows3,
"gemm: dimensions mismatch for multiplication."
);
assert_eq!(
(nrows1, ncols1),
(nrows2, ncols3),
"gemm: dimensions mismatch for addition."
);
if a.ncols() == 0 {
self.fill(N::zero());
return;
}
// We assume large matrices will be Dynamic but small matrices static.
// We could use matrixmultiply for large statically-sized matrices but the performance
// threshold to activate it would be different from SMALL_DIM because our code optimizes
// better for statically-sized matrices.
let is_dynamic = R1::is::<Dynamic>()
if R1::is::<Dynamic>()
|| C1::is::<Dynamic>()
|| R2::is::<Dynamic>()
|| C2::is::<Dynamic>()
|| R3::is::<Dynamic>()
|| C3::is::<Dynamic>();
// Threshold determined empirically.
const SMALL_DIM: usize = 5;
|| C3::is::<Dynamic>() {
// matrixmultiply can be used only if the std feature is available.
let nrows1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape();
if is_dynamic
&& nrows1 > SMALL_DIM
&& ncols1 > SMALL_DIM
&& nrows2 > SMALL_DIM
&& ncols2 > SMALL_DIM
{
if N::is::<f32>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = self.strides();
// Threshold determined empirically.
const SMALL_DIM: usize = 5;
unsafe {
matrixmultiply::sgemm(
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f32,
rsa as isize,
csa as isize,
b.data.ptr() as *const f32,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
self.data.ptr_mut() as *mut f32,
rsc as isize,
csc as isize,
);
if nrows1 > SMALL_DIM
&& ncols1 > SMALL_DIM
&& nrows2 > SMALL_DIM
&& ncols2 > SMALL_DIM
{
assert_eq!(
ncols2, nrows3,
"gemm: dimensions mismatch for multiplication."
);
assert_eq!(
(nrows1, ncols1),
(nrows2, ncols3),
"gemm: dimensions mismatch for addition."
);
// NOTE: this case should never happen because we enter this
// codepath only when ncols2 > SMALL_DIM. Though we keep this
// here just in case if in the future we change the conditions to
// enter this codepath.
if ncols2 == 0 {
// NOTE: we can't just always multiply by beta
// because we documented the guaranty that `self` is
// never read if `beta` is zero.
if beta.is_zero() {
self.fill(N::zero());
} else {
*self *= beta;
}
return;
}
return;
} else if N::is::<f64>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = self.strides();
unsafe {
matrixmultiply::dgemm(
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f64,
rsa as isize,
csa as isize,
b.data.ptr() as *const f64,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
self.data.ptr_mut() as *mut f64,
rsc as isize,
csc as isize,
);
if N::is::<f32>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = self.strides();
unsafe {
matrixmultiply::sgemm(
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f32,
rsa as isize,
csa as isize,
b.data.ptr() as *const f32,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
self.data.ptr_mut() as *mut f32,
rsc as isize,
csc as isize,
);
}
return;
} else if N::is::<f64>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = self.strides();
unsafe {
matrixmultiply::dgemm(
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f64,
rsa as isize,
csa as isize,
b.data.ptr() as *const f64,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
self.data.ptr_mut() as *mut f64,
rsc as isize,
csc as isize,
);
}
return;
}
return;
}
}
}
for j1 in 0..ncols1 {
// FIXME: avoid bound checks.
self.column_mut(j1).gemv(alpha, a, &b.column(j1), beta);

View File

@ -13,6 +13,11 @@ fn empty_matrix_mul_matrix() {
let m1 = DMatrix::<f32>::zeros(3, 0);
let m2 = DMatrix::<f32>::zeros(0, 4);
assert_eq!(m1 * m2, DMatrix::zeros(3, 4));
// Still works with larger matrices.
let m1 = DMatrix::<f32>::zeros(13, 0);
let m2 = DMatrix::<f32>::zeros(0, 14);
assert_eq!(m1 * m2, DMatrix::zeros(13, 14));
}
#[test]
@ -28,3 +33,28 @@ fn empty_matrix_tr_mul_matrix() {
let m2 = DMatrix::<f32>::zeros(0, 4);
assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4));
}
#[test]
fn empty_matrix_gemm() {
let mut res = DMatrix::repeat(3, 4, 1.0);
let m1 = DMatrix::<f32>::zeros(3, 0);
let m2 = DMatrix::<f32>::zeros(0, 4);
res.gemm(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
// Still works with lager matrices.
let mut res = DMatrix::repeat(13, 14, 1.0);
let m1 = DMatrix::<f32>::zeros(13, 0);
let m2 = DMatrix::<f32>::zeros(0, 14);
res.gemm(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(13, 14, 0.5));
}
#[test]
fn empty_matrix_gemm_tr() {
let mut res = DMatrix::repeat(3, 4, 1.0);
let m1 = DMatrix::<f32>::zeros(0, 3);
let m2 = DMatrix::<f32>::zeros(0, 4);
res.gemm_tr(1.0, &m1, &m2, 0.5);
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
}