From 95830ff9b49f4ecf6707508f144bbdf27a1cf258 Mon Sep 17 00:00:00 2001 From: fangs124 Date: Wed, 24 Apr 2024 13:43:49 +0700 Subject: [PATCH] fixed mul_tr, reverted test change, commented out the new test for mul_tr. --- src/base/alias_view.rs | 9 ++++++++- src/base/ops.rs | 20 ++++++++++---------- tests/core/matrix.rs | 10 +++++----- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/base/alias_view.rs b/src/base/alias_view.rs index 19a6caec..776d499c 100644 --- a/src/base/alias_view.rs +++ b/src/base/alias_view.rs @@ -379,6 +379,14 @@ pub type MatrixViewXx6<'a, T, RStride = U1, CStride = Dyn> = pub type VectorView<'a, T, D, RStride = U1, CStride = D> = Matrix>; +/// An immutable row vector view with dimensions known at compile-time. +/// +/// +/// +/// **Because this is an alias, not all its methods are listed here. See the [`Matrix`](crate::base::Matrix) type too.** +pub type RowVectorView<'a, T, D, RStride = D, CStride = U1> = + Matrix>; + /// An immutable column vector view with dimensions known at compile-time. /// /// See [`SVectorViewMut`] for a mutable version of this type. @@ -806,7 +814,6 @@ pub type MatrixViewMutXx5<'a, T, RStride = U1, CStride = Dyn> = /// **Because this is an alias, not all its methods are listed here. See the [`Matrix`](crate::base::Matrix) type too.** pub type MatrixViewMutXx6<'a, T, RStride = U1, CStride = Dyn> = Matrix>; - /// A mutable column vector view with dimensions known at compile-time. /// /// See [`VectorView`] for an immutable version of this type. diff --git a/src/base/ops.rs b/src/base/ops.rs index 3b89d3e2..eea9b136 100644 --- a/src/base/ops.rs +++ b/src/base/ops.rs @@ -14,7 +14,7 @@ use crate::base::constraint::{ use crate::base::dimension::{Dim, DimMul, DimName, DimProd, Dyn}; use crate::base::storage::{Storage, StorageMut}; use crate::base::uninit::Uninit; -use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorView}; +use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorView, RowVectorView}; use crate::storage::IsContiguous; use crate::uninit::{Init, InitStatus}; use crate::{RawStorage, RawStorageMut, SimdComplexField}; @@ -679,7 +679,7 @@ where // SAFETY: this is OK because the result is now initialized. unsafe { res.assume_init() } } - + #[inline] #[must_use] /// Equivalent to `self * rhs.transpose()`. @@ -690,7 +690,7 @@ where ShapeConstraint: SameNumberOfColumns, { let mut res = Matrix::uninit(self.shape_generic().0, rhs.shape_generic().0); - self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b)); //note: this was changed + self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b)); // SAFETY: this is OK because the result is now initialized. unsafe { res.assume_init() } } @@ -759,15 +759,15 @@ where } } - #[inline(always)] + #[inline(always)] fn yy_mul_to_uninit( &self, _status: Status, rhs: &Matrix, out: &mut Matrix, dot: impl Fn( - &VectorView<'_, T, R1, SA::RStride, SA::CStride>, - &VectorView<'_, T, R2, SB::RStride, SB::CStride>, + &RowVectorView<'_, T, C1, SA::RStride, SA::CStride>, + &RowVectorView<'_, T, C2, SB::RStride, SB::CStride>, ) -> T, ) where Status: InitStatus, @@ -781,26 +781,26 @@ where assert!( ncols1 == ncols2, - "Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.", + "Matrix multiplication dimensions mismatch {:?} and {:?}: left cols != right cols.", self.shape(), rhs.shape() ); assert!( nrows1 == nrows3, - "Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.", + "Matrix multiplication output dimensions mismatch {:?} and {:?}: left rows != right rows.", self.shape(), out.shape() ); assert!( nrows2 == ncols3, - "Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols", + "Matrix multiplication output dimensions mismatch {:?} and {:?}: left rows != right cols", rhs.shape(), out.shape() ); for i in 0..nrows1 { for j in 0..nrows2 { - let dot = dot(&self.row(i), &rhs.row(j)); + let dot = dot(&self.row(i), &rhs.row(j)); let elt = unsafe { out.get_unchecked_mut((i, j)) }; Status::init(elt, dot) } diff --git a/tests/core/matrix.rs b/tests/core/matrix.rs index 27ec863e..ef49ee68 100644 --- a/tests/core/matrix.rs +++ b/tests/core/matrix.rs @@ -920,13 +920,13 @@ mod transposition_tests { } #[test] - fn tr_mul_is_transpose_lhs_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) { + fn tr_mul_is_transpose_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) { prop_assert!(relative_eq!(m.transpose() * v, m.tr_mul(&v), epsilon = 1.0e-7)) } - #[test] - fn mul_tr_is_transpose_rhs_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) { - prop_assert!(relative_eq!(m*v.transpose(), m.mul_tr(&v), epsilon = 1.0e-7)) - } + //#[test] + //fn mul_tr_is_transpose_rhs_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) { + // prop_assert!(relative_eq!(m * v.transpose(), m.mul_tr(&v), epsilon = 1.0e-7)) + //} } }