GEMM on empty matrices: properly take the beta parameter into account.
This commit is contained in:
parent
f9f7ddd08f
commit
be41cb96e8
170
src/base/blas.rs
170
src/base/blas.rs
|
@ -565,7 +565,14 @@ where
|
||||||
);
|
);
|
||||||
|
|
||||||
if ncols2 == 0 {
|
if ncols2 == 0 {
|
||||||
self.fill(N::zero());
|
// NOTE: we can't just always multiply by beta
|
||||||
|
// because we documented the guaranty that `self` is
|
||||||
|
// never read if `beta` is zero.
|
||||||
|
if beta.is_zero() {
|
||||||
|
self.fill(N::zero());
|
||||||
|
} else {
|
||||||
|
*self *= beta;
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -992,98 +999,109 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
|
||||||
|
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
{
|
{
|
||||||
// matrixmultiply can be used only if the std feature is available.
|
|
||||||
let nrows1 = self.nrows();
|
|
||||||
let (nrows2, ncols2) = a.shape();
|
|
||||||
let (nrows3, ncols3) = b.shape();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
ncols2, nrows3,
|
|
||||||
"gemm: dimensions mismatch for multiplication."
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
(nrows1, ncols1),
|
|
||||||
(nrows2, ncols3),
|
|
||||||
"gemm: dimensions mismatch for addition."
|
|
||||||
);
|
|
||||||
|
|
||||||
|
|
||||||
if a.ncols() == 0 {
|
|
||||||
self.fill(N::zero());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We assume large matrices will be Dynamic but small matrices static.
|
// We assume large matrices will be Dynamic but small matrices static.
|
||||||
// We could use matrixmultiply for large statically-sized matrices but the performance
|
// We could use matrixmultiply for large statically-sized matrices but the performance
|
||||||
// threshold to activate it would be different from SMALL_DIM because our code optimizes
|
// threshold to activate it would be different from SMALL_DIM because our code optimizes
|
||||||
// better for statically-sized matrices.
|
// better for statically-sized matrices.
|
||||||
let is_dynamic = R1::is::<Dynamic>()
|
if R1::is::<Dynamic>()
|
||||||
|| C1::is::<Dynamic>()
|
|| C1::is::<Dynamic>()
|
||||||
|| R2::is::<Dynamic>()
|
|| R2::is::<Dynamic>()
|
||||||
|| C2::is::<Dynamic>()
|
|| C2::is::<Dynamic>()
|
||||||
|| R3::is::<Dynamic>()
|
|| R3::is::<Dynamic>()
|
||||||
|| C3::is::<Dynamic>();
|
|| C3::is::<Dynamic>() {
|
||||||
// Threshold determined empirically.
|
// matrixmultiply can be used only if the std feature is available.
|
||||||
const SMALL_DIM: usize = 5;
|
let nrows1 = self.nrows();
|
||||||
|
let (nrows2, ncols2) = a.shape();
|
||||||
|
let (nrows3, ncols3) = b.shape();
|
||||||
|
|
||||||
if is_dynamic
|
// Threshold determined empirically.
|
||||||
&& nrows1 > SMALL_DIM
|
const SMALL_DIM: usize = 5;
|
||||||
&& ncols1 > SMALL_DIM
|
|
||||||
&& nrows2 > SMALL_DIM
|
|
||||||
&& ncols2 > SMALL_DIM
|
|
||||||
{
|
|
||||||
if N::is::<f32>() {
|
|
||||||
let (rsa, csa) = a.strides();
|
|
||||||
let (rsb, csb) = b.strides();
|
|
||||||
let (rsc, csc) = self.strides();
|
|
||||||
|
|
||||||
unsafe {
|
if nrows1 > SMALL_DIM
|
||||||
matrixmultiply::sgemm(
|
&& ncols1 > SMALL_DIM
|
||||||
nrows2,
|
&& nrows2 > SMALL_DIM
|
||||||
ncols2,
|
&& ncols2 > SMALL_DIM
|
||||||
ncols3,
|
{
|
||||||
mem::transmute_copy(&alpha),
|
assert_eq!(
|
||||||
a.data.ptr() as *const f32,
|
ncols2, nrows3,
|
||||||
rsa as isize,
|
"gemm: dimensions mismatch for multiplication."
|
||||||
csa as isize,
|
);
|
||||||
b.data.ptr() as *const f32,
|
assert_eq!(
|
||||||
rsb as isize,
|
(nrows1, ncols1),
|
||||||
csb as isize,
|
(nrows2, ncols3),
|
||||||
mem::transmute_copy(&beta),
|
"gemm: dimensions mismatch for addition."
|
||||||
self.data.ptr_mut() as *mut f32,
|
);
|
||||||
rsc as isize,
|
|
||||||
csc as isize,
|
// NOTE: this case should never happen because we enter this
|
||||||
);
|
// codepath only when ncols2 > SMALL_DIM. Though we keep this
|
||||||
|
// here just in case if in the future we change the conditions to
|
||||||
|
// enter this codepath.
|
||||||
|
if ncols2 == 0 {
|
||||||
|
// NOTE: we can't just always multiply by beta
|
||||||
|
// because we documented the guaranty that `self` is
|
||||||
|
// never read if `beta` is zero.
|
||||||
|
if beta.is_zero() {
|
||||||
|
self.fill(N::zero());
|
||||||
|
} else {
|
||||||
|
*self *= beta;
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
} else if N::is::<f64>() {
|
|
||||||
let (rsa, csa) = a.strides();
|
|
||||||
let (rsb, csb) = b.strides();
|
|
||||||
let (rsc, csc) = self.strides();
|
|
||||||
|
|
||||||
unsafe {
|
if N::is::<f32>() {
|
||||||
matrixmultiply::dgemm(
|
let (rsa, csa) = a.strides();
|
||||||
nrows2,
|
let (rsb, csb) = b.strides();
|
||||||
ncols2,
|
let (rsc, csc) = self.strides();
|
||||||
ncols3,
|
|
||||||
mem::transmute_copy(&alpha),
|
unsafe {
|
||||||
a.data.ptr() as *const f64,
|
matrixmultiply::sgemm(
|
||||||
rsa as isize,
|
nrows2,
|
||||||
csa as isize,
|
ncols2,
|
||||||
b.data.ptr() as *const f64,
|
ncols3,
|
||||||
rsb as isize,
|
mem::transmute_copy(&alpha),
|
||||||
csb as isize,
|
a.data.ptr() as *const f32,
|
||||||
mem::transmute_copy(&beta),
|
rsa as isize,
|
||||||
self.data.ptr_mut() as *mut f64,
|
csa as isize,
|
||||||
rsc as isize,
|
b.data.ptr() as *const f32,
|
||||||
csc as isize,
|
rsb as isize,
|
||||||
);
|
csb as isize,
|
||||||
|
mem::transmute_copy(&beta),
|
||||||
|
self.data.ptr_mut() as *mut f32,
|
||||||
|
rsc as isize,
|
||||||
|
csc as isize,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
} else if N::is::<f64>() {
|
||||||
|
let (rsa, csa) = a.strides();
|
||||||
|
let (rsb, csb) = b.strides();
|
||||||
|
let (rsc, csc) = self.strides();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
matrixmultiply::dgemm(
|
||||||
|
nrows2,
|
||||||
|
ncols2,
|
||||||
|
ncols3,
|
||||||
|
mem::transmute_copy(&alpha),
|
||||||
|
a.data.ptr() as *const f64,
|
||||||
|
rsa as isize,
|
||||||
|
csa as isize,
|
||||||
|
b.data.ptr() as *const f64,
|
||||||
|
rsb as isize,
|
||||||
|
csb as isize,
|
||||||
|
mem::transmute_copy(&beta),
|
||||||
|
self.data.ptr_mut() as *mut f64,
|
||||||
|
rsc as isize,
|
||||||
|
csc as isize,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for j1 in 0..ncols1 {
|
for j1 in 0..ncols1 {
|
||||||
// FIXME: avoid bound checks.
|
// FIXME: avoid bound checks.
|
||||||
self.column_mut(j1).gemv(alpha, a, &b.column(j1), beta);
|
self.column_mut(j1).gemv(alpha, a, &b.column(j1), beta);
|
||||||
|
|
|
@ -13,6 +13,11 @@ fn empty_matrix_mul_matrix() {
|
||||||
let m1 = DMatrix::<f32>::zeros(3, 0);
|
let m1 = DMatrix::<f32>::zeros(3, 0);
|
||||||
let m2 = DMatrix::<f32>::zeros(0, 4);
|
let m2 = DMatrix::<f32>::zeros(0, 4);
|
||||||
assert_eq!(m1 * m2, DMatrix::zeros(3, 4));
|
assert_eq!(m1 * m2, DMatrix::zeros(3, 4));
|
||||||
|
|
||||||
|
// Still works with larger matrices.
|
||||||
|
let m1 = DMatrix::<f32>::zeros(13, 0);
|
||||||
|
let m2 = DMatrix::<f32>::zeros(0, 14);
|
||||||
|
assert_eq!(m1 * m2, DMatrix::zeros(13, 14));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -28,3 +33,28 @@ fn empty_matrix_tr_mul_matrix() {
|
||||||
let m2 = DMatrix::<f32>::zeros(0, 4);
|
let m2 = DMatrix::<f32>::zeros(0, 4);
|
||||||
assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4));
|
assert_eq!(m1.tr_mul(&m2), DMatrix::zeros(3, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_matrix_gemm() {
|
||||||
|
let mut res = DMatrix::repeat(3, 4, 1.0);
|
||||||
|
let m1 = DMatrix::<f32>::zeros(3, 0);
|
||||||
|
let m2 = DMatrix::<f32>::zeros(0, 4);
|
||||||
|
res.gemm(1.0, &m1, &m2, 0.5);
|
||||||
|
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
|
||||||
|
|
||||||
|
// Still works with lager matrices.
|
||||||
|
let mut res = DMatrix::repeat(13, 14, 1.0);
|
||||||
|
let m1 = DMatrix::<f32>::zeros(13, 0);
|
||||||
|
let m2 = DMatrix::<f32>::zeros(0, 14);
|
||||||
|
res.gemm(1.0, &m1, &m2, 0.5);
|
||||||
|
assert_eq!(res, DMatrix::repeat(13, 14, 0.5));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_matrix_gemm_tr() {
|
||||||
|
let mut res = DMatrix::repeat(3, 4, 1.0);
|
||||||
|
let m1 = DMatrix::<f32>::zeros(0, 3);
|
||||||
|
let m2 = DMatrix::<f32>::zeros(0, 4);
|
||||||
|
res.gemm_tr(1.0, &m1, &m2, 0.5);
|
||||||
|
assert_eq!(res, DMatrix::repeat(3, 4, 0.5));
|
||||||
|
}
|
Loading…
Reference in New Issue