added mul_tr to ops.rs
added mut_tr and the corresponding hidden yy_mul_to_uninit function.
This commit is contained in:
parent
48d7b175a3
commit
a56c215073
@ -679,6 +679,20 @@ where
|
||||
// SAFETY: this is OK because the result is now initialized.
|
||||
unsafe { res.assume_init() }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn mul_tr<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> OMatrix<T, C1, C2>
|
||||
where
|
||||
SB: Storage<T, R2, C2>,
|
||||
DefaultAllocator: Allocator<T, R1, R2>,
|
||||
ShapeConstraint: SameNumberOfColumns<C1, C2>,
|
||||
{
|
||||
let mut res = Matrix::uninit(self.shape_generic().0, rhs.shape_generic().0);
|
||||
self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b)); //note: this was changed
|
||||
// SAFETY: this is OK because the result is now initialized.
|
||||
unsafe { res.assume_init() }
|
||||
}
|
||||
|
||||
/// Equivalent to `self.adjoint() * rhs`.
|
||||
#[inline]
|
||||
@ -744,6 +758,54 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn yy_mul_to_uninit<Status, R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
|
||||
&self,
|
||||
_status: Status,
|
||||
rhs: &Matrix<T, R2, C2, SB>,
|
||||
out: &mut Matrix<Status::Value, R3, C3, SC>,
|
||||
dot: impl Fn(
|
||||
&VectorView<'_, T, R1, SA::RStride, SA::CStride>,
|
||||
&VectorView<'_, T, R2, SB::RStride, SB::CStride>,
|
||||
) -> T,
|
||||
) where
|
||||
Status: InitStatus<T>,
|
||||
SB: RawStorage<T, R2, C2>,
|
||||
SC: RawStorageMut<Status::Value, R3, C3>,
|
||||
ShapeConstraint: SameNumberOfColumns<C1, C2> + DimEq<R1, R3> + DimEq<R2, C3>,
|
||||
{
|
||||
let (nrows1, ncols1) = self.shape();
|
||||
let (nrows2, ncols2) = rhs.shape();
|
||||
let (nrows3, ncols3) = out.shape();
|
||||
|
||||
assert!(
|
||||
ncols1 == ncols2,
|
||||
"Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.",
|
||||
self.shape(),
|
||||
rhs.shape()
|
||||
);
|
||||
assert!(
|
||||
nrows1 == nrows3,
|
||||
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.",
|
||||
self.shape(),
|
||||
out.shape()
|
||||
);
|
||||
assert!(
|
||||
nrows2 == ncols3,
|
||||
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols",
|
||||
rhs.shape(),
|
||||
out.shape()
|
||||
);
|
||||
|
||||
for i in 0..nrows1 {
|
||||
for j in 0..nrows2 {
|
||||
let dot = dot(&self.row(i), &rhs.row(j));
|
||||
let elt = unsafe { out.get_unchecked_mut((i, j)) };
|
||||
Status::init(elt, dot)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Equivalent to `self.transpose() * rhs` but stores the result into `out` to avoid
|
||||
/// allocations.
|
||||
#[inline]
|
||||
|
Loading…
Reference in New Issue
Block a user