More verbose DMatrix dim asserts where possible.

Previously, most dimension mismatch asserts used raw `assert!` and did
not include the mismatching dimensions in the panic message. When using
dynamic matrices, this led to somewhat-opaque panics such as:

```rust
let m1 = DMatrix::<f32>::zeros(2, 3);
let m2 = DMatrix::<f32>::zeros(5, 10);
m1 + m2 // panic: Matrix addition/subtraction dimensions mismatch.
```

This patch adds dimension information in the panic messages wherever
doing so did not add additional bounds checks, mostly by simply changing
`assert!(a == b, ...)` cases to `assert_eq!`. After:

```rust
// panic: assertion failed: `(left == right)`
//   left: `(2, 3)`,
//  right: `(5, 10)`: Matrix addition/subtraction dimensions mismatch.
```

Note that the `gemv` and `ger` were not updated, as they are called from
within other functions on subset matricies -- e.g., `gemv` is called
from `gemm` which is called from `mul_to` . Including dimension
information in the `gemv` panic messages would be confusing to
`mul` / `mul_to` users, because it would include dimensions of the column
vectors that `gemm` passes to `gemv` rather than of the original `mul`
arguments. A fix would be to add bounds checks to `mul_to`, but that may
have performance and redundancy implications, so is left to another
patch.
This commit is contained in:
Jenan Wise 2020-06-22 15:29:13 -07:00
parent 2198b0e6b4
commit 85a64fb517
3 changed files with 68 additions and 39 deletions

View File

@ -284,7 +284,9 @@ where
{
assert!(
self.nrows() == rhs.nrows(),
"Dot product dimensions mismatch."
"Dot product dimensions mismatch for shapes {:?} and {:?}: left rows != right rows.",
self.shape(),
rhs.shape(),
);
// So we do some special cases for common fixed-size vectors of dimension lower than 8
@ -496,8 +498,9 @@ where
ShapeConstraint: DimEq<C, R2> + DimEq<R, C2>,
{
let (nrows, ncols) = self.shape();
assert!(
(ncols, nrows) == rhs.shape(),
assert_eq!(
(ncols, nrows),
rhs.shape(),
"Transposed dot product dimension mismatch."
);

View File

@ -538,8 +538,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
assert!(
(nrows.value(), ncols.value()) == rhs.shape(),
assert_eq!(
(nrows.value(), ncols.value()),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -578,9 +579,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
assert!(
(nrows.value(), ncols.value()) == b.shape()
&& (nrows.value(), ncols.value()) == c.shape(),
assert_eq!(
(nrows.value(), ncols.value()),
b.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
assert_eq!(
(nrows.value(), ncols.value()),
c.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -636,8 +642,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let mut res = init;
assert!(
(nrows.value(), ncols.value()) == rhs.shape(),
assert_eq!(
(nrows.value(), ncols.value()),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -884,8 +891,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
{
let (nrows, ncols) = self.shape();
assert!(
(nrows, ncols) == rhs.shape(),
assert_eq!(
(nrows, ncols),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -922,12 +930,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
{
let (nrows, ncols) = self.shape();
assert!(
(nrows, ncols) == b.shape(),
assert_eq!(
(nrows, ncols),
b.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
assert!(
(nrows, ncols) == c.shape(),
assert_eq!(
(nrows, ncols),
c.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
@ -1427,8 +1437,9 @@ where
#[inline]
fn lt(&self, right: &Self) -> bool {
assert!(
self.shape() == right.shape(),
assert_eq!(
self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch."
);
self.iter().zip(right.iter()).all(|(a, b)| a.lt(b))
@ -1436,8 +1447,9 @@ where
#[inline]
fn le(&self, right: &Self) -> bool {
assert!(
self.shape() == right.shape(),
assert_eq!(
self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch."
);
self.iter().zip(right.iter()).all(|(a, b)| a.le(b))
@ -1445,8 +1457,9 @@ where
#[inline]
fn gt(&self, right: &Self) -> bool {
assert!(
self.shape() == right.shape(),
assert_eq!(
self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch."
);
self.iter().zip(right.iter()).all(|(a, b)| a.gt(b))
@ -1454,8 +1467,9 @@ where
#[inline]
fn ge(&self, right: &Self) -> bool {
assert!(
self.shape() == right.shape(),
assert_eq!(
self.shape(),
right.shape(),
"Matrix comparison error: dimensions mismatch."
);
self.iter().zip(right.iter()).all(|(a, b)| a.ge(b))
@ -1602,7 +1616,11 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
+ SameNumberOfRows<R2, U2>
+ SameNumberOfColumns<C2, U1>,
{
assert!(self.shape() == (2, 1), "2D perpendicular product ");
assert!(
self.shape() == (2, 1),
"2D perpendicular product requires (2, 1) vector but found {:?}",
self.shape()
);
unsafe {
self.get_unchecked((0, 0)).inlined_clone() * b.get_unchecked((1, 0)).inlined_clone()
@ -1626,13 +1644,11 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
ShapeConstraint: SameNumberOfRows<R, R2> + SameNumberOfColumns<C, C2>,
{
let shape = self.shape();
assert!(
shape == b.shape(),
"Vector cross product dimension mismatch."
);
assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch.");
assert!(
(shape.0 == 3 && shape.1 == 1) || (shape.0 == 1 && shape.1 == 3),
"Vector cross product dimension mismatch."
"Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.",
shape
);
if shape.0 == 3 {

View File

@ -154,8 +154,8 @@ macro_rules! componentwise_binop_impl(
out: &mut Matrix<N, R3, C3, SC>)
where SB: Storage<N, R2, C2>,
SC: StorageMut<N, R3, C3> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert!(self.shape() == out.shape(), "Matrix addition/subtraction output dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), out.shape(), "Matrix addition/subtraction output dimensions mismatch.");
// This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead?
@ -188,7 +188,7 @@ macro_rules! componentwise_binop_impl(
C2: Dim,
SA: StorageMut<N, R1, C1>,
SB: Storage<N, R2, C2> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
// This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead?
@ -218,7 +218,7 @@ macro_rules! componentwise_binop_impl(
where R2: Dim,
C2: Dim,
SB: StorageMut<N, R2, C2> {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
// This is the most common case and should be deduced at compile-time.
// FIXME: use specialization instead?
@ -277,7 +277,7 @@ macro_rules! componentwise_binop_impl(
#[inline]
fn $method(self, rhs: &'b Matrix<N, R2, C2, SB>) -> Self::Output {
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
let mut res = self.into_owned_sum::<R2, C2>();
res.$method_assign_statically_unchecked(rhs);
res
@ -296,7 +296,7 @@ macro_rules! componentwise_binop_impl(
#[inline]
fn $method(self, rhs: Matrix<N, R2, C2, SB>) -> Self::Output {
let mut rhs = rhs.into_owned_sum::<R1, C1>();
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
self.$method_assign_statically_unchecked_rhs(&mut rhs);
rhs
}
@ -728,11 +728,21 @@ where
assert!(
nrows1 == nrows2,
"Matrix multiplication dimensions mismatch."
"Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.",
self.shape(),
rhs.shape()
);
assert!(
nrows3 == ncols1 && ncols3 == ncols2,
"Matrix multiplication output dimensions mismatch."
ncols1 == nrows3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.",
self.shape(),
out.shape()
);
assert!(
ncols2 == ncols3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols",
rhs.shape(),
out.shape()
);
for i in 0..ncols1 {