Fix Vector::axpy for noncommutative cases (#648)
Fix Vector::axpy for noncommutative cases
This commit is contained in:
commit
d09aa50a31
@ -468,21 +468,21 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn array_axpy<N>(y: &mut [N], a: N, x: &[N], beta: N, stride1: usize, stride2: usize, len: usize)
|
fn array_axcpy<N>(y: &mut [N], a: N, x: &[N], c: N, beta: N, stride1: usize, stride2: usize, len: usize)
|
||||||
where N: Scalar + Zero + ClosedAdd + ClosedMul {
|
where N: Scalar + Zero + ClosedAdd + ClosedMul {
|
||||||
for i in 0..len {
|
for i in 0..len {
|
||||||
unsafe {
|
unsafe {
|
||||||
let y = y.get_unchecked_mut(i * stride1);
|
let y = y.get_unchecked_mut(i * stride1);
|
||||||
*y = a * *x.get_unchecked(i * stride2) + beta * *y;
|
*y = a * *x.get_unchecked(i * stride2) * c + beta * *y;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn array_ax<N>(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len: usize)
|
fn array_axc<N>(y: &mut [N], a: N, x: &[N], c: N, stride1: usize, stride2: usize, len: usize)
|
||||||
where N: Scalar + Zero + ClosedAdd + ClosedMul {
|
where N: Scalar + Zero + ClosedAdd + ClosedMul {
|
||||||
for i in 0..len {
|
for i in 0..len {
|
||||||
unsafe {
|
unsafe {
|
||||||
*y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2);
|
*y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2) * c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -492,6 +492,40 @@ where
|
|||||||
N: Scalar + Zero + ClosedAdd + ClosedMul,
|
N: Scalar + Zero + ClosedAdd + ClosedMul,
|
||||||
S: StorageMut<N, D>,
|
S: StorageMut<N, D>,
|
||||||
{
|
{
|
||||||
|
/// Computes `self = a * x * c + b * self`.
|
||||||
|
///
|
||||||
|
/// If `b` is zero, `self` is never read from.
|
||||||
|
///
|
||||||
|
/// # Examples:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra::Vector3;
|
||||||
|
/// let mut vec1 = Vector3::new(1.0, 2.0, 3.0);
|
||||||
|
/// let vec2 = Vector3::new(0.1, 0.2, 0.3);
|
||||||
|
/// vec1.axcpy(5.0, &vec2, 2.0, 5.0);
|
||||||
|
/// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0));
|
||||||
|
/// ```
|
||||||
|
#[inline]
|
||||||
|
pub fn axcpy<D2: Dim, SB>(&mut self, a: N, x: &Vector<N, D2, SB>, c: N, b: N)
|
||||||
|
where
|
||||||
|
SB: Storage<N, D2>,
|
||||||
|
ShapeConstraint: DimEq<D, D2>,
|
||||||
|
{
|
||||||
|
assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
|
||||||
|
|
||||||
|
let rstride1 = self.strides().0;
|
||||||
|
let rstride2 = x.strides().0;
|
||||||
|
|
||||||
|
let y = self.data.as_mut_slice();
|
||||||
|
let x = x.data.as_slice();
|
||||||
|
|
||||||
|
if !b.is_zero() {
|
||||||
|
array_axcpy(y, a, x, c, b, rstride1, rstride2, x.len());
|
||||||
|
} else {
|
||||||
|
array_axc(y, a, x, c, rstride1, rstride2, x.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Computes `self = a * x + b * self`.
|
/// Computes `self = a * x + b * self`.
|
||||||
///
|
///
|
||||||
/// If `b` is zero, `self` is never read from.
|
/// If `b` is zero, `self` is never read from.
|
||||||
@ -508,22 +542,12 @@ where
|
|||||||
#[inline]
|
#[inline]
|
||||||
pub fn axpy<D2: Dim, SB>(&mut self, a: N, x: &Vector<N, D2, SB>, b: N)
|
pub fn axpy<D2: Dim, SB>(&mut self, a: N, x: &Vector<N, D2, SB>, b: N)
|
||||||
where
|
where
|
||||||
|
N: One,
|
||||||
SB: Storage<N, D2>,
|
SB: Storage<N, D2>,
|
||||||
ShapeConstraint: DimEq<D, D2>,
|
ShapeConstraint: DimEq<D, D2>,
|
||||||
{
|
{
|
||||||
assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes.");
|
assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes.");
|
||||||
|
self.axcpy(a, x, N::one(), b)
|
||||||
let rstride1 = self.strides().0;
|
|
||||||
let rstride2 = x.strides().0;
|
|
||||||
|
|
||||||
let y = self.data.as_mut_slice();
|
|
||||||
let x = x.data.as_slice();
|
|
||||||
|
|
||||||
if !b.is_zero() {
|
|
||||||
array_axpy(y, a, x, b, rstride1, rstride2, x.len());
|
|
||||||
} else {
|
|
||||||
array_ax(y, a, x, rstride1, rstride2, x.len());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and
|
/// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and
|
||||||
@ -579,13 +603,13 @@ where
|
|||||||
// FIXME: avoid bound checks.
|
// FIXME: avoid bound checks.
|
||||||
let col2 = a.column(0);
|
let col2 = a.column(0);
|
||||||
let val = unsafe { *x.vget_unchecked(0) };
|
let val = unsafe { *x.vget_unchecked(0) };
|
||||||
self.axpy(alpha * val, &col2, beta);
|
self.axcpy(alpha, &col2, val, beta);
|
||||||
|
|
||||||
for j in 1..ncols2 {
|
for j in 1..ncols2 {
|
||||||
let col2 = a.column(j);
|
let col2 = a.column(j);
|
||||||
let val = unsafe { *x.vget_unchecked(j) };
|
let val = unsafe { *x.vget_unchecked(j) };
|
||||||
|
|
||||||
self.axpy(alpha * val, &col2, N::one());
|
self.axcpy(alpha, &col2, val, N::one());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,105 +1,129 @@
|
|||||||
#![cfg(feature = "arbitrary")]
|
use na::{geometry::Quaternion, Matrix2, Vector3};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
|
||||||
use na::{DMatrix, DVector};
|
#[test]
|
||||||
use std::cmp;
|
fn gemm_noncommutative() {
|
||||||
|
type Qf64 = Quaternion<f64>;
|
||||||
|
let i = Qf64::from_imag(Vector3::new(1.0, 0.0, 0.0));
|
||||||
|
let j = Qf64::from_imag(Vector3::new(0.0, 1.0, 0.0));
|
||||||
|
let k = Qf64::from_imag(Vector3::new(0.0, 0.0, 1.0));
|
||||||
|
|
||||||
quickcheck! {
|
let m1 = Matrix2::new(k, Qf64::zero(), j, i);
|
||||||
/*
|
// this is the inverse of m1
|
||||||
*
|
let m2 = Matrix2::new(-k, Qf64::zero(), Qf64::one(), -i);
|
||||||
* Symmetric operators.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
|
||||||
let n = cmp::max(1, cmp::min(n, 50));
|
|
||||||
let a = DMatrix::<f64>::new_random(n, n);
|
|
||||||
let a = &a * a.transpose();
|
|
||||||
|
|
||||||
let x = DVector::new_random(n);
|
let mut res: Matrix2<Qf64> = Matrix2::zero();
|
||||||
let mut y1 = DVector::new_random(n);
|
res.gemm(Qf64::one(), &m1, &m2, Qf64::zero());
|
||||||
let mut y2 = y1.clone();
|
assert_eq!(res, Matrix2::identity());
|
||||||
|
|
||||||
y1.gemv(alpha, &a, &x, beta);
|
let mut res: Matrix2<Qf64> = Matrix2::identity();
|
||||||
y2.sygemv(alpha, &a.lower_triangle(), &x, beta);
|
res.gemm(k, &m1, &m2, -k);
|
||||||
|
assert_eq!(res, Matrix2::zero());
|
||||||
|
}
|
||||||
|
|
||||||
if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
|
#[cfg(feature = "arbitrary")]
|
||||||
return false;
|
mod blas_quickcheck {
|
||||||
|
use na::{DMatrix, DVector};
|
||||||
|
use std::cmp;
|
||||||
|
|
||||||
|
quickcheck! {
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Symmetric operators.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
||||||
|
let n = cmp::max(1, cmp::min(n, 50));
|
||||||
|
let a = DMatrix::<f64>::new_random(n, n);
|
||||||
|
let a = &a * a.transpose();
|
||||||
|
|
||||||
|
let x = DVector::new_random(n);
|
||||||
|
let mut y1 = DVector::new_random(n);
|
||||||
|
let mut y2 = y1.clone();
|
||||||
|
|
||||||
|
y1.gemv(alpha, &a, &x, beta);
|
||||||
|
y2.sygemv(alpha, &a.lower_triangle(), &x, beta);
|
||||||
|
|
||||||
|
if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
y1.gemv(alpha, &a, &x, 0.0);
|
||||||
|
y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0);
|
||||||
|
|
||||||
|
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
||||||
}
|
}
|
||||||
|
|
||||||
y1.gemv(alpha, &a, &x, 0.0);
|
fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool {
|
||||||
y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0);
|
let n = cmp::max(1, cmp::min(n, 50));
|
||||||
|
let a = DMatrix::<f64>::new_random(n, n);
|
||||||
|
let x = DVector::new_random(n);
|
||||||
|
let mut y1 = DVector::new_random(n);
|
||||||
|
let mut y2 = y1.clone();
|
||||||
|
|
||||||
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
y1.gemv(alpha, &a, &x, beta);
|
||||||
}
|
y2.gemv_tr(alpha, &a.transpose(), &x, beta);
|
||||||
|
|
||||||
fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool {
|
if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
|
||||||
let n = cmp::max(1, cmp::min(n, 50));
|
return false;
|
||||||
let a = DMatrix::<f64>::new_random(n, n);
|
}
|
||||||
let x = DVector::new_random(n);
|
|
||||||
let mut y1 = DVector::new_random(n);
|
|
||||||
let mut y2 = y1.clone();
|
|
||||||
|
|
||||||
y1.gemv(alpha, &a, &x, beta);
|
y1.gemv(alpha, &a, &x, 0.0);
|
||||||
y2.gemv_tr(alpha, &a.transpose(), &x, beta);
|
y2.gemv_tr(alpha, &a.transpose(), &x, 0.0);
|
||||||
|
|
||||||
if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
|
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
y1.gemv(alpha, &a, &x, 0.0);
|
fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
||||||
y2.gemv_tr(alpha, &a.transpose(), &x, 0.0);
|
let n = cmp::max(1, cmp::min(n, 50));
|
||||||
|
let a = DMatrix::<f64>::new_random(n, n);
|
||||||
|
let mut a1 = &a * a.transpose();
|
||||||
|
let mut a2 = a1.lower_triangle();
|
||||||
|
|
||||||
relative_eq!(y1, y2, epsilon = 1.0e-10)
|
let x = DVector::new_random(n);
|
||||||
}
|
let y = DVector::new_random(n);
|
||||||
|
|
||||||
fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
a1.ger(alpha, &x, &y, beta);
|
||||||
let n = cmp::max(1, cmp::min(n, 50));
|
a2.syger(alpha, &x, &y, beta);
|
||||||
let a = DMatrix::<f64>::new_random(n, n);
|
|
||||||
let mut a1 = &a * a.transpose();
|
|
||||||
let mut a2 = a1.lower_triangle();
|
|
||||||
|
|
||||||
let x = DVector::new_random(n);
|
if !relative_eq!(a1.lower_triangle(), a2) {
|
||||||
let y = DVector::new_random(n);
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
a1.ger(alpha, &x, &y, beta);
|
a1.ger(alpha, &x, &y, 0.0);
|
||||||
a2.syger(alpha, &x, &y, beta);
|
a2.syger(alpha, &x, &y, 0.0);
|
||||||
|
|
||||||
if !relative_eq!(a1.lower_triangle(), a2) {
|
relative_eq!(a1.lower_triangle(), a2)
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
a1.ger(alpha, &x, &y, 0.0);
|
fn quadform(n: usize, alpha: f64, beta: f64) -> bool {
|
||||||
a2.syger(alpha, &x, &y, 0.0);
|
let n = cmp::max(1, cmp::min(n, 50));
|
||||||
|
let rhs = DMatrix::<f64>::new_random(6, n);
|
||||||
|
let mid = DMatrix::<f64>::new_random(6, 6);
|
||||||
|
let mut res = DMatrix::new_random(n, n);
|
||||||
|
|
||||||
relative_eq!(a1.lower_triangle(), a2)
|
let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha;
|
||||||
}
|
|
||||||
|
|
||||||
fn quadform(n: usize, alpha: f64, beta: f64) -> bool {
|
res.quadform(alpha, &mid, &rhs, beta);
|
||||||
let n = cmp::max(1, cmp::min(n, 50));
|
|
||||||
let rhs = DMatrix::<f64>::new_random(6, n);
|
|
||||||
let mid = DMatrix::<f64>::new_random(6, 6);
|
|
||||||
let mut res = DMatrix::new_random(n, n);
|
|
||||||
|
|
||||||
let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha;
|
println!("{}{}", res, expected);
|
||||||
|
|
||||||
res.quadform(alpha, &mid, &rhs, beta);
|
relative_eq!(res, expected, epsilon = 1.0e-7)
|
||||||
|
}
|
||||||
|
|
||||||
println!("{}{}", res, expected);
|
fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool {
|
||||||
|
let n = cmp::max(1, cmp::min(n, 50));
|
||||||
|
let lhs = DMatrix::<f64>::new_random(6, n);
|
||||||
|
let mid = DMatrix::<f64>::new_random(n, n);
|
||||||
|
let mut res = DMatrix::new_random(6, 6);
|
||||||
|
|
||||||
relative_eq!(res, expected, epsilon = 1.0e-7)
|
let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha;
|
||||||
}
|
|
||||||
|
|
||||||
fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool {
|
res.quadform_tr(alpha, &lhs, &mid , beta);
|
||||||
let n = cmp::max(1, cmp::min(n, 50));
|
|
||||||
let lhs = DMatrix::<f64>::new_random(6, n);
|
|
||||||
let mid = DMatrix::<f64>::new_random(n, n);
|
|
||||||
let mut res = DMatrix::new_random(6, 6);
|
|
||||||
|
|
||||||
let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha;
|
println!("{}{}", res, expected);
|
||||||
|
|
||||||
res.quadform_tr(alpha, &lhs, &mid , beta);
|
relative_eq!(res, expected, epsilon = 1.0e-7)
|
||||||
|
}
|
||||||
println!("{}{}", res, expected);
|
|
||||||
|
|
||||||
relative_eq!(res, expected, epsilon = 1.0e-7)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user