added mul_tr to ops.rs
added mut_tr and the corresponding hidden yy_mul_to_uninit function.
This commit is contained in:
parent
48d7b175a3
commit
a56c215073
|
@ -680,6 +680,20 @@ where
|
||||||
unsafe { res.assume_init() }
|
unsafe { res.assume_init() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
#[must_use]
|
||||||
|
pub fn mul_tr<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> OMatrix<T, C1, C2>
|
||||||
|
where
|
||||||
|
SB: Storage<T, R2, C2>,
|
||||||
|
DefaultAllocator: Allocator<T, R1, R2>,
|
||||||
|
ShapeConstraint: SameNumberOfColumns<C1, C2>,
|
||||||
|
{
|
||||||
|
let mut res = Matrix::uninit(self.shape_generic().0, rhs.shape_generic().0);
|
||||||
|
self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b)); //note: this was changed
|
||||||
|
// SAFETY: this is OK because the result is now initialized.
|
||||||
|
unsafe { res.assume_init() }
|
||||||
|
}
|
||||||
|
|
||||||
/// Equivalent to `self.adjoint() * rhs`.
|
/// Equivalent to `self.adjoint() * rhs`.
|
||||||
#[inline]
|
#[inline]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
|
@ -744,6 +758,54 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn yy_mul_to_uninit<Status, R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
|
||||||
|
&self,
|
||||||
|
_status: Status,
|
||||||
|
rhs: &Matrix<T, R2, C2, SB>,
|
||||||
|
out: &mut Matrix<Status::Value, R3, C3, SC>,
|
||||||
|
dot: impl Fn(
|
||||||
|
&VectorView<'_, T, R1, SA::RStride, SA::CStride>,
|
||||||
|
&VectorView<'_, T, R2, SB::RStride, SB::CStride>,
|
||||||
|
) -> T,
|
||||||
|
) where
|
||||||
|
Status: InitStatus<T>,
|
||||||
|
SB: RawStorage<T, R2, C2>,
|
||||||
|
SC: RawStorageMut<Status::Value, R3, C3>,
|
||||||
|
ShapeConstraint: SameNumberOfColumns<C1, C2> + DimEq<R1, R3> + DimEq<R2, C3>,
|
||||||
|
{
|
||||||
|
let (nrows1, ncols1) = self.shape();
|
||||||
|
let (nrows2, ncols2) = rhs.shape();
|
||||||
|
let (nrows3, ncols3) = out.shape();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
ncols1 == ncols2,
|
||||||
|
"Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.",
|
||||||
|
self.shape(),
|
||||||
|
rhs.shape()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
nrows1 == nrows3,
|
||||||
|
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.",
|
||||||
|
self.shape(),
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
nrows2 == ncols3,
|
||||||
|
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols",
|
||||||
|
rhs.shape(),
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
|
||||||
|
for i in 0..nrows1 {
|
||||||
|
for j in 0..nrows2 {
|
||||||
|
let dot = dot(&self.row(i), &rhs.row(j));
|
||||||
|
let elt = unsafe { out.get_unchecked_mut((i, j)) };
|
||||||
|
Status::init(elt, dot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Equivalent to `self.transpose() * rhs` but stores the result into `out` to avoid
|
/// Equivalent to `self.transpose() * rhs` but stores the result into `out` to avoid
|
||||||
/// allocations.
|
/// allocations.
|
||||||
#[inline]
|
#[inline]
|
||||||
|
|
Loading…
Reference in New Issue