Add Vector::axcpy method

The added method `Vector::axcpy` generalises `Vector::gemv` to
noncommutative cases since it allows us to write for `gemv`
`self.axcpy(alpha, &col2, val, beta)`, instead the usual
`self.axpy(alpha * val, &col2, beta)`. Hence, `axcpy` preserves the
order of scalar multiplication which is important for applications where
commutativity is not guaranteed (e.g., matrices of quaternions, etc.).

This commmit also removes helpers `array_axpy` and `array_ax`, and
replaces them with `array_axcpy` and `array_axc` respectively, which
like above preserve the order of scalar multiplication.

Finally, `Vector::axpy` is preserved, however, now expressed in terms of
`Vector::axcpy` like so:

```
self.axcpy(alpha * val, &col2, beta)
```
This commit is contained in:
Jakub Konka 2019-11-04 10:27:57 +01:00 committed by Sébastien Crozet
parent e1c8e1bccf
commit fe65b1c129

View File

@ -468,21 +468,21 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
} }
} }
fn array_axpy<N>(y: &mut [N], a: N, x: &[N], beta: N, stride1: usize, stride2: usize, len: usize) fn array_axcpy<N>(y: &mut [N], a: N, x: &[N], c: N, beta: N, stride1: usize, stride2: usize, len: usize)
where N: Scalar + Zero + ClosedAdd + ClosedMul { where N: Scalar + Zero + ClosedAdd + ClosedMul {
for i in 0..len { for i in 0..len {
unsafe { unsafe {
let y = y.get_unchecked_mut(i * stride1); let y = y.get_unchecked_mut(i * stride1);
*y = *x.get_unchecked(i * stride2) * a + *y * beta; *y = a * *x.get_unchecked(i * stride2) * c + beta * *y;
} }
} }
} }
fn array_ax<N>(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len: usize) fn array_axc<N>(y: &mut [N], a: N, x: &[N], c: N, stride1: usize, stride2: usize, len: usize)
where N: Scalar + Zero + ClosedAdd + ClosedMul { where N: Scalar + Zero + ClosedAdd + ClosedMul {
for i in 0..len { for i in 0..len {
unsafe { unsafe {
*y.get_unchecked_mut(i * stride1) = *x.get_unchecked(i * stride2) * a; *y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2) * c;
} }
} }
} }
@ -492,6 +492,40 @@ where
N: Scalar + Zero + ClosedAdd + ClosedMul, N: Scalar + Zero + ClosedAdd + ClosedMul,
S: StorageMut<N, D>, S: StorageMut<N, D>,
{ {
/// Computes `self = a * x * c + b * self`.
///
/// If `b` is zero, `self` is never read from.
///
/// # Examples:
///
/// ```
/// # use nalgebra::Vector3;
/// let mut vec1 = Vector3::new(1.0, 2.0, 3.0);
/// let vec2 = Vector3::new(0.1, 0.2, 0.3);
/// vec1.axcpy(5.0, &vec2, 2.0, 5.0);
/// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0));
/// ```
#[inline]
pub fn axcpy<D2: Dim, SB>(&mut self, a: N, x: &Vector<N, D2, SB>, c: N, b: N)
where
SB: Storage<N, D2>,
ShapeConstraint: DimEq<D, D2>,
{
assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
let rstride1 = self.strides().0;
let rstride2 = x.strides().0;
let y = self.data.as_mut_slice();
let x = x.data.as_slice();
if !b.is_zero() {
array_axcpy(y, a, x, c, b, rstride1, rstride2, x.len());
} else {
array_axc(y, a, x, c, rstride1, rstride2, x.len());
}
}
/// Computes `self = a * x + b * self`. /// Computes `self = a * x + b * self`.
/// ///
/// If `b` is zero, `self` is never read from. /// If `b` is zero, `self` is never read from.
@ -508,22 +542,12 @@ where
#[inline] #[inline]
pub fn axpy<D2: Dim, SB>(&mut self, a: N, x: &Vector<N, D2, SB>, b: N) pub fn axpy<D2: Dim, SB>(&mut self, a: N, x: &Vector<N, D2, SB>, b: N)
where where
N: One,
SB: Storage<N, D2>, SB: Storage<N, D2>,
ShapeConstraint: DimEq<D, D2>, ShapeConstraint: DimEq<D, D2>,
{ {
assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes."); assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes.");
self.axcpy(a, x, N::one(), b)
let rstride1 = self.strides().0;
let rstride2 = x.strides().0;
let y = self.data.as_mut_slice();
let x = x.data.as_slice();
if !b.is_zero() {
array_axpy(y, a, x, b, rstride1, rstride2, x.len());
} else {
array_ax(y, a, x, rstride1, rstride2, x.len());
}
} }
/// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and /// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and
@ -579,13 +603,13 @@ where
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
let col2 = a.column(0); let col2 = a.column(0);
let val = unsafe { *x.vget_unchecked(0) }; let val = unsafe { *x.vget_unchecked(0) };
self.axpy(val * alpha, &col2, beta); self.axcpy(alpha, &col2, val, beta);
for j in 1..ncols2 { for j in 1..ncols2 {
let col2 = a.column(j); let col2 = a.column(j);
let val = unsafe { *x.vget_unchecked(j) }; let val = unsafe { *x.vget_unchecked(j) };
self.axpy(val * alpha, &col2, N::one()); self.axcpy(alpha, &col2, val, N::one());
} }
} }
@ -624,7 +648,7 @@ where
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
let col2 = a.column(0); let col2 = a.column(0);
let val = unsafe { *x.vget_unchecked(0) }; let val = unsafe { *x.vget_unchecked(0) };
self.axpy(val * alpha, &col2, beta); self.axpy(alpha * val, &col2, beta);
self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..)); self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..));
for j in 1..dim2 { for j in 1..dim2 {
@ -637,7 +661,7 @@ where
*self.vget_unchecked_mut(j) += alpha * dot; *self.vget_unchecked_mut(j) += alpha * dot;
} }
self.rows_range_mut(j + 1..) self.rows_range_mut(j + 1..)
.axpy(val * alpha, &col2.rows_range(j + 1..), N::one()); .axpy(alpha * val, &col2.rows_range(j + 1..), N::one());
} }
} }
@ -890,7 +914,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
for j in 0..ncols1 { for j in 0..ncols1 {
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
let val = unsafe { conjugate(*y.vget_unchecked(j)) }; let val = unsafe { conjugate(*y.vget_unchecked(j)) };
self.column_mut(j).axpy(val * alpha, x, beta); self.column_mut(j).axpy(alpha * val, x, beta);
} }
} }
@ -1256,7 +1280,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
let subdim = Dynamic::new(dim1 - j); let subdim = Dynamic::new(dim1 - j);
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
self.generic_slice_mut((j, j), (subdim, U1)).axpy( self.generic_slice_mut((j, j), (subdim, U1)).axpy(
val * alpha, alpha * val,
&x.rows_range(j..), &x.rows_range(j..),
beta, beta,
); );