diff --git a/src/base/matrix.rs b/src/base/matrix.rs index 1c1d9e1f..db745887 100644 --- a/src/base/matrix.rs +++ b/src/base/matrix.rs @@ -894,18 +894,17 @@ where { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - assert!( - self.shape() == other.shape(), - "Matrix comparison error: dimensions mismatch." - ); + if self.shape() != other.shape() || self.nrows() == 0 || self.ncols() == 0 { + return None; + } - let first_ord = unsafe { + let mut first_ord = unsafe { self.data .get_unchecked_linear(0) .partial_cmp(other.data.get_unchecked_linear(0)) }; - if let Some(mut first_ord) = first_ord { + if let Some(first_ord) = first_ord.as_mut() { let mut it = self.iter().zip(other.iter()); let _ = it.next(); // Drop the first elements (we already tested it). @@ -914,16 +913,16 @@ where match ord { Ordering::Equal => { /* Does not change anything. */ } Ordering::Less => { - if first_ord == Ordering::Greater { + if *first_ord == Ordering::Greater { return None; } - first_ord = ord + *first_ord = ord } Ordering::Greater => { - if first_ord == Ordering::Less { + if *first_ord == Ordering::Less { return None; } - first_ord = ord + *first_ord = ord } } } else { @@ -976,8 +975,7 @@ impl Eq for Matrix where N: Scalar + Eq, S: Storage, -{ -} +{} impl PartialEq for Matrix where diff --git a/tests/core/matrix.rs b/tests/core/matrix.rs index 05e98384..411b2909 100644 --- a/tests/core/matrix.rs +++ b/tests/core/matrix.rs @@ -1,4 +1,5 @@ use num::{One, Zero}; +use std::cmp::Ordering; use na::dimension::{U15, U8}; use na::{ @@ -723,6 +724,19 @@ fn partial_clamp() { assert_eq!(*inter.unwrap(), n); } +#[test] +fn partial_cmp() { + // NOTE: from #401. + let a = Vector2::new(1.0, 6.0); + let b = Vector2::new(1.0, 3.0); + let c = Vector2::new(2.0, 7.0); + let d = Vector2::new(0.0, 7.0); + assert_eq!(a.partial_cmp(&a), Some(Ordering::Equal)); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + assert_eq!(a.partial_cmp(&c), Some(Ordering::Less)); + assert_eq!(a.partial_cmp(&d), None); +} + #[test] fn swizzle() { let a = Vector2::new(1.0f32, 2.0);