added mul_tr to ops.rs

added mut_tr and the corresponding hidden yy_mul_to_uninit function.
This commit is contained in:
Fangs 2024-04-24 11:07:21 +07:00 committed by GitHub
parent 48d7b175a3
commit a56c215073
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 62 additions and 0 deletions

View File

@ -679,6 +679,20 @@ where
// SAFETY: this is OK because the result is now initialized. // SAFETY: this is OK because the result is now initialized.
unsafe { res.assume_init() } unsafe { res.assume_init() }
} }
#[inline]
#[must_use]
pub fn mul_tr<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> OMatrix<T, C1, C2>
where
SB: Storage<T, R2, C2>,
DefaultAllocator: Allocator<T, R1, R2>,
ShapeConstraint: SameNumberOfColumns<C1, C2>,
{
let mut res = Matrix::uninit(self.shape_generic().0, rhs.shape_generic().0);
self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b)); //note: this was changed
// SAFETY: this is OK because the result is now initialized.
unsafe { res.assume_init() }
}
/// Equivalent to `self.adjoint() * rhs`. /// Equivalent to `self.adjoint() * rhs`.
#[inline] #[inline]
@ -744,6 +758,54 @@ where
} }
} }
#[inline(always)]
fn yy_mul_to_uninit<Status, R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self,
_status: Status,
rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<Status::Value, R3, C3, SC>,
dot: impl Fn(
&VectorView<'_, T, R1, SA::RStride, SA::CStride>,
&VectorView<'_, T, R2, SB::RStride, SB::CStride>,
) -> T,
) where
Status: InitStatus<T>,
SB: RawStorage<T, R2, C2>,
SC: RawStorageMut<Status::Value, R3, C3>,
ShapeConstraint: SameNumberOfColumns<C1, C2> + DimEq<R1, R3> + DimEq<R2, C3>,
{
let (nrows1, ncols1) = self.shape();
let (nrows2, ncols2) = rhs.shape();
let (nrows3, ncols3) = out.shape();
assert!(
ncols1 == ncols2,
"Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.",
self.shape(),
rhs.shape()
);
assert!(
nrows1 == nrows3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.",
self.shape(),
out.shape()
);
assert!(
nrows2 == ncols3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols",
rhs.shape(),
out.shape()
);
for i in 0..nrows1 {
for j in 0..nrows2 {
let dot = dot(&self.row(i), &rhs.row(j));
let elt = unsafe { out.get_unchecked_mut((i, j)) };
Status::init(elt, dot)
}
}
}
/// Equivalent to `self.transpose() * rhs` but stores the result into `out` to avoid /// Equivalent to `self.transpose() * rhs` but stores the result into `out` to avoid
/// allocations. /// allocations.
#[inline] #[inline]