Merge pull request #741 from jenanwise/dim-mismatch-verbose-errors

More verbose DMatrix dim asserts where possible.
This commit is contained in:
Sébastien Crozet 2020-06-23 01:46:09 -07:00 committed by GitHub
commit b1b18d17ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 39 deletions

View File

@ -284,7 +284,9 @@ where
{
assert!(
self.nrows() == rhs.nrows(),
"Dot product dimensions mismatch."
"Dot product dimensions mismatch for shapes {:?} and {:?}: left rows != right rows.",
self.shape(),
rhs.shape(),
);
// So we do some special cases for common fixed-size vectors of dimension lower than 8
@ -496,8 +498,9 @@ where
ShapeConstraint: DimEq<C, R2> + DimEq<R, C2>,
{
let (nrows, ncols) = self.shape();
assert!(
(ncols, nrows) == rhs.shape(),
assert_eq!(
(ncols, nrows),
rhs.shape(),
"Transposed dot product dimension mismatch."
);

View File

@ -538,8 +538,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
assert!(
(nrows.value(), ncols.value()) == rhs.shape(),
assert_eq!(
(nrows.value(), ncols.value()),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -578,9 +579,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
assert!(
(nrows.value(), ncols.value()) == b.shape()
&& (nrows.value(), ncols.value()) == c.shape(),
assert_eq!(
(nrows.value(), ncols.value()),
b.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
assert_eq!(
(nrows.value(), ncols.value()),
c.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -636,8 +642,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = init;
assert!(
(nrows.value(), ncols.value()) == rhs.shape(),
assert_eq!(
(nrows.value(), ncols.value()),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -884,8 +891,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
{
let (nrows, ncols) = self.shape();
assert!(
(nrows, ncols) == rhs.shape(),
assert_eq!(
(nrows, ncols),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -922,12 +930,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
{
let (nrows, ncols) = self.shape();
assert!(
(nrows, ncols) == b.shape(),
assert_eq!(
(nrows, ncols),
b.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
assert!(
(nrows, ncols) == c.shape(),
assert_eq!(
(nrows, ncols),
c.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -1427,8 +1437,9 @@ where
#[inline]
fn lt(&self, right: &Self) -> bool {
assert!(
self.shape() == right.shape(),
assert_eq!(
self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch."
);
self.iter().zip(right.iter()).all(|(a, b)| a.lt(b))
@ -1436,8 +1447,9 @@ where
#[inline]
fn le(&self, right: &Self) -> bool {
assert!(
self.shape() == right.shape(),
assert_eq!(
self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch."
);
self.iter().zip(right.iter()).all(|(a, b)| a.le(b))
@ -1445,8 +1457,9 @@ where
#[inline]
fn gt(&self, right: &Self) -> bool {
assert!(
self.shape() == right.shape(),
assert_eq!(
self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch."
);
self.iter().zip(right.iter()).all(|(a, b)| a.gt(b))
@ -1454,8 +1467,9 @@ where
#[inline]
fn ge(&self, right: &Self) -> bool {
assert!(
self.shape() == right.shape(),
assert_eq!(
self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch."
);
self.iter().zip(right.iter()).all(|(a, b)| a.ge(b))
@ -1602,7 +1616,11 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
+ SameNumberOfRows<R2, U2>
+ SameNumberOfColumns<C2, U1>,
{
assert!(self.shape() == (2, 1), "2D perpendicular product ");
assert!(
self.shape() == (2, 1),
"2D perpendicular product requires (2, 1) vector but found {:?}",
self.shape()
);
unsafe {
self.get_unchecked((0, 0)).inlined_clone() * b.get_unchecked((1, 0)).inlined_clone()
@ -1626,13 +1644,11 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
ShapeConstraint: SameNumberOfRows<R, R2> + SameNumberOfColumns<C, C2>,
{
let shape = self.shape();
assert!(
shape == b.shape(),
"Vector cross product dimension mismatch."
);
assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch.");
assert!(
(shape.0 == 3 && shape.1 == 1) || (shape.0 == 1 && shape.1 == 3),
"Vector cross product dimension mismatch."
"Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.",
shape
);
if shape.0 == 3 {

View File

@ -154,8 +154,8 @@ macro_rules! componentwise_binop_impl(
out: &mut Matrix<N, R3, C3, SC>)
where SB: Storage<N, R2, C2>,
SC: StorageMut<N, R3, C3> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert!(self.shape() == out.shape(), "Matrix addition/subtraction output dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), out.shape(), "Matrix addition/subtraction output dimensions mismatch.");
// This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead?
@ -188,7 +188,7 @@ macro_rules! componentwise_binop_impl(
C2: Dim,
SA: StorageMut<N, R1, C1>,
SB: Storage<N, R2, C2> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
// This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead?
@ -218,7 +218,7 @@ macro_rules! componentwise_binop_impl(
where R2: Dim,
C2: Dim,
SB: StorageMut<N, R2, C2> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
// This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead?
@ -277,7 +277,7 @@ macro_rules! componentwise_binop_impl(
#[inline]
fn $method(self, rhs: &'b Matrix<N, R2, C2, SB>) -> Self::Output {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
let mut res = self.into_owned_sum::<R2, C2>();
res.$method_assign_statically_unchecked(rhs);
res
@ -296,7 +296,7 @@ macro_rules! componentwise_binop_impl(
#[inline]
fn $method(self, rhs: Matrix<N, R2, C2, SB>) -> Self::Output {
let mut rhs = rhs.into_owned_sum::<R1, C1>();
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
self.$method_assign_statically_unchecked_rhs(&mut rhs);
rhs
}
@ -728,11 +728,21 @@ where
assert!(
nrows1 == nrows2,
"Matrix multiplication dimensions mismatch."
"Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.",
self.shape(),
rhs.shape()
);
assert!(
nrows3 == ncols1 && ncols3 == ncols2,
"Matrix multiplication output dimensions mismatch."
ncols1 == nrows3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.",
self.shape(),
out.shape()
);
assert!(
ncols2 == ncols3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols",
rhs.shape(),
out.shape()
);
for i in 0..ncols1 {