Fix Vector::axpy for noncommutative cases

One example would be performing simple matrix multiplication
over a division algebra such as quaternions.
This commit is contained in:
Jakub Konka 2019-09-04 16:02:31 +02:00
parent bde8fbe10f
commit d0fa79f6e1
2 changed files with 106 additions and 82 deletions

View File

@ -473,7 +473,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul {
for i in 0..len { for i in 0..len {
unsafe { unsafe {
let y = y.get_unchecked_mut(i * stride1); let y = y.get_unchecked_mut(i * stride1);
*y = a * *x.get_unchecked(i * stride2) + beta * *y; *y = *x.get_unchecked(i * stride2) * a + *y * beta;
} }
} }
} }
@ -482,7 +482,7 @@ fn array_ax<N>(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len:
where N: Scalar + Zero + ClosedAdd + ClosedMul { where N: Scalar + Zero + ClosedAdd + ClosedMul {
for i in 0..len { for i in 0..len {
unsafe { unsafe {
*y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2); *y.get_unchecked_mut(i * stride1) = *x.get_unchecked(i * stride2) * a;
} }
} }
} }
@ -579,13 +579,13 @@ where
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
let col2 = a.column(0); let col2 = a.column(0);
let val = unsafe { *x.vget_unchecked(0) }; let val = unsafe { *x.vget_unchecked(0) };
self.axpy(alpha * val, &col2, beta); self.axpy(val * alpha, &col2, beta);
for j in 1..ncols2 { for j in 1..ncols2 {
let col2 = a.column(j); let col2 = a.column(j);
let val = unsafe { *x.vget_unchecked(j) }; let val = unsafe { *x.vget_unchecked(j) };
self.axpy(alpha * val, &col2, N::one()); self.axpy(val * alpha, &col2, N::one());
} }
} }
@ -624,7 +624,7 @@ where
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
let col2 = a.column(0); let col2 = a.column(0);
let val = unsafe { *x.vget_unchecked(0) }; let val = unsafe { *x.vget_unchecked(0) };
self.axpy(alpha * val, &col2, beta); self.axpy(val * alpha, &col2, beta);
self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..)); self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..));
for j in 1..dim2 { for j in 1..dim2 {
@ -637,7 +637,7 @@ where
*self.vget_unchecked_mut(j) += alpha * dot; *self.vget_unchecked_mut(j) += alpha * dot;
} }
self.rows_range_mut(j + 1..) self.rows_range_mut(j + 1..)
.axpy(alpha * val, &col2.rows_range(j + 1..), N::one()); .axpy(val * alpha, &col2.rows_range(j + 1..), N::one());
} }
} }
@ -890,7 +890,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
for j in 0..ncols1 { for j in 0..ncols1 {
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
let val = unsafe { conjugate(*y.vget_unchecked(j)) }; let val = unsafe { conjugate(*y.vget_unchecked(j)) };
self.column_mut(j).axpy(alpha * val, x, beta); self.column_mut(j).axpy(val * alpha, x, beta);
} }
} }
@ -1256,7 +1256,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
let subdim = Dynamic::new(dim1 - j); let subdim = Dynamic::new(dim1 - j);
// FIXME: avoid bound checks. // FIXME: avoid bound checks.
self.generic_slice_mut((j, j), (subdim, U1)).axpy( self.generic_slice_mut((j, j), (subdim, U1)).axpy(
alpha * val, val * alpha,
&x.rows_range(j..), &x.rows_range(j..),
beta, beta,
); );

View File

