Add transpose gemv.

This commit is contained in:
Sébastien Crozet 2018-02-02 12:26:14 +01:00
parent 1a7f0dea9f
commit 300b3d0452
2 changed files with 59 additions and 0 deletions

View File

@ -311,6 +311,45 @@ impl<N, D: Dim, S> Vector<N, D, S>
self.rows_range_mut(j + 1 ..).axpy(alpha * val, &col2.rows_range(j + 1 ..), N::one()); self.rows_range_mut(j + 1 ..).axpy(alpha * val, &col2.rows_range(j + 1 ..), N::one());
} }
} }
/// Computes `self = alpha * a.transpose() * x + beta * self`, where `a` is a matrix, `x` a vector, and
/// `alpha, beta` two scalars.
///
/// If `beta` is zero, `self` is never read.
#[inline]
pub fn gemv_tr<R2: Dim, C2: Dim, D3: Dim, SB, SC>(&mut self,
alpha: N,
a: &Matrix<N, R2, C2, SB>,
x: &Vector<N, D3, SC>,
beta: N)
where N: One,
SB: Storage<N, R2, C2>,
SC: Storage<N, D3>,
ShapeConstraint: DimEq<D, C2> +
AreMultipliable<C2, R2, D3, U1> {
let dim1 = self.nrows();
let (nrows2, ncols2) = a.shape();
let dim3 = x.nrows();
assert!(nrows2 == dim3 && dim1 == ncols2, "Gemv: dimensions mismatch.");
if ncols2 == 0 {
return;
}
if beta.is_zero() {
for j in 0 .. ncols2 {
let val = unsafe { self.vget_unchecked_mut(j) };
*val = alpha * a.column(j).dot(x)
}
}
else {
for j in 0 .. ncols2 {
let val = unsafe { self.vget_unchecked_mut(j) };
*val = alpha * a.column(j).dot(x) + beta * *val;
}
}
}
} }
impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S> impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S>

View File

@ -31,6 +31,26 @@ quickcheck! {
relative_eq!(y1, y2, epsilon = 1.0e-10) relative_eq!(y1, y2, epsilon = 1.0e-10)
} }
fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool {
let n = cmp::max(1, cmp::min(n, 50));
let a = DMatrix::<f64>::new_random(n, n);
let x = DVector::new_random(n);
let mut y1 = DVector::new_random(n);
let mut y2 = y1.clone();
y1.gemv(alpha, &a, &x, beta);
y2.gemv_tr(alpha, &a.transpose(), &x, beta);
if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
return false;
}
y1.gemv(alpha, &a, &x, 0.0);
y2.gemv_tr(alpha, &a.transpose(), &x, 0.0);
relative_eq!(y1, y2, epsilon = 1.0e-10)
}
fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool {
let n = cmp::max(1, cmp::min(n, 50)); let n = cmp::max(1, cmp::min(n, 50));
let a = DMatrix::<f64>::new_random(n, n); let a = DMatrix::<f64>::new_random(n, n);