diff --git a/src/core/blas.rs b/src/core/blas.rs index a3f983cc..ed4ea2f5 100644 --- a/src/core/blas.rs +++ b/src/core/blas.rs @@ -462,36 +462,41 @@ impl> SquareMatrix where N: Scalar + Zero + One + ClosedAdd + ClosedMul { /// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`. - pub fn quadform_symm(&mut self, - scratch: &mut Vector, - alpha: N, - mid: &SquareMatrix, - rhs: &Matrix, - beta: N) - where D2: Dim, R3: Dim, C3: Dim, - S2: StorageMut, - S3: Storage, - S4: Storage, - ShapeConstraint: DimEq + DimEq + DimEq, // FIXME: why is this one necessary? - DefaultAllocator: Allocator { - - - /* - scratch.gemv(N::one(), lhs, &mid.column(0), N::zero()); - self.ger_symm(alpha, &scratch, &lhs.column(0), beta); + pub fn quadform_with_workspace(&mut self, + work: &mut Vector, + alpha: N, + lhs: &Matrix, + mid: &SquareMatrix, + beta: N) + where D2: Dim, R3: Dim, C3: Dim, D4: Dim, + S2: StorageMut, + S3: Storage, + S4: Storage, + ShapeConstraint: DimEq + + DimEq + + DimEq + + DimEq { + work.gemv(N::one(), lhs, &mid.column(0), N::zero()); + self.ger(alpha, work, &lhs.column(0), beta); for j in 1 .. mid.ncols() { - scratch.gemv(N::one(), lhs, &mid.column(j), N::zero()); - self.ger_symm(alpha, &scratch, &lhs.column(j), N::one()); - } - */ - - scratch.gemv_symm(N::one(), mid, &rhs.column(0), N::zero()); - self.column_mut(0).gemv(alpha, rhs, &scratch, beta); - - for j in 1 .. mid.ncols() { - scratch.gemv_symm(N::one(), mid, &rhs.column(j), N::zero()); - self.slice_range_mut(j .., j).gemv(alpha, &rhs.rows_range(j ..), &scratch, N::one()); + work.gemv(N::one(), lhs, &mid.column(j), N::zero()); + self.ger(alpha, work, &lhs.column(j), N::one()); } } + + /// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`. + pub fn quadform(&mut self, + alpha: N, + lhs: &Matrix, + mid: &SquareMatrix, + beta: N) + where R3: Dim, C3: Dim, D4: Dim, + S3: Storage, + S4: Storage, + ShapeConstraint: DimEq + DimEq + DimEq, + DefaultAllocator: Allocator { + let mut work = unsafe { Vector::new_uninitialized_generic(self.data.shape().0, U1) }; + self.quadform_with_workspace(&mut work, alpha, lhs, mid, beta) + } } diff --git a/tests/core/blas.rs b/tests/core/blas.rs index 79e5bffc..019c73df 100644 --- a/tests/core/blas.rs +++ b/tests/core/blas.rs @@ -53,22 +53,18 @@ quickcheck! { relative_eq!(a1.lower_triangle(), a2) } - fn quadform_symm(n: usize, alpha: f64, beta: f64) -> bool { - let n = cmp::max(1, cmp::min(n, 50)); - let lhs = DMatrix::::new_random(6, n); - let mut mid = DMatrix::::new_random(n, n); - let mut res = DMatrix::new_random(6, 6); - let mut scratch = Vector6::zeros(); - - mid.fill_upper_triangle_with_lower_triangle(); + fn quadform(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let lhs = DMatrix::::new_random(6, n); + let mid = DMatrix::::new_random(n, n); + let mut res = DMatrix::new_random(6, 6); let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; - res.quadform_symm(&mut scratch, alpha, &lhs, &mid, beta); - res.fill_upper_triangle_with_lower_triangle(); + res.quadform(alpha, &lhs, &mid , beta); println!("{}{}", res, expected); - relative_eq!(res.lower_triangle(), expected.lower_triangle(), epsilon = 1.0e-7) + relative_eq!(res, expected, epsilon = 1.0e-7) } }