Add transpose gemv.
This commit is contained in:
parent
1a7f0dea9f
commit
300b3d0452
@ -311,6 +311,45 @@ impl<N, D: Dim, S> Vector<N, D, S>
|
||||
self.rows_range_mut(j + 1 ..).axpy(alpha * val, &col2.rows_range(j + 1 ..), N::one());
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes `self = alpha * a.transpose() * x + beta * self`, where `a` is a matrix, `x` a vector, and
|
||||
/// `alpha, beta` two scalars.
|
||||
///
|
||||
/// If `beta` is zero, `self` is never read.
|
||||
#[inline]
|
||||
pub fn gemv_tr<R2: Dim, C2: Dim, D3: Dim, SB, SC>(&mut self,
|
||||
alpha: N,
|
||||
a: &Matrix<N, R2, C2, SB>,
|
||||
x: &Vector<N, D3, SC>,
|
||||
beta: N)
|
||||
where N: One,
|
||||
SB: Storage<N, R2, C2>,
|
||||
SC: Storage<N, D3>,
|
||||
ShapeConstraint: DimEq<D, C2> +
|
||||
AreMultipliable<C2, R2, D3, U1> {
|
||||
let dim1 = self.nrows();
|
||||
let (nrows2, ncols2) = a.shape();
|
||||
let dim3 = x.nrows();
|
||||
|
||||
assert!(nrows2 == dim3 && dim1 == ncols2, "Gemv: dimensions mismatch.");
|
||||
|
||||
if ncols2 == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if beta.is_zero() {
|
||||
for j in 0 .. ncols2 {
|
||||
let val = unsafe { self.vget_unchecked_mut(j) };
|
||||
*val = alpha * a.column(j).dot(x)
|
||||
}
|
||||
}
|
||||
else {
|
||||
for j in 0 .. ncols2 {
|
||||
let val = unsafe { self.vget_unchecked_mut(j) };
|
||||
*val = alpha * a.column(j).dot(x) + beta * *val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S>
|
||||
|
@ -31,6 +31,26 @@ quickcheck! {
|
||||
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
||||
}
|
||||
|
||||
fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool {
|
||||
let n = cmp::max(1, cmp::min(n, 50));
|
||||
let a = DMatrix::<f64>::new_random(n, n);
|
||||
let x = DVector::new_random(n);
|
||||
let mut y1 = DVector::new_random(n);
|
||||
let mut y2 = y1.clone();
|
||||
|
||||
y1.gemv(alpha, &a, &x, beta);
|
||||
y2.gemv_tr(alpha, &a.transpose(), &x, beta);
|
||||
|
||||
if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
|
||||
return false;
|
||||
}
|
||||
|
||||
y1.gemv(alpha, &a, &x, 0.0);
|
||||
y2.gemv_tr(alpha, &a.transpose(), &x, 0.0);
|
||||
|
||||
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
||||
}
|
||||
|
||||
fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
||||
let n = cmp::max(1, cmp::min(n, 50));
|
||||
let a = DMatrix::<f64>::new_random(n, n);
|
||||
|
Loading…
Reference in New Issue
Block a user