diff --git a/src/base/blas.rs b/src/base/blas.rs index dec28d07..622761fe 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -468,21 +468,21 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul } } -fn array_axpy(y: &mut [N], a: N, x: &[N], beta: N, stride1: usize, stride2: usize, len: usize) +fn array_axcpy(y: &mut [N], a: N, x: &[N], c: N, beta: N, stride1: usize, stride2: usize, len: usize) where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { let y = y.get_unchecked_mut(i * stride1); - *y = *x.get_unchecked(i * stride2) * a + *y * beta; + *y = a * *x.get_unchecked(i * stride2) * c + beta * *y; } } } -fn array_ax(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len: usize) +fn array_axc(y: &mut [N], a: N, x: &[N], c: N, stride1: usize, stride2: usize, len: usize) where N: Scalar + Zero + ClosedAdd + ClosedMul { for i in 0..len { unsafe { - *y.get_unchecked_mut(i * stride1) = *x.get_unchecked(i * stride2) * a; + *y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2) * c; } } } @@ -492,6 +492,40 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul, S: StorageMut, { + /// Computes `self = a * x * c + b * self`. + /// + /// If `b` is zero, `self` is never read from. + /// + /// # Examples: + /// + /// ``` + /// # use nalgebra::Vector3; + /// let mut vec1 = Vector3::new(1.0, 2.0, 3.0); + /// let vec2 = Vector3::new(0.1, 0.2, 0.3); + /// vec1.axcpy(5.0, &vec2, 2.0, 5.0); + /// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0)); + /// ``` + #[inline] + pub fn axcpy(&mut self, a: N, x: &Vector, c: N, b: N) + where + SB: Storage, + ShapeConstraint: DimEq, + { + assert_eq!(self.nrows(), x.nrows(), "Axcpy: mismatched vector shapes."); + + let rstride1 = self.strides().0; + let rstride2 = x.strides().0; + + let y = self.data.as_mut_slice(); + let x = x.data.as_slice(); + + if !b.is_zero() { + array_axcpy(y, a, x, c, b, rstride1, rstride2, x.len()); + } else { + array_axc(y, a, x, c, rstride1, rstride2, x.len()); + } + } + /// Computes `self = a * x + b * self`. /// /// If `b` is zero, `self` is never read from. @@ -508,22 +542,12 @@ where #[inline] pub fn axpy(&mut self, a: N, x: &Vector, b: N) where + N: One, SB: Storage, ShapeConstraint: DimEq, { assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes."); - - let rstride1 = self.strides().0; - let rstride2 = x.strides().0; - - let y = self.data.as_mut_slice(); - let x = x.data.as_slice(); - - if !b.is_zero() { - array_axpy(y, a, x, b, rstride1, rstride2, x.len()); - } else { - array_ax(y, a, x, rstride1, rstride2, x.len()); - } + self.axcpy(a, x, N::one(), b) } /// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and @@ -579,13 +603,13 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(val * alpha, &col2, beta); + self.axcpy(alpha, &col2, val, beta); for j in 1..ncols2 { let col2 = a.column(j); let val = unsafe { *x.vget_unchecked(j) }; - self.axpy(val * alpha, &col2, N::one()); + self.axcpy(alpha, &col2, val, N::one()); } } @@ -624,7 +648,7 @@ where // FIXME: avoid bound checks. let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; - self.axpy(val * alpha, &col2, beta); + self.axpy(alpha * val, &col2, beta); self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..)); for j in 1..dim2 { @@ -637,7 +661,7 @@ where *self.vget_unchecked_mut(j) += alpha * dot; } self.rows_range_mut(j + 1..) - .axpy(val * alpha, &col2.rows_range(j + 1..), N::one()); + .axpy(alpha * val, &col2.rows_range(j + 1..), N::one()); } } @@ -890,7 +914,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul for j in 0..ncols1 { // FIXME: avoid bound checks. let val = unsafe { conjugate(*y.vget_unchecked(j)) }; - self.column_mut(j).axpy(val * alpha, x, beta); + self.column_mut(j).axpy(alpha * val, x, beta); } } @@ -1256,7 +1280,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul let subdim = Dynamic::new(dim1 - j); // FIXME: avoid bound checks. self.generic_slice_mut((j, j), (subdim, U1)).axpy( - val * alpha, + alpha * val, &x.rows_range(j..), beta, );