Fix Matrix PartialEq bug and add some generalization (#692)

Fix Matrix PartialEq bug and add some generalization
This commit is contained in:
Sébastien Crozet 2020-01-25 22:33:02 +01:00
commit da23386beb
4 changed files with 78 additions and 13 deletions

View File

@ -241,7 +241,7 @@ impl NamedDim for typenum::U1 {
type Name = U1;
}
macro_rules! named_dimension(
macro_rules! named_dimension (
($($D: ident),* $(,)*) => {$(
/// A type level dimension.
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]

View File

@ -1137,7 +1137,7 @@ impl<N: Scalar + Zero + One, D: DimAdd<U1> + IsNotStaticOne, S: Storage<N, D, D>
where DefaultAllocator: Allocator<N, DimSum<D, U1>, DimSum<D, U1>> {
assert!(self.is_square(), "Only square matrices can currently be transformed to homogeneous coordinates.");
let dim = DimSum::<D, U1>::from_usize(self.nrows() + 1);
let mut res = MatrixN::identity_generic(dim, dim);
let mut res = MatrixN::identity_generic(dim, dim);
res.generic_slice_mut::<D, D>((0, 0), self.data.shape()).copy_from(&self);
res
}
@ -1344,18 +1344,19 @@ where
S: Storage<N, R, C>,
{}
impl<N, R: Dim, C: Dim, S> PartialEq for Matrix<N, R, C, S>
impl<N, R, R2, C, C2, S, S2> PartialEq<Matrix<N, R2, C2, S2>> for Matrix<N, R, C, S>
where
N: Scalar,
N: Scalar + PartialEq,
C: Dim,
C2: Dim,
R: Dim,
R2: Dim,
S: Storage<N, R, C>,
S2: Storage<N, R2, C2>
{
#[inline]
fn eq(&self, right: &Matrix<N, R, C, S>) -> bool {
assert!(
self.shape() == right.shape(),
"Matrix equality test dimension mismatch."
);
self.iter().zip(right.iter()).all(|(l, r)| l == r)
fn eq(&self, right: &Matrix<N, R2, C2, S2>) -> bool {
self.shape() == right.shape() && self.iter().zip(right.iter()).all(|(l, r)| l == r)
}
}

View File

@ -1,12 +1,15 @@
use num::{One, Zero};
use std::cmp::Ordering;
use na::dimension::{U15, U8};
use na::dimension::{U15, U8, U2, U4};
use na::{
self, DMatrix, DVector, Matrix2, Matrix2x3, Matrix2x4, Matrix3, Matrix3x2, Matrix3x4, Matrix4,
Matrix4x3, Matrix4x5, Matrix5, Matrix6, MatrixMN, RowVector3, RowVector4, RowVector5,
Vector1, Vector2, Vector3, Vector4, Vector5, Vector6,
};
use typenum::{UInt, UTerm};
use serde_json::error::Category::Data;
use typenum::bit::{B0, B1};
#[test]
fn iter() {
@ -1047,3 +1050,62 @@ mod finite_dim_inner_space_tests {
true
}
}
#[test]
fn partial_eq_shape_mismatch() {
let a = Matrix2::new(1, 2, 3, 4);
let b = Matrix2x3::new(1, 2, 3, 4, 5, 6);
assert_ne!(a, b);
assert_ne!(b, a);
}
#[test]
fn partial_eq_different_types() {
// Ensure comparability of several types of Matrices
let dynamic_mat = DMatrix::from_row_slice(2, 4, &[1, 2, 3, 4, 5, 6, 7, 8]);
let static_mat = Matrix2x4::new(1, 2, 3, 4, 5, 6, 7, 8);
let mut typenum_static_mat = MatrixMN::<u8, typenum::U1024, U4>::zeros();
let mut slice = typenum_static_mat.slice_mut((0,0), (2, 4));
slice += static_mat;
let fslice_of_dmat = dynamic_mat.fixed_slice::<U2, U2>(0, 0);
let dslice_of_dmat = dynamic_mat.slice((0, 0), (2, 2));
let fslice_of_smat = static_mat.fixed_slice::<U2, U2>(0, 0);
let dslice_of_smat = static_mat.slice((0, 0), (2, 2));
assert_eq!(dynamic_mat, static_mat);
assert_eq!(static_mat, dynamic_mat);
assert_eq!(dynamic_mat, slice);
assert_eq!(slice, dynamic_mat);
assert_eq!(static_mat, slice);
assert_eq!(slice, static_mat);
assert_eq!(fslice_of_dmat, dslice_of_dmat);
assert_eq!(dslice_of_dmat, fslice_of_dmat);
assert_eq!(fslice_of_dmat, fslice_of_smat);
assert_eq!(fslice_of_smat, fslice_of_dmat);
assert_eq!(fslice_of_dmat, dslice_of_smat);
assert_eq!(dslice_of_smat, fslice_of_dmat);
assert_eq!(dslice_of_dmat, fslice_of_smat);
assert_eq!(fslice_of_smat, dslice_of_dmat);
assert_eq!(dslice_of_dmat, dslice_of_smat);
assert_eq!(dslice_of_smat, dslice_of_dmat);
assert_eq!(fslice_of_smat, dslice_of_smat);
assert_eq!(dslice_of_smat, fslice_of_smat);
assert_ne!(dynamic_mat, dslice_of_smat);
assert_ne!(dslice_of_smat, dynamic_mat);
// TODO - implement those comparisons
// assert_ne!(static_mat, typenum_static_mat);
//assert_ne!(typenum_static_mat, static_mat);
}

View File

@ -14,7 +14,8 @@ macro_rules! test_serde(
fn $test() {
let v: $ty<f32> = rand::random();
let serialized = serde_json::to_string(&v).unwrap();
assert_eq!(v, serde_json::from_str(&serialized).unwrap());
let deserialized: $ty<f32> = serde_json::from_str(&serialized).unwrap();
assert_eq!(v, deserialized);
}
)*}
);
@ -23,7 +24,8 @@ macro_rules! test_serde(
fn serde_dmatrix() {
let v: DMatrix<f32> = DMatrix::new_random(3, 4);
let serialized = serde_json::to_string(&v).unwrap();
assert_eq!(v, serde_json::from_str(&serialized).unwrap());
let deserialized: DMatrix<f32> = serde_json::from_str(&serialized).unwrap();
assert_eq!(v, deserialized);
}
test_serde!(