diff --git a/src/base/blas.rs b/src/base/blas.rs index cc8f2345..622761fe 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -468,21 +468,21 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul } } -fn array_axpy(y: &mut [N], a: N, x: &[N], beta: N, stride1: usize, stride2: usize, len: usize) +fn array_axcpy(y: &mut [N], a: N, x: &[N], c: N, beta: N, stride1: usize, stride2: usize, len: usize) where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { let y = y.get_unchecked_mut(i * stride1); - *y = a * *x.get_unchecked(i * stride2) + beta * *y; + *y = a * *x.get_unchecked(i * stride2) * c + beta * *y; } } } -fn array_ax(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len: usize) +fn array_axc(y: &mut [N], a: N, x: &[N], c: N, stride1: usize, stride2: usize, len: usize) where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { - *y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2); + *y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2) * c; } } } @@ -492,6 +492,40 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul, S: StorageMut, { + /// Computes `self = a * x * c + b * self`. + /// + /// If `b` is zero, `self` is never read from. + /// + /// # Examples: + /// + /// ``` + /// # use nalgebra::Vector3; + /// let mut vec1 = Vector3::new(1.0, 2.0, 3.0); + /// let vec2 = Vector3::new(0.1, 0.2, 0.3); + /// vec1.axcpy(5.0, &vec2, 2.0, 5.0); + /// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0)); + /// ``` + #[inline] + pub fn axcpy(&mut self, a: N, x: &Vector, c: N, b: N) + where + SB: Storage, + ShapeConstraint: DimEq, + { + assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes."); + + let rstride1 = self.strides().0; + let rstride2 = x.strides().0; + + let y = self.data.as_mut_slice(); + let x = x.data.as_slice(); + + if !b.is_zero() { + array_axcpy(y, a, x, c, b, rstride1, rstride2, x.len()); + } else { + array_axc(y, a, x, c, rstride1, rstride2, x.len()); + } + } + /// Computes `self = a * x + b * self`. /// /// If `b` is zero, `self` is never read from. @@ -508,22 +542,12 @@ where #[inline] pub fn axpy(&mut self, a: N, x: &Vector, b: N) where + N: One, SB: Storage, ShapeConstraint: DimEq, { assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes."); - - let rstride1 = self.strides().0; - let rstride2 = x.strides().0; - - let y = self.data.as_mut_slice(); - let x = x.data.as_slice(); - - if !b.is_zero() { - array_axpy(y, a, x, b, rstride1, rstride2, x.len()); - } else { - array_ax(y, a, x, rstride1, rstride2, x.len()); - } + self.axcpy(a, x, N::one(), b) } /// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and @@ -579,13 +603,13 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(alpha * val, &col2, beta); + self.axcpy(alpha, &col2, val, beta); for j in 1..ncols2 { let col2 = a.column(j); let val = unsafe { *x.vget_unchecked(j) }; - self.axpy(alpha * val, &col2, N::one()); + self.axcpy(alpha, &col2, val, N::one()); } } diff --git a/tests/core/blas.rs b/tests/core/blas.rs index 38113c17..9b7be4af 100644 --- a/tests/core/blas.rs +++ b/tests/core/blas.rs @@ -1,105 +1,129 @@ -#![cfg(feature = "arbitrary")] +use na::{geometry::Quaternion, Matrix2, Vector3}; +use num_traits::{One, Zero}; -use na::{DMatrix, DVector}; -use std::cmp; +#[test] +fn gemm_noncommutative() { + type Qf64 = Quaternion; + let i = Qf64::from_imag(Vector3::new(1.0, 0.0, 0.0)); + let j = Qf64::from_imag(Vector3::new(0.0, 1.0, 0.0)); + let k = Qf64::from_imag(Vector3::new(0.0, 0.0, 1.0)); -quickcheck! { - /* - * - * Symmetric operators. - * - */ - fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let a = &a * a.transpose(); + let m1 = Matrix2::new(k, Qf64::zero(), j, i); + // this is the inverse of m1 + let m2 = Matrix2::new(-k, Qf64::zero(), Qf64::one(), -i); - let x = DVector::new_random(n); - let mut y1 = DVector::new_random(n); - let mut y2 = y1.clone(); + let mut res: Matrix2 = Matrix2::zero(); + res.gemm(Qf64::one(), &m1, &m2, Qf64::zero()); + assert_eq!(res, Matrix2::identity()); - y1.gemv(alpha, &a, &x, beta); - y2.sygemv(alpha, &a.lower_triangle(), &x, beta); + let mut res: Matrix2 = Matrix2::identity(); + res.gemm(k, &m1, &m2, -k); + assert_eq!(res, Matrix2::zero()); +} - if !relative_eq!(y1, y2, epsilon = 1.0e-10) { - return false; +#[cfg(feature = "arbitrary")] +mod blas_quickcheck { + use na::{DMatrix, DVector}; + use std::cmp; + + quickcheck! { + /* + * + * Symmetric operators. + * + */ + fn gemv_symm(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let a = &a * a.transpose(); + + let x = DVector::new_random(n); + let mut y1 = DVector::new_random(n); + let mut y2 = y1.clone(); + + y1.gemv(alpha, &a, &x, beta); + y2.sygemv(alpha, &a.lower_triangle(), &x, beta); + + if !relative_eq!(y1, y2, epsilon = 1.0e-10) { + return false; + } + + y1.gemv(alpha, &a, &x, 0.0); + y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0); + + relative_eq!(y1, y2, epsilon = 1.0e-10) } - y1.gemv(alpha, &a, &x, 0.0); - y2.sygemv(alpha, &a.lower_triangle(), &x, 0.0); + fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let x = DVector::new_random(n); + let mut y1 = DVector::new_random(n); + let mut y2 = y1.clone(); - relative_eq!(y1, y2, epsilon = 1.0e-10) - } + y1.gemv(alpha, &a, &x, beta); + y2.gemv_tr(alpha, &a.transpose(), &x, beta); - fn gemv_tr(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let x = DVector::new_random(n); - let mut y1 = DVector::new_random(n); - let mut y2 = y1.clone(); + if !relative_eq!(y1, y2, epsilon = 1.0e-10) { + return false; + } - y1.gemv(alpha, &a, &x, beta); - y2.gemv_tr(alpha, &a.transpose(), &x, beta); + y1.gemv(alpha, &a, &x, 0.0); + y2.gemv_tr(alpha, &a.transpose(), &x, 0.0); - if !relative_eq!(y1, y2, epsilon = 1.0e-10) { - return false; + relative_eq!(y1, y2, epsilon = 1.0e-10) } - y1.gemv(alpha, &a, &x, 0.0); - y2.gemv_tr(alpha, &a.transpose(), &x, 0.0); + fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let a = DMatrix::::new_random(n, n); + let mut a1 = &a * a.transpose(); + let mut a2 = a1.lower_triangle(); - relative_eq!(y1, y2, epsilon = 1.0e-10) - } + let x = DVector::new_random(n); + let y = DVector::new_random(n); - fn ger_symm(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let a = DMatrix::::new_random(n, n); - let mut a1 = &a * a.transpose(); - let mut a2 = a1.lower_triangle(); + a1.ger(alpha, &x, &y, beta); + a2.syger(alpha, &x, &y, beta); - let x = DVector::new_random(n); - let y = DVector::new_random(n); + if !relative_eq!(a1.lower_triangle(), a2) { + return false; + } - a1.ger(alpha, &x, &y, beta); - a2.syger(alpha, &x, &y, beta); + a1.ger(alpha, &x, &y, 0.0); + a2.syger(alpha, &x, &y, 0.0); - if !relative_eq!(a1.lower_triangle(), a2) { - return false; + relative_eq!(a1.lower_triangle(), a2) } - a1.ger(alpha, &x, &y, 0.0); - a2.syger(alpha, &x, &y, 0.0); + fn quadform(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let rhs = DMatrix::::new_random(6, n); + let mid = DMatrix::::new_random(6, 6); + let mut res = DMatrix::new_random(n, n); - relative_eq!(a1.lower_triangle(), a2) - } + let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha; - fn quadform(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let rhs = DMatrix::::new_random(6, n); - let mid = DMatrix::::new_random(6, 6); - let mut res = DMatrix::new_random(n, n); + res.quadform(alpha, &mid, &rhs, beta); - let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha; + println!("{}{}", res, expected); - res.quadform(alpha, &mid, &rhs, beta); + relative_eq!(res, expected, epsilon = 1.0e-7) + } - println!("{}{}", res, expected); + fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let lhs = DMatrix::::new_random(6, n); + let mid = DMatrix::::new_random(n, n); + let mut res = DMatrix::new_random(6, 6); - relative_eq!(res, expected, epsilon = 1.0e-7) - } + let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; - fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let lhs = DMatrix::::new_random(6, n); - let mid = DMatrix::::new_random(n, n); - let mut res = DMatrix::new_random(6, 6); + res.quadform_tr(alpha, &lhs, &mid , beta); - let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; + println!("{}{}", res, expected); - res.quadform_tr(alpha, &lhs, &mid , beta); - - println!("{}{}", res, expected); - - relative_eq!(res, expected, epsilon = 1.0e-7) + relative_eq!(res, expected, epsilon = 1.0e-7) + } } }