Fix Vector::axpy for noncommutative cases (#648)

Fix Vector::axpy for noncommutative cases
This commit is contained in:
Sébastien Crozet 2019-11-19 22:06:01 +01:00 committed by GitHub
commit d09aa50a31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 140 additions and 92 deletions

View File

@ -468,21 +468,21 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
}
}
fn array_axpy<N>(y: &mut [N], a: N, x: &[N], beta: N, stride1: usize, stride2: usize, len: usize)
fn array_axcpy<N>(y: &mut [N], a: N, x: &[N], c: N, beta: N, stride1: usize, stride2: usize, len: usize)
where N: Scalar + Zero + ClosedAdd + ClosedMul {
for i in 0..len {
unsafe {
let y = y.get_unchecked_mut(i * stride1);
*y = a * *x.get_unchecked(i * stride2) + beta * *y;
*y = a * *x.get_unchecked(i * stride2) * c + beta * *y;
}
}
}
fn array_ax<N>(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len: usize)
fn array_axc<N>(y: &mut [N], a: N, x: &[N], c: N, stride1: usize, stride2: usize, len: usize)
where N: Scalar + Zero + ClosedAdd + ClosedMul {
for i in 0..len {
unsafe {
*y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2);
*y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2) * c;
}
}
}
@ -492,6 +492,40 @@ where
N: Scalar + Zero + ClosedAdd + ClosedMul,
S: StorageMut<N, D>,
{
/// Computes `self = a * x * c + b * self`.
///
/// If `b` is zero, `self` is never read from.
///
/// # Examples:
///
/// ```
/// # use nalgebra::Vector3;
/// let mut vec1 = Vector3::new(1.0, 2.0, 3.0);
/// let vec2 = Vector3::new(0.1, 0.2, 0.3);
/// vec1.axcpy(5.0, &vec2, 2.0, 5.0);
/// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0));
/// ```
#[inline]
pub fn axcpy<D2: Dim, SB>(&mut self, a: N, x: &Vector<N, D2, SB>, c: N, b: N)
where
SB: Storage<N, D2>,
ShapeConstraint: DimEq<D, D2>,
{
assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
let rstride1 = self.strides().0;
let rstride2 = x.strides().0;
let y = self.data.as_mut_slice();
let x = x.data.as_slice();
if !b.is_zero() {
array_axcpy(y, a, x, c, b, rstride1, rstride2, x.len());
} else {
array_axc(y, a, x, c, rstride1, rstride2, x.len());
}
}
/// Computes `self = a * x + b * self`.
///
/// If `b` is zero, `self` is never read from.
@ -508,22 +542,12 @@ where
#[inline]
pub fn axpy<D2: Dim, SB>(&mut self, a: N, x: &Vector<N, D2, SB>, b: N)
where
N: One,
SB: Storage<N, D2>,
ShapeConstraint: DimEq<D, D2>,
{
assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes.");
let rstride1 = self.strides().0;
let rstride2 = x.strides().0;
let y = self.data.as_mut_slice();
let x = x.data.as_slice();
if !b.is_zero() {
array_axpy(y, a, x, b, rstride1, rstride2, x.len());
} else {
array_ax(y, a, x, rstride1, rstride2, x.len());
}
self.axcpy(a, x, N::one(), b)
}
/// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and
@ -579,13 +603,13 @@ where
// FIXME: avoid bound checks.
let col2 = a.column(0);
let val = unsafe { *x.vget_unchecked(0) };
self.axpy(alpha * val, &col2, beta);
self.axcpy(alpha, &col2, val, beta);
for j in 1..ncols2 {
let col2 = a.column(j);
let val = unsafe { *x.vget_unchecked(j) };
self.axpy(alpha * val, &col2, N::one());
self.axcpy(alpha, &col2, val, N::one());
}
}

View File

@ -1,5 +1,28 @@
#![cfg(feature = "arbitrary")]
use na::{geometry::Quaternion, Matrix2, Vector3};
use num_traits::{One, Zero};
#[test]
fn gemm_noncommutative() {
type Qf64 = Quaternion<f64>;
let i = Qf64::from_imag(Vector3::new(1.0, 0.0, 0.0));
let j = Qf64::from_imag(Vector3::new(0.0, 1.0, 0.0));
let k = Qf64::from_imag(Vector3::new(0.0, 0.0, 1.0));
let m1 = Matrix2::new(k, Qf64::zero(), j, i);
// this is the inverse of m1
let m2 = Matrix2::new(-k, Qf64::zero(), Qf64::one(), -i);
let mut res: Matrix2<Qf64> = Matrix2::zero();
res.gemm(Qf64::one(), &m1, &m2, Qf64::zero());
assert_eq!(res, Matrix2::identity());
let mut res: Matrix2<Qf64> = Matrix2::identity();
res.gemm(k, &m1, &m2, -k);
assert_eq!(res, Matrix2::zero());
}
#[cfg(feature = "arbitrary")]
mod blas_quickcheck {
use na::{DMatrix, DVector};
use std::cmp;
@ -103,3 +126,4 @@ quickcheck! {
relative_eq!(res, expected, epsilon = 1.0e-7)
}
}
}