Add transpose gemv.
This commit is contained in:
parent
1a7f0dea9f
commit
300b3d0452
|
@ -311,6 +311,45 @@ impl<N, D: Dim, S> Vector<N, D, S>
|
||||||
self.rows_range_mut(j + 1 ..).axpy(alpha * val, &col2.rows_range(j + 1 ..), N::one());
|
self.rows_range_mut(j + 1 ..).axpy(alpha * val, &col2.rows_range(j + 1 ..), N::one());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Computes `self = alpha * a.transpose() * x + beta * self`, where `a` is a matrix, `x` a vector, and
|
||||||
|
/// `alpha, beta` two scalars.
|
||||||
|
///
|
||||||
|
/// If `beta` is zero, `self` is never read.
|
||||||
|
#[inline]
|
||||||
|
pub fn gemv_tr<R2: Dim, C2: Dim, D3: Dim, SB, SC>(&mut self,
|
||||||
|
alpha: N,
|
||||||
|
a: &Matrix<N, R2, C2, SB>,
|
||||||
|
x: &Vector<N, D3, SC>,
|
||||||
|
beta: N)
|
||||||
|
where N: One,
|
||||||
|
SB: Storage<N, R2, C2>,
|
||||||
|
SC: Storage<N, D3>,
|
||||||
|
ShapeConstraint: DimEq<D, C2> +
|
||||||
|
AreMultipliable<C2, R2, D3, U1> {
|
||||||
|
let dim1 = self.nrows();
|
||||||
|
let (nrows2, ncols2) = a.shape();
|
||||||
|
let dim3 = x.nrows();
|
||||||
|
|
||||||
|
assert!(nrows2 == dim3 && dim1 == ncols2, "Gemv: dimensions mismatch.");
|
||||||
|
|
||||||
|
if ncols2 == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if beta.is_zero() {
|
||||||
|
for j in 0 .. ncols2 {
|
||||||
|
let val = unsafe { self.vget_unchecked_mut(j) };
|
||||||
|
*val = alpha * a.column(j).dot(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for j in 0 .. ncols2 {
|
||||||
|
let val = unsafe { self.vget_unchecked_mut(j) };
|
||||||
|
*val = alpha * a.column(j).dot(x) + beta * *val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S>
|
impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S>
|
||||||
|
|
|
@ -31,6 +31,26 @@ quickcheck! {
|
||||||
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool {
|
||||||
|
let n = cmp::max(1, cmp::min(n, 50));
|
||||||
|
let a = DMatrix::<f64>::new_random(n, n);
|
||||||
|
let x = DVector::new_random(n);
|
||||||
|
let mut y1 = DVector::new_random(n);
|
||||||
|
let mut y2 = y1.clone();
|
||||||
|
|
||||||
|
y1.gemv(alpha, &a, &x, beta);
|
||||||
|
y2.gemv_tr(alpha, &a.transpose(), &x, beta);
|
||||||
|
|
||||||
|
if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
y1.gemv(alpha, &a, &x, 0.0);
|
||||||
|
y2.gemv_tr(alpha, &a.transpose(), &x, 0.0);
|
||||||
|
|
||||||
|
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
||||||
|
}
|
||||||
|
|
||||||
fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
||||||
let n = cmp::max(1, cmp::min(n, 50));
|
let n = cmp::max(1, cmp::min(n, 50));
|
||||||
let a = DMatrix::<f64>::new_random(n, n);
|
let a = DMatrix::<f64>::new_random(n, n);
|
||||||
|
|
Loading…
Reference in New Issue