diff --git a/src/base/ops.rs b/src/base/ops.rs index d5cf3a51..50a35b32 100644 --- a/src/base/ops.rs +++ b/src/base/ops.rs @@ -679,6 +679,20 @@ where // SAFETY: this is OK because the result is now initialized. unsafe { res.assume_init() } } + + #[inline] + #[must_use] + pub fn mul_tr(&self, rhs: &Matrix) -> OMatrix + where + SB: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfColumns, + { + let mut res = Matrix::uninit(self.shape_generic().0, rhs.shape_generic().0); + self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b)); //note: this was changed + // SAFETY: this is OK because the result is now initialized. + unsafe { res.assume_init() } + } /// Equivalent to `self.adjoint() * rhs`. #[inline] @@ -744,6 +758,54 @@ where } } + #[inline(always)] + fn yy_mul_to_uninit( + &self, + _status: Status, + rhs: &Matrix, + out: &mut Matrix, + dot: impl Fn( + &VectorView<'_, T, R1, SA::RStride, SA::CStride>, + &VectorView<'_, T, R2, SB::RStride, SB::CStride>, + ) -> T, + ) where + Status: InitStatus, + SB: RawStorage, + SC: RawStorageMut, + ShapeConstraint: SameNumberOfColumns + DimEq + DimEq, + { + let (nrows1, ncols1) = self.shape(); + let (nrows2, ncols2) = rhs.shape(); + let (nrows3, ncols3) = out.shape(); + + assert!( + ncols1 == ncols2, + "Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.", + self.shape(), + rhs.shape() + ); + assert!( + nrows1 == nrows3, + "Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.", + self.shape(), + out.shape() + ); + assert!( + nrows2 == ncols3, + "Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols", + rhs.shape(), + out.shape() + ); + + for i in 0..nrows1 { + for j in 0..nrows2 { + let dot = dot(&self.row(i), &rhs.row(j)); + let elt = unsafe { out.get_unchecked_mut((i, j)) }; + Status::init(elt, dot) + } + } + } + /// Equivalent to `self.transpose() * rhs` but stores the result into `out` to avoid /// allocations. #[inline]