From e1c8e1bccfaca6c25fdfd534f72ea901f6cfe0e2 Mon Sep 17 00:00:00 2001 From: Jakub Konka Date: Wed, 4 Sep 2019 16:02:31 +0200 Subject: [PATCH] Fix Vector::axpy for noncommutative cases One example would be performing simple matrix multiplication over a division algebra such as quaternions. --- src/base/blas.rs | 16 ++--- tests/core/blas.rs | 172 ++++++++++++++++++++++++++------------------- 2 files changed, 106 insertions(+), 82 deletions(-) diff --git a/src/base/blas.rs b/src/base/blas.rs index cc8f2345..dec28d07 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -473,7 +473,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { let y = y.get_unchecked_mut(i * stride1); - *y = a * *x.get_unchecked(i * stride2) + beta * *y; + *y = *x.get_unchecked(i * stride2) * a + *y * beta; } } } @@ -482,7 +482,7 @@ fn array_ax(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len: where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { - *y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2); + *y.get_unchecked_mut(i * stride1) = *x.get_unchecked(i * stride2) * a; } } } @@ -579,13 +579,13 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(alpha * val, &col2, beta); + self.axpy(val * alpha, &col2, beta); for j in 1..ncols2 { let col2 = a.column(j); let val = unsafe { *x.vget_unchecked(j) }; - self.axpy(alpha * val, &col2, N::one()); + self.axpy(val * alpha, &col2, N::one()); } } @@ -624,7 +624,7 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(alpha * val, &col2, beta); + self.axpy(val * alpha, &col2, beta); self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..)); for j in 1..dim2 { @@ -637,7 +637,7 @@ where *self.vget_unchecked_mut(j) += alpha * dot; } self.rows_range_mut(j + 1..) - .axpy(alpha * val, &col2.rows_range(j + 1..), N::one()); + .axpy(val * alpha, &col2.rows_range(j + 1..), N::one()); } } @@ -890,7 +890,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul for j in 0..ncols1 { // FIXME: avoid bound checks. let val = unsafe { conjugate(*y.vget_unchecked(j)) }; - self.column_mut(j).axpy(alpha * val, x, beta); + self.column_mut(j).axpy(val * alpha, x, beta); } } @@ -1256,7 +1256,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul let subdim = Dynamic::new(dim1 - j); // FIXME: avoid bound checks. self.generic_slice_mut((j, j), (subdim, U1)).axpy( - alpha * val, + val * alpha, &x.rows_range(j..), beta, ); diff --git a/tests/core/blas.rs b/tests/core/blas.rs index 38113c17..9b7be4af 100644 --- a/tests/core/blas.rs +++ b/tests/core/blas.rs @@ -1,105 +1,129 @@ -#![cfg(feature = "arbitrary")] +use na::{geometry::Quaternion, Matrix2, Vector3}; +use num_traits::{One, Zero}; -use na::{DMatrix, DVector}; -use std::cmp; +#[test] +fn gemm_noncommutative() { + type Qf64 = Quaternion; + let i = Qf64::from_imag(Vector3::new(1.0, 0.0, 0.0)); + let j = Qf64::from_imag(Vector3::new(0.0, 1.0, 0.0)); + let k = Qf64::from_imag(Vector3::new(0.0, 0.0, 1.0)); -quickcheck! { - /* - * - * Symmetric operators. - * - */ - fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let a = &a * a.transpose(); + let m1 = Matrix2::new(k, Qf64::zero(), j, i); + // this is the inverse of m1 + let m2 = Matrix2::new(-k, Qf64::zero(), Qf64::one(), -i); - let x = DVector::new_random(n); - let mut y1 = DVector::new_random(n); - let mut y2 = y1.clone(); + let mut res: Matrix2 = Matrix2::zero(); + res.gemm(Qf64::one(), &m1, &m2, Qf64::zero()); + assert_eq!(res, Matrix2::identity()); - y1.gemv(alpha, &a, &x, beta); - y2.sygemv(alpha, &a.lower_triangle(), &x, beta); + let mut res: Matrix2 = Matrix2::identity(); + res.gemm(k, &m1, &m2, -k); + assert_eq!(res, Matrix2::zero()); +} - if !relative_eq!(y1, y2, epsilon = 1.0e-10) { - return false; +#[cfg(feature = "arbitrary")] +mod blas_quickcheck { + use na::{DMatrix, DVector}; + use std::cmp; + + quickcheck! { + /* + * + * Symmetric operators. + * + */ + fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let a = &a * a.transpose(); + + let x = DVector::new_random(n); + let mut y1 = DVector::new_random(n); + let mut y2 = y1.clone(); + + y1.gemv(alpha, &a, &x, beta); + y2.sygemv(alpha, &a.lower_triangle(), &x, beta); + + if !relative_eq!(y1, y2, epsilon = 1.0e-10) { + return false; + } + + y1.gemv(alpha, &a, &x, 0.0); + y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0); + + relative_eq!(y1, y2, epsilon = 1.0e-10) } - y1.gemv(alpha, &a, &x, 0.0); - y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0); + fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let x = DVector::new_random(n); + let mut y1 = DVector::new_random(n); + let mut y2 = y1.clone(); - relative_eq!(y1, y2, epsilon = 1.0e-10) - } + y1.gemv(alpha, &a, &x, beta); + y2.gemv_tr(alpha, &a.transpose(), &x, beta); - fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let x = DVector::new_random(n); - let mut y1 = DVector::new_random(n); - let mut y2 = y1.clone(); + if !relative_eq!(y1, y2, epsilon = 1.0e-10) { + return false; + } - y1.gemv(alpha, &a, &x, beta); - y2.gemv_tr(alpha, &a.transpose(), &x, beta); + y1.gemv(alpha, &a, &x, 0.0); + y2.gemv_tr(alpha, &a.transpose(), &x, 0.0); - if !relative_eq!(y1, y2, epsilon = 1.0e-10) { - return false; + relative_eq!(y1, y2, epsilon = 1.0e-10) } - y1.gemv(alpha, &a, &x, 0.0); - y2.gemv_tr(alpha, &a.transpose(), &x, 0.0); + fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let mut a1 = &a * a.transpose(); + let mut a2 = a1.lower_triangle(); - relative_eq!(y1, y2, epsilon = 1.0e-10) - } + let x = DVector::new_random(n); + let y = DVector::new_random(n); - fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let mut a1 = &a * a.transpose(); - let mut a2 = a1.lower_triangle(); + a1.ger(alpha, &x, &y, beta); + a2.syger(alpha, &x, &y, beta); - let x = DVector::new_random(n); - let y = DVector::new_random(n); + if !relative_eq!(a1.lower_triangle(), a2) { + return false; + } - a1.ger(alpha, &x, &y, beta); - a2.syger(alpha, &x, &y, beta); + a1.ger(alpha, &x, &y, 0.0); + a2.syger(alpha, &x, &y, 0.0); - if !relative_eq!(a1.lower_triangle(), a2) { - return false; + relative_eq!(a1.lower_triangle(), a2) } - a1.ger(alpha, &x, &y, 0.0); - a2.syger(alpha, &x, &y, 0.0); + fn quadform(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let rhs = DMatrix::::new_random(6, n); + let mid = DMatrix::::new_random(6, 6); + let mut res = DMatrix::new_random(n, n); - relative_eq!(a1.lower_triangle(), a2) - } + let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha; - fn quadform(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let rhs = DMatrix::::new_random(6, n); - let mid = DMatrix::::new_random(6, 6); - let mut res = DMatrix::new_random(n, n); + res.quadform(alpha, &mid, &rhs, beta); - let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha; + println!("{}{}", res, expected); - res.quadform(alpha, &mid, &rhs, beta); + relative_eq!(res, expected, epsilon = 1.0e-7) + } - println!("{}{}", res, expected); + fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let lhs = DMatrix::::new_random(6, n); + let mid = DMatrix::::new_random(n, n); + let mut res = DMatrix::new_random(6, 6); - relative_eq!(res, expected, epsilon = 1.0e-7) - } + let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; - fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let lhs = DMatrix::::new_random(6, n); - let mid = DMatrix::::new_random(n, n); - let mut res = DMatrix::new_random(6, 6); + res.quadform_tr(alpha, &lhs, &mid , beta); - let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; + println!("{}{}", res, expected); - res.quadform_tr(alpha, &lhs, &mid , beta); - - println!("{}{}", res, expected); - - relative_eq!(res, expected, epsilon = 1.0e-7) + relative_eq!(res, expected, epsilon = 1.0e-7) + } } }