fixed mul_tr, reverted test change, commented out the new test for mul_tr.

This commit is contained in:
fangs124 2024-04-24 13:43:49 +07:00
parent acd2fc38fb
commit 95830ff9b4
3 changed files with 23 additions and 16 deletions

View File

@ -379,6 +379,14 @@ pub type MatrixViewXx6<'a, T, RStride = U1, CStride = Dyn> =
pub type VectorView<'a, T, D, RStride = U1, CStride = D> = pub type VectorView<'a, T, D, RStride = U1, CStride = D> =
Matrix<T, D, U1, ViewStorage<'a, T, D, U1, RStride, CStride>>; Matrix<T, D, U1, ViewStorage<'a, T, D, U1, RStride, CStride>>;
/// An immutable row vector view with dimensions known at compile-time.
///
///
///
/// **Because this is an alias, not all its methods are listed here. See the [`Matrix`](crate::base::Matrix) type too.**
pub type RowVectorView<'a, T, D, RStride = D, CStride = U1> =
Matrix<T, U1, D, ViewStorage<'a, T, U1, D, RStride, CStride>>;
/// An immutable column vector view with dimensions known at compile-time. /// An immutable column vector view with dimensions known at compile-time.
/// ///
/// See [`SVectorViewMut`] for a mutable version of this type. /// See [`SVectorViewMut`] for a mutable version of this type.
@ -806,7 +814,6 @@ pub type MatrixViewMutXx5<'a, T, RStride = U1, CStride = Dyn> =
/// **Because this is an alias, not all its methods are listed here. See the [`Matrix`](crate::base::Matrix) type too.** /// **Because this is an alias, not all its methods are listed here. See the [`Matrix`](crate::base::Matrix) type too.**
pub type MatrixViewMutXx6<'a, T, RStride = U1, CStride = Dyn> = pub type MatrixViewMutXx6<'a, T, RStride = U1, CStride = Dyn> =
Matrix<T, Dyn, U6, ViewStorageMut<'a, T, Dyn, U6, RStride, CStride>>; Matrix<T, Dyn, U6, ViewStorageMut<'a, T, Dyn, U6, RStride, CStride>>;
/// A mutable column vector view with dimensions known at compile-time. /// A mutable column vector view with dimensions known at compile-time.
/// ///
/// See [`VectorView`] for an immutable version of this type. /// See [`VectorView`] for an immutable version of this type.

View File

@ -14,7 +14,7 @@ use crate::base::constraint::{
use crate::base::dimension::{Dim, DimMul, DimName, DimProd, Dyn}; use crate::base::dimension::{Dim, DimMul, DimName, DimProd, Dyn};
use crate::base::storage::{Storage, StorageMut}; use crate::base::storage::{Storage, StorageMut};
use crate::base::uninit::Uninit; use crate::base::uninit::Uninit;
use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorView}; use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorView, RowVectorView};
use crate::storage::IsContiguous; use crate::storage::IsContiguous;
use crate::uninit::{Init, InitStatus}; use crate::uninit::{Init, InitStatus};
use crate::{RawStorage, RawStorageMut, SimdComplexField}; use crate::{RawStorage, RawStorageMut, SimdComplexField};
@ -690,7 +690,7 @@ where
ShapeConstraint: SameNumberOfColumns<C1, C2>, ShapeConstraint: SameNumberOfColumns<C1, C2>,
{ {
let mut res = Matrix::uninit(self.shape_generic().0, rhs.shape_generic().0); let mut res = Matrix::uninit(self.shape_generic().0, rhs.shape_generic().0);
self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b)); //note: this was changed self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b));
// SAFETY: this is OK because the result is now initialized. // SAFETY: this is OK because the result is now initialized.
unsafe { res.assume_init() } unsafe { res.assume_init() }
} }
@ -766,8 +766,8 @@ where
rhs: &Matrix<T, R2, C2, SB>, rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<Status::Value, R3, C3, SC>, out: &mut Matrix<Status::Value, R3, C3, SC>,
dot: impl Fn( dot: impl Fn(
&VectorView<'_, T, R1, SA::RStride, SA::CStride>, &RowVectorView<'_, T, C1, SA::RStride, SA::CStride>,
&VectorView<'_, T, R2, SB::RStride, SB::CStride>, &RowVectorView<'_, T, C2, SB::RStride, SB::CStride>,
) -> T, ) -> T,
) where ) where
Status: InitStatus<T>, Status: InitStatus<T>,
@ -781,19 +781,19 @@ where
assert!( assert!(
ncols1 == ncols2, ncols1 == ncols2,
"Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.", "Matrix multiplication dimensions mismatch {:?} and {:?}: left cols != right cols.",
self.shape(), self.shape(),
rhs.shape() rhs.shape()
); );
assert!( assert!(
nrows1 == nrows3, nrows1 == nrows3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.", "Matrix multiplication output dimensions mismatch {:?} and {:?}: left rows != right rows.",
self.shape(), self.shape(),
out.shape() out.shape()
); );
assert!( assert!(
nrows2 == ncols3, nrows2 == ncols3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols", "Matrix multiplication output dimensions mismatch {:?} and {:?}: left rows != right cols",
rhs.shape(), rhs.shape(),
out.shape() out.shape()
); );

View File

@ -920,13 +920,13 @@ mod transposition_tests {
} }
#[test] #[test]
fn tr_mul_is_transpose_lhs_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) { fn tr_mul_is_transpose_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) {
prop_assert!(relative_eq!(m.transpose() * v, m.tr_mul(&v), epsilon = 1.0e-7)) prop_assert!(relative_eq!(m.transpose() * v, m.tr_mul(&v), epsilon = 1.0e-7))
} }
#[test] //#[test]
fn mul_tr_is_transpose_rhs_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) { //fn mul_tr_is_transpose_rhs_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) {
prop_assert!(relative_eq!(m*v.transpose(), m.mul_tr(&v), epsilon = 1.0e-7)) // prop_assert!(relative_eq!(m * v.transpose(), m.mul_tr(&v), epsilon = 1.0e-7))
} //}
} }
} }