@ -1,105 +1,129 @@
#![cfg(feature = "arbitrary")] use na::{geometry::Quaternion, Matrix2, Vector3};
use num_traits::{One, Zero};
use na::{DMatrix, DVector}; #[test]
use std::cmp; fn gemm_noncommutative() {
type Qf64 = Quaternion<f64>;
let i = Qf64::from_imag(Vector3::new(1.0, 0.0, 0.0));
let j = Qf64::from_imag(Vector3::new(0.0, 1.0, 0.0));
let k = Qf64::from_imag(Vector3::new(0.0, 0.0, 1.0));
quickcheck! { let m1 = Matrix2::new(k, Qf64::zero(), j, i);
/* // this is the inverse of m1
* let m2 = Matrix2::new(-k, Qf64::zero(), Qf64::one(), -i);
* Symmetric operators.
*
*/
fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool {
let n = cmp::max(1, cmp::min(n, 50));
let a = DMatrix::<f64>::new_random(n, n);
let a = &a * a.transpose();
let x = DVector::new_random(n); let mut res: Matrix2<Qf64> = Matrix2::zero();
let mut y1 = DVector::new_random(n); res.gemm(Qf64::one(), &m1, &m2, Qf64::zero());
let mut y2 = y1.clone(); assert_eq!(res, Matrix2::identity());
y1.gemv(alpha, &a, &x, beta); let mut res: Matrix2<Qf64> = Matrix2::identity();
y2.sygemv(alpha, &a.lower_triangle(), &x, beta); res.gemm(k, &m1, &m2, -k);
assert_eq!(res, Matrix2::zero());
}
if !relative_eq!(y1, y2, epsilon = 1.0e-10) { #[cfg(feature = "arbitrary")]
return false; mod blas_quickcheck {
use na::{DMatrix, DVector};
use std::cmp;
quickcheck! {
/*
*
* Symmetric operators.
*
*/
fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool {
let n = cmp::max(1, cmp::min(n, 50));
let a = DMatrix::<f64>::new_random(n, n);
let a = &a * a.transpose();
let x = DVector::new_random(n);
let mut y1 = DVector::new_random(n);
let mut y2 = y1.clone();
y1.gemv(alpha, &a, &x, beta);
y2.sygemv(alpha, &a.lower_triangle(), &x, beta);
if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
return false;
}
y1.gemv(alpha, &a, &x, 0.0);
y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0);
relative_eq!(y1, y2, epsilon = 1.0e-10)
} }
y1.gemv(alpha, &a, &x, 0.0); fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool {
y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0); let n = cmp::max(1, cmp::min(n, 50));
let a = DMatrix::<f64>::new_random(n, n);
let x = DVector::new_random(n);
let mut y1 = DVector::new_random(n);
let mut y2 = y1.clone();
relative_eq!(y1, y2, epsilon = 1.0e-10) y1.gemv(alpha, &a, &x, beta);
} y2.gemv_tr(alpha, &a.transpose(), &x, beta);
fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool { if !relative_eq!(y1, y2, epsilon = 1.0e-10) {
let n = cmp::max(1, cmp::min(n, 50)); return false;
let a = DMatrix::<f64>::new_random(n, n); }
let x = DVector::new_random(n);
let mut y1 = DVector::new_random(n);
let mut y2 = y1.clone();
y1.gemv(alpha, &a, &x, beta); y1.gemv(alpha, &a, &x, 0.0);
y2.gemv_tr(alpha, &a.transpose(), &x, beta); y2.gemv_tr(alpha, &a.transpose(), &x, 0.0);
if !relative_eq!(y1, y2, epsilon = 1.0e-10) { relative_eq!(y1, y2, epsilon = 1.0e-10)
return false;
} }
y1.gemv(alpha, &a, &x, 0.0); fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool {
y2.gemv_tr(alpha, &a.transpose(), &x, 0.0); let n = cmp::max(1, cmp::min(n, 50));
let a = DMatrix::<f64>::new_random(n, n);
let mut a1 = &a * a.transpose();
let mut a2 = a1.lower_triangle();
relative_eq!(y1, y2, epsilon = 1.0e-10) let x = DVector::new_random(n);
} let y = DVector::new_random(n);
fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { a1.ger(alpha, &x, &y, beta);
let n = cmp::max(1, cmp::min(n, 50)); a2.syger(alpha, &x, &y, beta);
let a = DMatrix::<f64>::new_random(n, n);
let mut a1 = &a * a.transpose();
let mut a2 = a1.lower_triangle();
let x = DVector::new_random(n); if !relative_eq!(a1.lower_triangle(), a2) {
let y = DVector::new_random(n); return false;
}
a1.ger(alpha, &x, &y, beta); a1.ger(alpha, &x, &y, 0.0);
a2.syger(alpha, &x, &y, beta); a2.syger(alpha, &x, &y, 0.0);
if !relative_eq!(a1.lower_triangle(), a2) { relative_eq!(a1.lower_triangle(), a2)
return false;
} }
a1.ger(alpha, &x, &y, 0.0); fn quadform(n: usize, alpha: f64, beta: f64) -> bool {
a2.syger(alpha, &x, &y, 0.0); let n = cmp::max(1, cmp::min(n, 50));
let rhs = DMatrix::<f64>::new_random(6, n);
let mid = DMatrix::<f64>::new_random(6, 6);
let mut res = DMatrix::new_random(n, n);
relative_eq!(a1.lower_triangle(), a2) let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha;
}
fn quadform(n: usize, alpha: f64, beta: f64) -> bool { res.quadform(alpha, &mid, &rhs, beta);
let n = cmp::max(1, cmp::min(n, 50));
let rhs = DMatrix::<f64>::new_random(6, n);
let mid = DMatrix::<f64>::new_random(6, 6);
let mut res = DMatrix::new_random(n, n);
let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha; println!("{}{}", res, expected);
res.quadform(alpha, &mid, &rhs, beta); relative_eq!(res, expected, epsilon = 1.0e-7)
}
println!("{}{}", res, expected); fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool {
let n = cmp::max(1, cmp::min(n, 50));
let lhs = DMatrix::<f64>::new_random(6, n);
let mid = DMatrix::<f64>::new_random(n, n);
let mut res = DMatrix::new_random(6, 6);
relative_eq!(res, expected, epsilon = 1.0e-7) let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha;
}
fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool { res.quadform_tr(alpha, &lhs, &mid , beta);
let n = cmp::max(1, cmp::min(n, 50));
let lhs = DMatrix::<f64>::new_random(6, n);
let mid = DMatrix::<f64>::new_random(n, n);
let mut res = DMatrix::new_random(6, 6);
let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; println!("{}{}", res, expected);
res.quadform_tr(alpha, &lhs, &mid , beta); relative_eq!(res, expected, epsilon = 1.0e-7)
}
println!("{}{}", res, expected);
relative_eq!(res, expected, epsilon = 1.0e-7)
} }
} }