From 85a64fb517c9783663ba1c241092fa7b2060276f Mon Sep 17 00:00:00 2001 From: Jenan Wise Date: Mon, 22 Jun 2020 15:29:13 -0700 Subject: [PATCH] More verbose DMatrix dim asserts where possible. Previously, most dimension mismatch asserts used raw `assert!` and did not include the mismatching dimensions in the panic message. When using dynamic matrices, this led to somewhat-opaque panics such as: ```rust let m1 = DMatrix::::zeros(2, 3); let m2 = DMatrix::::zeros(5, 10); m1 + m2 // panic: Matrix addition/subtraction dimensions mismatch. ``` This patch adds dimension information in the panic messages wherever doing so did not add additional bounds checks, mostly by simply changing `assert!(a == b, ...)` cases to `assert_eq!`. After: ```rust // panic: assertion failed: `(left == right)` // left: `(2, 3)`, // right: `(5, 10)`: Matrix addition/subtraction dimensions mismatch. ``` Note that the `gemv` and `ger` were not updated, as they are called from within other functions on subset matricies -- e.g., `gemv` is called from `gemm` which is called from `mul_to` . Including dimension information in the `gemv` panic messages would be confusing to `mul` / `mul_to` users, because it would include dimensions of the column vectors that `gemm` passes to `gemv` rather than of the original `mul` arguments. A fix would be to add bounds checks to `mul_to`, but that may have performance and redundancy implications, so is left to another patch. --- src/base/blas.rs | 9 ++++-- src/base/matrix.rs | 70 ++++++++++++++++++++++++++++------------------ src/base/ops.rs | 28 +++++++++++++------ 3 files changed, 68 insertions(+), 39 deletions(-) diff --git a/src/base/blas.rs b/src/base/blas.rs index add17af2..1d61b9d2 100644 --- a/src/base/blas.rs +++ b/src/base/blas.rs @@ -284,7 +284,9 @@ where { assert!( self.nrows() == rhs.nrows(), - "Dot product dimensions mismatch." + "Dot product dimensions mismatch for shapes {:?} and {:?}: left rows != right rows.", + self.shape(), + rhs.shape(), ); // So we do some special cases for common fixed-size vectors of dimension lower than 8 @@ -496,8 +498,9 @@ where ShapeConstraint: DimEq + DimEq, { let (nrows, ncols) = self.shape(); - assert!( - (ncols, nrows) == rhs.shape(), + assert_eq!( + (ncols, nrows), + rhs.shape(), "Transposed dot product dimension mismatch." ); diff --git a/src/base/matrix.rs b/src/base/matrix.rs index e4821bf8..d30773a3 100644 --- a/src/base/matrix.rs +++ b/src/base/matrix.rs @@ -538,8 +538,9 @@ impl> Matrix { let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) }; - assert!( - (nrows.value(), ncols.value()) == rhs.shape(), + assert_eq!( + (nrows.value(), ncols.value()), + rhs.shape(), "Matrix simultaneous traversal error: dimension mismatch." ); @@ -578,9 +579,14 @@ impl> Matrix { let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) }; - assert!( - (nrows.value(), ncols.value()) == b.shape() - && (nrows.value(), ncols.value()) == c.shape(), + assert_eq!( + (nrows.value(), ncols.value()), + b.shape(), + "Matrix simultaneous traversal error: dimension mismatch." + ); + assert_eq!( + (nrows.value(), ncols.value()), + c.shape(), "Matrix simultaneous traversal error: dimension mismatch." ); @@ -636,8 +642,9 @@ impl> Matrix { let mut res = init; - assert!( - (nrows.value(), ncols.value()) == rhs.shape(), + assert_eq!( + (nrows.value(), ncols.value()), + rhs.shape(), "Matrix simultaneous traversal error: dimension mismatch." ); @@ -884,8 +891,9 @@ impl> Matrix { { let (nrows, ncols) = self.shape(); - assert!( - (nrows, ncols) == rhs.shape(), + assert_eq!( + (nrows, ncols), + rhs.shape(), "Matrix simultaneous traversal error: dimension mismatch." ); @@ -922,12 +930,14 @@ impl> Matrix { { let (nrows, ncols) = self.shape(); - assert!( - (nrows, ncols) == b.shape(), + assert_eq!( + (nrows, ncols), + b.shape(), "Matrix simultaneous traversal error: dimension mismatch." ); - assert!( - (nrows, ncols) == c.shape(), + assert_eq!( + (nrows, ncols), + c.shape(), "Matrix simultaneous traversal error: dimension mismatch." ); @@ -1427,8 +1437,9 @@ where #[inline] fn lt(&self, right: &Self) -> bool { - assert!( - self.shape() == right.shape(), + assert_eq!( + self.shape(), + right.shape(), "Matrix comparison error: dimensions mismatch." ); self.iter().zip(right.iter()).all(|(a, b)| a.lt(b)) @@ -1436,8 +1447,9 @@ where #[inline] fn le(&self, right: &Self) -> bool { - assert!( - self.shape() == right.shape(), + assert_eq!( + self.shape(), + right.shape(), "Matrix comparison error: dimensions mismatch." ); self.iter().zip(right.iter()).all(|(a, b)| a.le(b)) @@ -1445,8 +1457,9 @@ where #[inline] fn gt(&self, right: &Self) -> bool { - assert!( - self.shape() == right.shape(), + assert_eq!( + self.shape(), + right.shape(), "Matrix comparison error: dimensions mismatch." ); self.iter().zip(right.iter()).all(|(a, b)| a.gt(b)) @@ -1454,8 +1467,9 @@ where #[inline] fn ge(&self, right: &Self) -> bool { - assert!( - self.shape() == right.shape(), + assert_eq!( + self.shape(), + right.shape(), "Matrix comparison error: dimensions mismatch." ); self.iter().zip(right.iter()).all(|(a, b)| a.ge(b)) @@ -1602,7 +1616,11 @@ impl + SameNumberOfColumns, { - assert!(self.shape() == (2, 1), "2D perpendicular product "); + assert!( + self.shape() == (2, 1), + "2D perpendicular product requires (2, 1) vector but found {:?}", + self.shape() + ); unsafe { self.get_unchecked((0, 0)).inlined_clone() * b.get_unchecked((1, 0)).inlined_clone() @@ -1626,13 +1644,11 @@ impl + SameNumberOfColumns, { let shape = self.shape(); - assert!( - shape == b.shape(), - "Vector cross product dimension mismatch." - ); + assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch."); assert!( (shape.0 == 3 && shape.1 == 1) || (shape.0 == 1 && shape.1 == 3), - "Vector cross product dimension mismatch." + "Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.", + shape ); if shape.0 == 3 { diff --git a/src/base/ops.rs b/src/base/ops.rs index 12b26d1a..7f0c6e4e 100644 --- a/src/base/ops.rs +++ b/src/base/ops.rs @@ -154,8 +154,8 @@ macro_rules! componentwise_binop_impl( out: &mut Matrix) where SB: Storage, SC: StorageMut { - assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); - assert!(self.shape() == out.shape(), "Matrix addition/subtraction output dimensions mismatch."); + assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); + assert_eq!(self.shape(), out.shape(), "Matrix addition/subtraction output dimensions mismatch."); // This is the most common case and should be deduced at compile-time. // FIXME: use specialization instead? @@ -188,7 +188,7 @@ macro_rules! componentwise_binop_impl( C2: Dim, SA: StorageMut, SB: Storage { - assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); + assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); // This is the most common case and should be deduced at compile-time. // FIXME: use specialization instead? @@ -218,7 +218,7 @@ macro_rules! componentwise_binop_impl( where R2: Dim, C2: Dim, SB: StorageMut { - assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); + assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); // This is the most common case and should be deduced at compile-time. // FIXME: use specialization instead? @@ -277,7 +277,7 @@ macro_rules! componentwise_binop_impl( #[inline] fn $method(self, rhs: &'b Matrix) -> Self::Output { - assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); + assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); let mut res = self.into_owned_sum::(); res.$method_assign_statically_unchecked(rhs); res @@ -296,7 +296,7 @@ macro_rules! componentwise_binop_impl( #[inline] fn $method(self, rhs: Matrix) -> Self::Output { let mut rhs = rhs.into_owned_sum::(); - assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); + assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch."); self.$method_assign_statically_unchecked_rhs(&mut rhs); rhs } @@ -728,11 +728,21 @@ where assert!( nrows1 == nrows2, - "Matrix multiplication dimensions mismatch." + "Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.", + self.shape(), + rhs.shape() ); assert!( - nrows3 == ncols1 && ncols3 == ncols2, - "Matrix multiplication output dimensions mismatch." + ncols1 == nrows3, + "Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.", + self.shape(), + out.shape() + ); + assert!( + ncols2 == ncols3, + "Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols", + rhs.shape(), + out.shape() ); for i in 0..ncols1 {