Merge pull request #741 from jenanwise/dim-mismatch-verbose-errors

More verbose DMatrix dim asserts where possible.
This commit is contained in:
Sébastien Crozet 2020-06-23 01:46:09 -07:00 committed by GitHub
commit b1b18d17ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 39 deletions

View File

@ -284,7 +284,9 @@ where
{ {
assert!( assert!(
self.nrows() == rhs.nrows(), self.nrows() == rhs.nrows(),
"Dot product dimensions mismatch." "Dot product dimensions mismatch for shapes {:?} and {:?}: left rows != right rows.",
self.shape(),
rhs.shape(),
); );
// So we do some special cases for common fixed-size vectors of dimension lower than 8 // So we do some special cases for common fixed-size vectors of dimension lower than 8
@ -496,8 +498,9 @@ where
ShapeConstraint: DimEq<C, R2> + DimEq<R, C2>, ShapeConstraint: DimEq<C, R2> + DimEq<R, C2>,
{ {
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
assert!( assert_eq!(
(ncols, nrows) == rhs.shape(), (ncols, nrows),
rhs.shape(),
"Transposed dot product dimension mismatch." "Transposed dot product dimension mismatch."
); );

View File

@ -538,8 +538,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) }; let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
assert!( assert_eq!(
(nrows.value(), ncols.value()) == rhs.shape(), (nrows.value(), ncols.value()),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch." "Matrix simultaneous traversal error: dimension mismatch."
); );
@ -578,9 +579,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) }; let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
assert!( assert_eq!(
(nrows.value(), ncols.value()) == b.shape() (nrows.value(), ncols.value()),
&& (nrows.value(), ncols.value()) == c.shape(), b.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
assert_eq!(
(nrows.value(), ncols.value()),
c.shape(),
"Matrix simultaneous traversal error: dimension mismatch." "Matrix simultaneous traversal error: dimension mismatch."
); );
@ -636,8 +642,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = init; let mut res = init;
assert!( assert_eq!(
(nrows.value(), ncols.value()) == rhs.shape(), (nrows.value(), ncols.value()),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch." "Matrix simultaneous traversal error: dimension mismatch."
); );
@ -884,8 +891,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
{ {
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
assert!( assert_eq!(
(nrows, ncols) == rhs.shape(), (nrows, ncols),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch." "Matrix simultaneous traversal error: dimension mismatch."
); );
@ -922,12 +930,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
{ {
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
assert!( assert_eq!(
(nrows, ncols) == b.shape(), (nrows, ncols),
b.shape(),
"Matrix simultaneous traversal error: dimension mismatch." "Matrix simultaneous traversal error: dimension mismatch."
); );
assert!( assert_eq!(
(nrows, ncols) == c.shape(), (nrows, ncols),
c.shape(),
"Matrix simultaneous traversal error: dimension mismatch." "Matrix simultaneous traversal error: dimension mismatch."
); );
@ -1427,8 +1437,9 @@ where
#[inline] #[inline]
fn lt(&self, right: &Self) -> bool { fn lt(&self, right: &Self) -> bool {
assert!( assert_eq!(
self.shape() == right.shape(), self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch." "Matrix comparison error: dimensions mismatch."
); );
self.iter().zip(right.iter()).all(|(a, b)| a.lt(b)) self.iter().zip(right.iter()).all(|(a, b)| a.lt(b))
@ -1436,8 +1447,9 @@ where
#[inline] #[inline]
fn le(&self, right: &Self) -> bool { fn le(&self, right: &Self) -> bool {
assert!( assert_eq!(
self.shape() == right.shape(), self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch." "Matrix comparison error: dimensions mismatch."
); );
self.iter().zip(right.iter()).all(|(a, b)| a.le(b)) self.iter().zip(right.iter()).all(|(a, b)| a.le(b))
@ -1445,8 +1457,9 @@ where
#[inline] #[inline]
fn gt(&self, right: &Self) -> bool { fn gt(&self, right: &Self) -> bool {
assert!( assert_eq!(
self.shape() == right.shape(), self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch." "Matrix comparison error: dimensions mismatch."
); );
self.iter().zip(right.iter()).all(|(a, b)| a.gt(b)) self.iter().zip(right.iter()).all(|(a, b)| a.gt(b))
@ -1454,8 +1467,9 @@ where
#[inline] #[inline]
fn ge(&self, right: &Self) -> bool { fn ge(&self, right: &Self) -> bool {
assert!( assert_eq!(
self.shape() == right.shape(), self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch." "Matrix comparison error: dimensions mismatch."
); );
self.iter().zip(right.iter()).all(|(a, b)| a.ge(b)) self.iter().zip(right.iter()).all(|(a, b)| a.ge(b))
@ -1602,7 +1616,11 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
+ SameNumberOfRows<R2, U2> + SameNumberOfRows<R2, U2>
+ SameNumberOfColumns<C2, U1>, + SameNumberOfColumns<C2, U1>,
{ {
assert!(self.shape() == (2, 1), "2D perpendicular product "); assert!(
self.shape() == (2, 1),
"2D perpendicular product requires (2, 1) vector but found {:?}",
self.shape()
);
unsafe { unsafe {
self.get_unchecked((0, 0)).inlined_clone() * b.get_unchecked((1, 0)).inlined_clone() self.get_unchecked((0, 0)).inlined_clone() * b.get_unchecked((1, 0)).inlined_clone()
@ -1626,13 +1644,11 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
ShapeConstraint: SameNumberOfRows<R, R2> + SameNumberOfColumns<C, C2>, ShapeConstraint: SameNumberOfRows<R, R2> + SameNumberOfColumns<C, C2>,
{ {
let shape = self.shape(); let shape = self.shape();
assert!( assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch.");
shape == b.shape(),
"Vector cross product dimension mismatch."
);
assert!( assert!(
(shape.0 == 3 && shape.1 == 1) || (shape.0 == 1 && shape.1 == 3), (shape.0 == 3 && shape.1 == 1) || (shape.0 == 1 && shape.1 == 3),
"Vector cross product dimension mismatch." "Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.",
shape
); );
if shape.0 == 3 { if shape.0 == 3 {

View File

@ -154,8 +154,8 @@ macro_rules! componentwise_binop_impl(
out: &mut Matrix<N, R3, C3, SC>) out: &mut Matrix<N, R3, C3, SC>)
where SB: Storage<N, R2, C2>, where SB: Storage<N, R2, C2>,
SC: StorageMut<N, R3, C3> { SC: StorageMut<N, R3, C3> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert!(self.shape() == out.shape(), "Matrix addition/subtraction output dimensions mismatch."); assert_eq!(self.shape(), out.shape(), "Matrix addition/subtraction output dimensions mismatch.");
// This is the most common case and should be deduced at compile-time. // This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead? // FIXME: use specialization instead?
@ -188,7 +188,7 @@ macro_rules! componentwise_binop_impl(
C2: Dim, C2: Dim,
SA: StorageMut<N, R1, C1>, SA: StorageMut<N, R1, C1>,
SB: Storage<N, R2, C2> { SB: Storage<N, R2, C2> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
// This is the most common case and should be deduced at compile-time. // This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead? // FIXME: use specialization instead?
@ -218,7 +218,7 @@ macro_rules! componentwise_binop_impl(
where R2: Dim, where R2: Dim,
C2: Dim, C2: Dim,
SB: StorageMut<N, R2, C2> { SB: StorageMut<N, R2, C2> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
// This is the most common case and should be deduced at compile-time. // This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead? // FIXME: use specialization instead?
@ -277,7 +277,7 @@ macro_rules! componentwise_binop_impl(
#[inline] #[inline]
fn $method(self, rhs: &'b Matrix<N, R2, C2, SB>) -> Self::Output { fn $method(self, rhs: &'b Matrix<N, R2, C2, SB>) -> Self::Output {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
let mut res = self.into_owned_sum::<R2, C2>(); let mut res = self.into_owned_sum::<R2, C2>();
res.$method_assign_statically_unchecked(rhs); res.$method_assign_statically_unchecked(rhs);
res res
@ -296,7 +296,7 @@ macro_rules! componentwise_binop_impl(
#[inline] #[inline]
fn $method(self, rhs: Matrix<N, R2, C2, SB>) -> Self::Output { fn $method(self, rhs: Matrix<N, R2, C2, SB>) -> Self::Output {
let mut rhs = rhs.into_owned_sum::<R1, C1>(); let mut rhs = rhs.into_owned_sum::<R1, C1>();
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
self.$method_assign_statically_unchecked_rhs(&mut rhs); self.$method_assign_statically_unchecked_rhs(&mut rhs);
rhs rhs
} }
@ -728,11 +728,21 @@ where
assert!( assert!(
nrows1 == nrows2, nrows1 == nrows2,
"Matrix multiplication dimensions mismatch." "Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.",
self.shape(),
rhs.shape()
); );
assert!( assert!(
nrows3 == ncols1 && ncols3 == ncols2, ncols1 == nrows3,
"Matrix multiplication output dimensions mismatch." "Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.",
self.shape(),
out.shape()
);
assert!(
ncols2 == ncols3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols",
rhs.shape(),
out.shape()
); );
for i in 0..ncols1 { for i in 0..ncols1 {