From 300b3d0452595083a088b4214d339e99fd45ca8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Fri, 2 Feb 2018 12:26:14 +0100 Subject: [PATCH] Add transpose gemv. --- src/core/blas.rs | 39 +++++++++++++++++++++++++++++++++++++++ tests/core/blas.rs | 20 ++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/src/core/blas.rs b/src/core/blas.rs index ed4ea2f5..517e7793 100644 --- a/src/core/blas.rs +++ b/src/core/blas.rs @@ -311,6 +311,45 @@ impl Vector self.rows_range_mut(j + 1 ..).axpy(alpha * val, &col2.rows_range(j + 1 ..), N::one()); } } + + /// Computes `self = alpha * a.transpose() * x + beta * self`, where `a` is a matrix, `x` a vector, and + /// `alpha, beta` two scalars. + /// + /// If `beta` is zero, `self` is never read. + #[inline] + pub fn gemv_tr(&mut self, + alpha: N, + a: &Matrix, + x: &Vector, + beta: N) + where N: One, + SB: Storage, + SC: Storage, + ShapeConstraint: DimEq + + AreMultipliable { + let dim1 = self.nrows(); + let (nrows2, ncols2) = a.shape(); + let dim3 = x.nrows(); + + assert!(nrows2 == dim3 && dim1 == ncols2, "Gemv: dimensions mismatch."); + + if ncols2 == 0 { + return; + } + + if beta.is_zero() { + for j in 0 .. ncols2 { + let val = unsafe { self.vget_unchecked_mut(j) }; + *val = alpha * a.column(j).dot(x) + } + } + else { + for j in 0 .. ncols2 { + let val = unsafe { self.vget_unchecked_mut(j) }; + *val = alpha * a.column(j).dot(x) + beta * *val; + } + } + } } impl> Matrix diff --git a/tests/core/blas.rs b/tests/core/blas.rs index 019c73df..0db8bc88 100644 --- a/tests/core/blas.rs +++ b/tests/core/blas.rs @@ -31,6 +31,26 @@ quickcheck! { relative_eq!(y1, y2, epsilon = 1.0e-10) } + fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let x = DVector::new_random(n); + let mut y1 = DVector::new_random(n); + let mut y2 = y1.clone(); + + y1.gemv(alpha, &a, &x, beta); + y2.gemv_tr(alpha, &a.transpose(), &x, beta); + + if !relative_eq!(y1, y2, epsilon = 1.0e-10) { + return false; + } + + y1.gemv(alpha, &a, &x, 0.0); + y2.gemv_tr(alpha, &a.transpose(), &x, 0.0); + + relative_eq!(y1, y2, epsilon = 1.0e-10) + } + fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { let n = cmp::max(1, cmp::min(n, 50)); let a = DMatrix::::new_random(n, n);