forked from M-Labs/nalgebra
More verbose DMatrix dim asserts where possible.
Previously, most dimension mismatch asserts used raw `assert!` and did not include the mismatching dimensions in the panic message. When using dynamic matrices, this led to somewhat-opaque panics such as: ```rust let m1 = DMatrix::<f32>::zeros(2, 3); let m2 = DMatrix::<f32>::zeros(5, 10); m1 + m2 // panic: Matrix addition/subtraction dimensions mismatch. ``` This patch adds dimension information in the panic messages wherever doing so did not add additional bounds checks, mostly by simply changing `assert!(a == b, ...)` cases to `assert_eq!`. After: ```rust // panic: assertion failed: `(left == right)` // left: `(2, 3)`, // right: `(5, 10)`: Matrix addition/subtraction dimensions mismatch. ``` Note that the `gemv` and `ger` were not updated, as they are called from within other functions on subset matricies -- e.g., `gemv` is called from `gemm` which is called from `mul_to` . Including dimension information in the `gemv` panic messages would be confusing to `mul` / `mul_to` users, because it would include dimensions of the column vectors that `gemm` passes to `gemv` rather than of the original `mul` arguments. A fix would be to add bounds checks to `mul_to`, but that may have performance and redundancy implications, so is left to another patch.
This commit is contained in:
parent
2198b0e6b4
commit
85a64fb517
@ -284,7 +284,9 @@ where
|
||||
{
|
||||
assert!(
|
||||
self.nrows() == rhs.nrows(),
|
||||
"Dot product dimensions mismatch."
|
||||
"Dot product dimensions mismatch for shapes {:?} and {:?}: left rows != right rows.",
|
||||
self.shape(),
|
||||
rhs.shape(),
|
||||
);
|
||||
|
||||
// So we do some special cases for common fixed-size vectors of dimension lower than 8
|
||||
@ -496,8 +498,9 @@ where
|
||||
ShapeConstraint: DimEq<C, R2> + DimEq<R, C2>,
|
||||
{
|
||||
let (nrows, ncols) = self.shape();
|
||||
assert!(
|
||||
(ncols, nrows) == rhs.shape(),
|
||||
assert_eq!(
|
||||
(ncols, nrows),
|
||||
rhs.shape(),
|
||||
"Transposed dot product dimension mismatch."
|
||||
);
|
||||
|
||||
|
@ -538,8 +538,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
||||
|
||||
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
|
||||
|
||||
assert!(
|
||||
(nrows.value(), ncols.value()) == rhs.shape(),
|
||||
assert_eq!(
|
||||
(nrows.value(), ncols.value()),
|
||||
rhs.shape(),
|
||||
"Matrix simultaneous traversal error: dimension mismatch."
|
||||
);
|
||||
|
||||
@ -578,9 +579,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
||||
|
||||
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
|
||||
|
||||
assert!(
|
||||
(nrows.value(), ncols.value()) == b.shape()
|
||||
&& (nrows.value(), ncols.value()) == c.shape(),
|
||||
assert_eq!(
|
||||
(nrows.value(), ncols.value()),
|
||||
b.shape(),
|
||||
"Matrix simultaneous traversal error: dimension mismatch."
|
||||
);
|
||||
assert_eq!(
|
||||
(nrows.value(), ncols.value()),
|
||||
c.shape(),
|
||||
"Matrix simultaneous traversal error: dimension mismatch."
|
||||
);
|
||||
|
||||
@ -636,8 +642,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
||||
|
||||
let mut res = init;
|
||||
|
||||
assert!(
|
||||
(nrows.value(), ncols.value()) == rhs.shape(),
|
||||
assert_eq!(
|
||||
(nrows.value(), ncols.value()),
|
||||
rhs.shape(),
|
||||
"Matrix simultaneous traversal error: dimension mismatch."
|
||||
);
|
||||
|
||||
@ -884,8 +891,9 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
|
||||
{
|
||||
let (nrows, ncols) = self.shape();
|
||||
|
||||
assert!(
|
||||
(nrows, ncols) == rhs.shape(),
|
||||
assert_eq!(
|
||||
(nrows, ncols),
|
||||
rhs.shape(),
|
||||
"Matrix simultaneous traversal error: dimension mismatch."
|
||||
);
|
||||
|
||||
@ -922,12 +930,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
|
||||
{
|
||||
let (nrows, ncols) = self.shape();
|
||||
|
||||
assert!(
|
||||
(nrows, ncols) == b.shape(),
|
||||
assert_eq!(
|
||||
(nrows, ncols),
|
||||
b.shape(),
|
||||
"Matrix simultaneous traversal error: dimension mismatch."
|
||||
);
|
||||
assert!(
|
||||
(nrows, ncols) == c.shape(),
|
||||
assert_eq!(
|
||||
(nrows, ncols),
|
||||
c.shape(),
|
||||
"Matrix simultaneous traversal error: dimension mismatch."
|
||||
);
|
||||
|
||||
@ -1427,8 +1437,9 @@ where
|
||||
|
||||
#[inline]
|
||||
fn lt(&self, right: &Self) -> bool {
|
||||
assert!(
|
||||
self.shape() == right.shape(),
|
||||
assert_eq!(
|
||||
self.shape(),
|
||||
right.shape(),
|
||||
"Matrix comparison error: dimensions mismatch."
|
||||
);
|
||||
self.iter().zip(right.iter()).all(|(a, b)| a.lt(b))
|
||||
@ -1436,8 +1447,9 @@ where
|
||||
|
||||
#[inline]
|
||||
fn le(&self, right: &Self) -> bool {
|
||||
assert!(
|
||||
self.shape() == right.shape(),
|
||||
assert_eq!(
|
||||
self.shape(),
|
||||
right.shape(),
|
||||
"Matrix comparison error: dimensions mismatch."
|
||||
);
|
||||
self.iter().zip(right.iter()).all(|(a, b)| a.le(b))
|
||||
@ -1445,8 +1457,9 @@ where
|
||||
|
||||
#[inline]
|
||||
fn gt(&self, right: &Self) -> bool {
|
||||
assert!(
|
||||
self.shape() == right.shape(),
|
||||
assert_eq!(
|
||||
self.shape(),
|
||||
right.shape(),
|
||||
"Matrix comparison error: dimensions mismatch."
|
||||
);
|
||||
self.iter().zip(right.iter()).all(|(a, b)| a.gt(b))
|
||||
@ -1454,8 +1467,9 @@ where
|
||||
|
||||
#[inline]
|
||||
fn ge(&self, right: &Self) -> bool {
|
||||
assert!(
|
||||
self.shape() == right.shape(),
|
||||
assert_eq!(
|
||||
self.shape(),
|
||||
right.shape(),
|
||||
"Matrix comparison error: dimensions mismatch."
|
||||
);
|
||||
self.iter().zip(right.iter()).all(|(a, b)| a.ge(b))
|
||||
@ -1602,7 +1616,11 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
|
||||
+ SameNumberOfRows<R2, U2>
|
||||
+ SameNumberOfColumns<C2, U1>,
|
||||
{
|
||||
assert!(self.shape() == (2, 1), "2D perpendicular product ");
|
||||
assert!(
|
||||
self.shape() == (2, 1),
|
||||
"2D perpendicular product requires (2, 1) vector but found {:?}",
|
||||
self.shape()
|
||||
);
|
||||
|
||||
unsafe {
|
||||
self.get_unchecked((0, 0)).inlined_clone() * b.get_unchecked((1, 0)).inlined_clone()
|
||||
@ -1626,13 +1644,11 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
|
||||
ShapeConstraint: SameNumberOfRows<R, R2> + SameNumberOfColumns<C, C2>,
|
||||
{
|
||||
let shape = self.shape();
|
||||
assert!(
|
||||
shape == b.shape(),
|
||||
"Vector cross product dimension mismatch."
|
||||
);
|
||||
assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch.");
|
||||
assert!(
|
||||
(shape.0 == 3 && shape.1 == 1) || (shape.0 == 1 && shape.1 == 3),
|
||||
"Vector cross product dimension mismatch."
|
||||
"Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.",
|
||||
shape
|
||||
);
|
||||
|
||||
if shape.0 == 3 {
|
||||
|
@ -154,8 +154,8 @@ macro_rules! componentwise_binop_impl(
|
||||
out: &mut Matrix<N, R3, C3, SC>)
|
||||
where SB: Storage<N, R2, C2>,
|
||||
SC: StorageMut<N, R3, C3> {
|
||||
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
assert!(self.shape() == out.shape(), "Matrix addition/subtraction output dimensions mismatch.");
|
||||
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
assert_eq!(self.shape(), out.shape(), "Matrix addition/subtraction output dimensions mismatch.");
|
||||
|
||||
// This is the most common case and should be deduced at compile-time.
|
||||
// FIXME: use specialization instead?
|
||||
@ -188,7 +188,7 @@ macro_rules! componentwise_binop_impl(
|
||||
C2: Dim,
|
||||
SA: StorageMut<N, R1, C1>,
|
||||
SB: Storage<N, R2, C2> {
|
||||
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
|
||||
// This is the most common case and should be deduced at compile-time.
|
||||
// FIXME: use specialization instead?
|
||||
@ -218,7 +218,7 @@ macro_rules! componentwise_binop_impl(
|
||||
where R2: Dim,
|
||||
C2: Dim,
|
||||
SB: StorageMut<N, R2, C2> {
|
||||
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
|
||||
// This is the most common case and should be deduced at compile-time.
|
||||
// FIXME: use specialization instead?
|
||||
@ -277,7 +277,7 @@ macro_rules! componentwise_binop_impl(
|
||||
|
||||
#[inline]
|
||||
fn $method(self, rhs: &'b Matrix<N, R2, C2, SB>) -> Self::Output {
|
||||
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
let mut res = self.into_owned_sum::<R2, C2>();
|
||||
res.$method_assign_statically_unchecked(rhs);
|
||||
res
|
||||
@ -296,7 +296,7 @@ macro_rules! componentwise_binop_impl(
|
||||
#[inline]
|
||||
fn $method(self, rhs: Matrix<N, R2, C2, SB>) -> Self::Output {
|
||||
let mut rhs = rhs.into_owned_sum::<R1, C1>();
|
||||
assert!(self.shape() == rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
assert_eq!(self.shape(), rhs.shape(), "Matrix addition/subtraction dimensions mismatch.");
|
||||
self.$method_assign_statically_unchecked_rhs(&mut rhs);
|
||||
rhs
|
||||
}
|
||||
@ -728,11 +728,21 @@ where
|
||||
|
||||
assert!(
|
||||
nrows1 == nrows2,
|
||||
"Matrix multiplication dimensions mismatch."
|
||||
"Matrix multiplication dimensions mismatch {:?} and {:?}: left rows != right rows.",
|
||||
self.shape(),
|
||||
rhs.shape()
|
||||
);
|
||||
assert!(
|
||||
nrows3 == ncols1 && ncols3 == ncols2,
|
||||
"Matrix multiplication output dimensions mismatch."
|
||||
ncols1 == nrows3,
|
||||
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right rows.",
|
||||
self.shape(),
|
||||
out.shape()
|
||||
);
|
||||
assert!(
|
||||
ncols2 == ncols3,
|
||||
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left cols != right cols",
|
||||
rhs.shape(),
|
||||
out.shape()
|
||||
);
|
||||
|
||||
for i in 0..ncols1 {
|
||||
|
Loading…
Reference in New Issue
Block a user