From e1c8e1bccfaca6c25fdfd534f72ea901f6cfe0e2 Mon Sep 17 00:00:00 2001 From: Jakub Konka Date: Wed, 4 Sep 2019 16:02:31 +0200 Subject: [PATCH 1/2] Fix Vector::axpy for noncommutative cases One example would be performing simple matrix multiplication over a division algebra such as quaternions. --- src/base/blas.rs | 16 ++--- tests/core/blas.rs | 172 ++++++++++++++++++++++++++------------------- 2 files changed, 106 insertions(+), 82 deletions(-) diff --git a/src/base/blas.rs b/src/base/blas.rs index cc8f2345..dec28d07 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -473,7 +473,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { let y = y.get_unchecked_mut(i * stride1); - *y = a * *x.get_unchecked(i * stride2) + beta * *y; + *y = *x.get_unchecked(i * stride2) * a + *y * beta; } } } @@ -482,7 +482,7 @@ fn array_ax(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len: where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { - *y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2); + *y.get_unchecked_mut(i * stride1) = *x.get_unchecked(i * stride2) * a; } } } @@ -579,13 +579,13 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(alpha * val, &col2, beta); + self.axpy(val * alpha, &col2, beta); for j in 1..ncols2 { let col2 = a.column(j); let val = unsafe { *x.vget_unchecked(j) }; - self.axpy(alpha * val, &col2, N::one()); + self.axpy(val * alpha, &col2, N::one()); } } @@ -624,7 +624,7 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(alpha * val, &col2, beta); + self.axpy(val * alpha, &col2, beta); self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..)); for j in 1..dim2 { @@ -637,7 +637,7 @@ where *self.vget_unchecked_mut(j) += alpha * dot; } self.rows_range_mut(j + 1..) - .axpy(alpha * val, &col2.rows_range(j + 1..), N::one()); + .axpy(val * alpha, &col2.rows_range(j + 1..), N::one()); } } @@ -890,7 +890,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul for j in 0..ncols1 { // FIXME: avoid bound checks. let val = unsafe { conjugate(*y.vget_unchecked(j)) }; - self.column_mut(j).axpy(alpha * val, x, beta); + self.column_mut(j).axpy(val * alpha, x, beta); } } @@ -1256,7 +1256,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul let subdim = Dynamic::new(dim1 - j); // FIXME: avoid bound checks. self.generic_slice_mut((j, j), (subdim, U1)).axpy( - alpha * val, + val * alpha, &x.rows_range(j..), beta, ); diff --git a/tests/core/blas.rs b/tests/core/blas.rs index 38113c17..9b7be4af 100644 --- a/tests/core/blas.rs +++ b/tests/core/blas.rs @@ -1,105 +1,129 @@ -#![cfg(feature = "arbitrary")] +use na::{geometry::Quaternion, Matrix2, Vector3}; +use num_traits::{One, Zero}; -use na::{DMatrix, DVector}; -use std::cmp; +#[test] +fn gemm_noncommutative() { + type Qf64 = Quaternion; + let i = Qf64::from_imag(Vector3::new(1.0, 0.0, 0.0)); + let j = Qf64::from_imag(Vector3::new(0.0, 1.0, 0.0)); + let k = Qf64::from_imag(Vector3::new(0.0, 0.0, 1.0)); -quickcheck! { - /* - * - * Symmetric operators. - * - */ - fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let a = &a * a.transpose(); + let m1 = Matrix2::new(k, Qf64::zero(), j, i); + // this is the inverse of m1 + let m2 = Matrix2::new(-k, Qf64::zero(), Qf64::one(), -i); - let x = DVector::new_random(n); - let mut y1 = DVector::new_random(n); - let mut y2 = y1.clone(); + let mut res: Matrix2 = Matrix2::zero(); + res.gemm(Qf64::one(), &m1, &m2, Qf64::zero()); + assert_eq!(res, Matrix2::identity()); - y1.gemv(alpha, &a, &x, beta); - y2.sygemv(alpha, &a.lower_triangle(), &x, beta); + let mut res: Matrix2 = Matrix2::identity(); + res.gemm(k, &m1, &m2, -k); + assert_eq!(res, Matrix2::zero()); +} - if !relative_eq!(y1, y2, epsilon = 1.0e-10) { - return false; +#[cfg(feature = "arbitrary")] +mod blas_quickcheck { + use na::{DMatrix, DVector}; + use std::cmp; + + quickcheck! { + /* + * + * Symmetric operators. + * + */ + fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let a = &a * a.transpose(); + + let x = DVector::new_random(n); + let mut y1 = DVector::new_random(n); + let mut y2 = y1.clone(); + + y1.gemv(alpha, &a, &x, beta); + y2.sygemv(alpha, &a.lower_triangle(), &x, beta); + + if !relative_eq!(y1, y2, epsilon = 1.0e-10) { + return false; + } + + y1.gemv(alpha, &a, &x, 0.0); + y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0); + + relative_eq!(y1, y2, epsilon = 1.0e-10) } - y1.gemv(alpha, &a, &x, 0.0); - y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0); + fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let x = DVector::new_random(n); + let mut y1 = DVector::new_random(n); + let mut y2 = y1.clone(); - relative_eq!(y1, y2, epsilon = 1.0e-10) - } + y1.gemv(alpha, &a, &x, beta); + y2.gemv_tr(alpha, &a.transpose(), &x, beta); - fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let x = DVector::new_random(n); - let mut y1 = DVector::new_random(n); - let mut y2 = y1.clone(); + if !relative_eq!(y1, y2, epsilon = 1.0e-10) { + return false; + } - y1.gemv(alpha, &a, &x, beta); - y2.gemv_tr(alpha, &a.transpose(), &x, beta); + y1.gemv(alpha, &a, &x, 0.0); + y2.gemv_tr(alpha, &a.transpose(), &x, 0.0); - if !relative_eq!(y1, y2, epsilon = 1.0e-10) { - return false; + relative_eq!(y1, y2, epsilon = 1.0e-10) } - y1.gemv(alpha, &a, &x, 0.0); - y2.gemv_tr(alpha, &a.transpose(), &x, 0.0); + fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let mut a1 = &a * a.transpose(); + let mut a2 = a1.lower_triangle(); - relative_eq!(y1, y2, epsilon = 1.0e-10) - } + let x = DVector::new_random(n); + let y = DVector::new_random(n); - fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let mut a1 = &a * a.transpose(); - let mut a2 = a1.lower_triangle(); + a1.ger(alpha, &x, &y, beta); + a2.syger(alpha, &x, &y, beta); - let x = DVector::new_random(n); - let y = DVector::new_random(n); + if !relative_eq!(a1.lower_triangle(), a2) { + return false; + } - a1.ger(alpha, &x, &y, beta); - a2.syger(alpha, &x, &y, beta); + a1.ger(alpha, &x, &y, 0.0); + a2.syger(alpha, &x, &y, 0.0); - if !relative_eq!(a1.lower_triangle(), a2) { - return false; + relative_eq!(a1.lower_triangle(), a2) } - a1.ger(alpha, &x, &y, 0.0); - a2.syger(alpha, &x, &y, 0.0); + fn quadform(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let rhs = DMatrix::::new_random(6, n); + let mid = DMatrix::::new_random(6, 6); + let mut res = DMatrix::new_random(n, n); - relative_eq!(a1.lower_triangle(), a2) - } + let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha; - fn quadform(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let rhs = DMatrix::::new_random(6, n); - let mid = DMatrix::::new_random(6, 6); - let mut res = DMatrix::new_random(n, n); + res.quadform(alpha, &mid, &rhs, beta); - let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha; + println!("{}{}", res, expected); - res.quadform(alpha, &mid, &rhs, beta); + relative_eq!(res, expected, epsilon = 1.0e-7) + } - println!("{}{}", res, expected); + fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let lhs = DMatrix::::new_random(6, n); + let mid = DMatrix::::new_random(n, n); + let mut res = DMatrix::new_random(6, 6); - relative_eq!(res, expected, epsilon = 1.0e-7) - } + let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; - fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let lhs = DMatrix::::new_random(6, n); - let mid = DMatrix::::new_random(n, n); - let mut res = DMatrix::new_random(6, 6); + res.quadform_tr(alpha, &lhs, &mid , beta); - let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; + println!("{}{}", res, expected); - res.quadform_tr(alpha, &lhs, &mid , beta); - - println!("{}{}", res, expected); - - relative_eq!(res, expected, epsilon = 1.0e-7) + relative_eq!(res, expected, epsilon = 1.0e-7) + } } } From fe65b1c129a16dfff13a1541f590260fcc939251 Mon Sep 17 00:00:00 2001 From: Jakub Konka Date: Mon, 4 Nov 2019 10:27:57 +0100 Subject: [PATCH 2/2] Add Vector::axcpy method The added method `Vector::axcpy` generalises `Vector::gemv` to noncommutative cases since it allows us to write for `gemv` `self.axcpy(alpha, &col2, val, beta)`, instead the usual `self.axpy(alpha * val, &col2, beta)`. Hence, `axcpy` preserves the order of scalar multiplication which is important for applications where commutativity is not guaranteed (e.g., matrices of quaternions, etc.). This commmit also removes helpers `array_axpy` and `array_ax`, and replaces them with `array_axcpy` and `array_axc` respectively, which like above preserve the order of scalar multiplication. Finally, `Vector::axpy` is preserved, however, now expressed in terms of `Vector::axcpy` like so: ``` self.axcpy(alpha * val, &col2, beta) ``` --- src/base/blas.rs | 68 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/src/base/blas.rs b/src/base/blas.rs index dec28d07..622761fe 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -468,21 +468,21 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul } } -fn array_axpy(y: &mut [N], a: N, x: &[N], beta: N, stride1: usize, stride2: usize, len: usize) +fn array_axcpy(y: &mut [N], a: N, x: &[N], c: N, beta: N, stride1: usize, stride2: usize, len: usize) where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { let y = y.get_unchecked_mut(i * stride1); - *y = *x.get_unchecked(i * stride2) * a + *y * beta; + *y = a * *x.get_unchecked(i * stride2) * c + beta * *y; } } } -fn array_ax(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len: usize) +fn array_axc(y: &mut [N], a: N, x: &[N], c: N, stride1: usize, stride2: usize, len: usize) where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { - *y.get_unchecked_mut(i * stride1) = *x.get_unchecked(i * stride2) * a; + *y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2) * c; } } } @@ -492,6 +492,40 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul, S: StorageMut, { + /// Computes `self = a * x * c + b * self`. + /// + /// If `b` is zero, `self` is never read from. + /// + /// # Examples: + /// + /// ``` + /// # use nalgebra::Vector3; + /// let mut vec1 = Vector3::new(1.0, 2.0, 3.0); + /// let vec2 = Vector3::new(0.1, 0.2, 0.3); + /// vec1.axcpy(5.0, &vec2, 2.0, 5.0); + /// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0)); + /// ``` + #[inline] + pub fn axcpy(&mut self, a: N, x: &Vector, c: N, b: N) + where + SB: Storage, + ShapeConstraint: DimEq, + { + assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes."); + + let rstride1 = self.strides().0; + let rstride2 = x.strides().0; + + let y = self.data.as_mut_slice(); + let x = x.data.as_slice(); + + if !b.is_zero() { + array_axcpy(y, a, x, c, b, rstride1, rstride2, x.len()); + } else { + array_axc(y, a, x, c, rstride1, rstride2, x.len()); + } + } + /// Computes `self = a * x + b * self`. /// /// If `b` is zero, `self` is never read from. @@ -508,22 +542,12 @@ where #[inline] pub fn axpy(&mut self, a: N, x: &Vector, b: N) where + N: One, SB: Storage, ShapeConstraint: DimEq, { assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes."); - - let rstride1 = self.strides().0; - let rstride2 = x.strides().0; - - let y = self.data.as_mut_slice(); - let x = x.data.as_slice(); - - if !b.is_zero() { - array_axpy(y, a, x, b, rstride1, rstride2, x.len()); - } else { - array_ax(y, a, x, rstride1, rstride2, x.len()); - } + self.axcpy(a, x, N::one(), b) } /// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and @@ -579,13 +603,13 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(val * alpha, &col2, beta); + self.axcpy(alpha, &col2, val, beta); for j in 1..ncols2 { let col2 = a.column(j); let val = unsafe { *x.vget_unchecked(j) }; - self.axpy(val * alpha, &col2, N::one()); + self.axcpy(alpha, &col2, val, N::one()); } } @@ -624,7 +648,7 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(val * alpha, &col2, beta); + self.axpy(alpha * val, &col2, beta); self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..)); for j in 1..dim2 { @@ -637,7 +661,7 @@ where *self.vget_unchecked_mut(j) += alpha * dot; } self.rows_range_mut(j + 1..) - .axpy(val * alpha, &col2.rows_range(j + 1..), N::one()); + .axpy(alpha * val, &col2.rows_range(j + 1..), N::one()); } } @@ -890,7 +914,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul for j in 0..ncols1 { // FIXME: avoid bound checks. let val = unsafe { conjugate(*y.vget_unchecked(j)) }; - self.column_mut(j).axpy(val * alpha, x, beta); + self.column_mut(j).axpy(alpha * val, x, beta); } } @@ -1256,7 +1280,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul let subdim = Dynamic::new(dim1 - j); // FIXME: avoid bound checks. self.generic_slice_mut((j, j), (subdim, U1)).axpy( - val * alpha, + alpha * val, &x.rows_range(j..), beta, );