GEMM on empty matrices: properly take the beta parameter into account.
This commit is contained in:
parent
f9f7ddd08f
commit
be41cb96e8
170
src/base/blas.rs
170
src/base/blas.rs
|
@ -565,7 +565,14 @@ where
|
|||
);
|
||||
|
||||
if ncols2 == 0 {
|
||||
self.fill(N::zero());
|
||||
// NOTE: we can't just always multiply by beta
|
||||
// because we documented the guaranty that `self` is
|
||||
// never read if `beta` is zero.
|
||||
if beta.is_zero() {
|
||||
self.fill(N::zero());
|
||||
} else {
|
||||
*self *= beta;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -992,98 +999,109 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
|
|||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
// matrixmultiply can be used only if the std feature is available.
|
||||
let nrows1 = self.nrows();
|
||||
let (nrows2, ncols2) = a.shape();
|
||||
let (nrows3, ncols3) = b.shape();
|
||||
|
||||
assert_eq!(
|
||||
ncols2, nrows3,
|
||||
"gemm: dimensions mismatch for multiplication."
|
||||
);
|
||||
assert_eq!(
|
||||
(nrows1, ncols1),
|
||||
(nrows2, ncols3),
|
||||
"gemm: dimensions mismatch for addition."
|
||||
);
|
||||
|
||||
|
||||
if a.ncols() == 0 {
|
||||
self.fill(N::zero());
|
||||
return;
|
||||
}
|
||||
|
||||
// We assume large matrices will be Dynamic but small matrices static.
|
||||
// We could use matrixmultiply for large statically-sized matrices but the performance
|
||||
// threshold to activate it would be different from SMALL_DIM because our code optimizes
|
||||
// better for statically-sized matrices.
|
||||
let is_dynamic = R1::is::<Dynamic>()
|
||||
if R1::is::<Dynamic>()
|
||||
|| C1::is::<Dynamic>()
|
||||
|| R2::is::<Dynamic>()
|
||||
|| C2::is::<Dynamic>()
|
||||
|| R3::is::<Dynamic>()
|
||||
|| C3::is::<Dynamic>();
|
||||
// Threshold determined empirically.
|
||||
const SMALL_DIM: usize = 5;
|
||||
|| C3::is::<Dynamic>() {
|
||||
// matrixmultiply can be used only if the std feature is available.
|
||||
let nrows1 = self.nrows();
|
||||
let (nrows2, ncols2) = a.shape();
|
||||
let (nrows3, ncols3) = b.shape();
|
||||
|
||||
if is_dynamic
|
||||
&& nrows1 > SMALL_DIM
|
||||
&& ncols1 > SMALL_DIM
|
||||
&& nrows2 > SMALL_DIM
|
||||
&& ncols2 > SMALL_DIM
|
||||
{
|
||||
if N::is::<f32>() {
|
||||
let (rsa, csa) = a.strides();
|
||||
let (rsb, csb) = b.strides();
|
||||
let (rsc, csc) = self.strides();
|
||||
// Threshold determined empirically.
|
||||
const SMALL_DIM: usize = 5;
|
||||
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
nrows2,
|
||||
ncols2,
|
||||
ncols3,
|
||||
mem::transmute_copy(&alpha),
|
||||
a.data.ptr() as *const f32,
|
||||
rsa as isize,
|
||||
csa as isize,
|
||||
b.data.ptr() as *const f32,
|
||||
rsb as isize,
|
||||
csb as isize,
|
||||
mem::transmute_copy(&beta),
|
||||
self.data.ptr_mut() as *mut f32,
|
||||
rsc as isize,
|
||||
csc as isize,
|
||||
);
|
||||
if nrows1 > SMALL_DIM
|
||||
&& ncols1 > SMALL_DIM
|
||||
&& nrows2 > SMALL_DIM
|
||||
&& ncols2 > SMALL_DIM
|
||||
{
|
||||
assert_eq!(
|
||||
ncols2, nrows3,
|
||||
"gemm: dimensions mismatch for multiplication."
|
||||
);
|
||||
assert_eq!(
|
||||
(nrows1, ncols1),
|
||||
(nrows2, ncols3),
|
||||
"gemm: dimensions mismatch for addition."
|
||||
);
|
||||
|
||||
// NOTE: this case should never happen because we enter this
|
||||
// codepath only when ncols2 > SMALL_DIM. Though we keep this
|
||||
// here just in case if in the future we change the conditions to
|
||||
// enter this codepath.
|
||||
if ncols2 == 0 {
|
||||
// NOTE: we can't just always multiply by beta
|
||||
// because we documented the guaranty that `self` is
|
||||
// never read if `beta` is zero.
|
||||
if beta.is_zero() {
|
||||
self.fill(N::zero());
|
||||
} else {
|
||||
*self *= beta;
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
} else if N::is::<f64>() {
|
||||
let (rsa, csa) = a.strides();
|
||||
let (rsb, csb) = b.strides();
|
||||
let (rsc, csc) = self.strides();
|
||||
|
||||
unsafe {
|
||||
matrixmultiply::dgemm(
|
||||
nrows2,
|
||||
ncols2,
|
||||
ncols3,
|
||||
mem::transmute_copy(&alpha),
|
||||
a.data.ptr() as *const f64,
|
||||
rsa as isize,
|
||||
csa as isize,
|
||||
b.data.ptr() as *const f64,
|
||||
rsb as isize,
|
||||
csb as isize,
|
||||
mem::transmute_copy(&beta),
|
||||
self.data.ptr_mut() as *mut f64,
|
||||
rsc as isize,
|
||||
csc as isize,
|
||||
);
|
||||
if N::is::<f32>() {
|
||||
let (rsa, csa) = a.strides();
|
||||
let (rsb, csb) = b.strides();
|
||||
let (rsc, csc) = self.strides();
|
||||
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
nrows2,
|
||||
ncols2,
|
||||
ncols3,
|
||||
mem::transmute_copy(&alpha),
|
||||
a.data.ptr() as *const f32,
|
||||
rsa as isize,
|
||||
csa as isize,
|
||||
b.data.ptr() as *const f32,
|
||||
rsb as isize,
|
||||
csb as isize,
|
||||
mem::transmute_copy(&beta),
|
||||
self.data.ptr_mut() as *mut f32,
|
||||
rsc as isize,
|
||||
csc as isize,
|
||||
);
|
||||
}
|
||||
return;
|
||||
} else if N::is::<f64>() {
|
||||
let (rsa, csa) = a.strides();
|
||||
let (rsb, csb) = b.strides();
|
||||
let (rsc, csc) = self.strides();
|
||||
|
||||
unsafe {
|
||||
matrixmultiply::dgemm(
|
||||
nrows2,
|
||||
ncols2,
|
||||
ncols3,
|
||||
mem::transmute_copy(&alpha),
|
||||
a.data.ptr() as *const f64,
|
||||
rsa as isize,
|
||||
csa as isize,
|
||||
b.data.ptr() as *const f64,
|
||||
rsb as isize,
|
||||
csb as isize,
|
||||
mem::transmute_copy(&beta),
|
||||
self.data.ptr_mut() as *mut f64,
|
||||
rsc as isize,
|
||||
csc as isize,
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for j1 in 0..ncols1 {
|
||||
// FIXME: avoid bound checks.
|
||||
self.column_mut(j1).gemv(alpha, a, &b.column(j1), beta);
|
||||
|
|
|
@ -13,6 +13,11 @@ fn empty_matrix_mul_matrix() {
|
|||
let m1 = DMatrix::<f32>::zeros(3, 0);
|
||||
let m2 = DMatrix::<f32>::zeros(0, 4);
|
||||
assert_eq!(m1 * m2, DMatrix::zeros(3, 4));
|
||||
|
||||
// Still works with larger matrices.
|
||||
let m1 = DMatrix::<f32>::zeros(13, 0);
|
||||
let m2 = DMatrix::<f32>::zeros(0, 14);
|
||||
assert_eq!(m1 * m2, DMatrix::zeros(13, 14));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -28,3 +33,28 @@ fn empty_matrix_tr_mul_matrix() {
|
|||
let m2 = DMatrix::<f32>::zeros(0, 4);
|
||||
assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_matrix_gemm() {
|
||||
let mut res = DMatrix::repeat(3, 4, 1.0);
|
||||
let m1 = DMatrix::<f32>::zeros(3, 0);
|
||||
let m2 = DMatrix::<f32>::zeros(0, 4);
|
||||
res.gemm(1.0, &m1, &m2, 0.5);
|
||||
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
|
||||
|
||||
// Still works with lager matrices.
|
||||
let mut res = DMatrix::repeat(13, 14, 1.0);
|
||||
let m1 = DMatrix::<f32>::zeros(13, 0);
|
||||
let m2 = DMatrix::<f32>::zeros(0, 14);
|
||||
res.gemm(1.0, &m1, &m2, 0.5);
|
||||
assert_eq!(res, DMatrix::repeat(13, 14, 0.5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_matrix_gemm_tr() {
|
||||
let mut res = DMatrix::repeat(3, 4, 1.0);
|
||||
let m1 = DMatrix::<f32>::zeros(0, 3);
|
||||
let m2 = DMatrix::<f32>::zeros(0, 4);
|
||||
res.gemm_tr(1.0, &m1, &m2, 0.5);
|
||||
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
|
||||
}
|
Loading…
Reference in New Issue