Fix matrix resizing with empty matrices.

Fix #306.
This commit is contained in:
Sébastien Crozet 2018-02-03 10:46:04 +01:00
parent 5a4179c287
commit 487af7d979
2 changed files with 113 additions and 2 deletions

View File

@ -581,8 +581,12 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
if new_nrows.value() == nrows { if new_nrows.value() == nrows {
let res = unsafe { DefaultAllocator::reallocate_copy(new_nrows, new_ncols, data) }; let res = unsafe { DefaultAllocator::reallocate_copy(new_nrows, new_ncols, data) };
let mut res = Matrix::from_data(res);
if new_ncols.value() > ncols {
res.columns_range_mut(ncols..).fill(val);
}
Matrix::from_data(res) res
} else { } else {
let mut res; let mut res;
@ -609,7 +613,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
extend_rows( extend_rows(
&mut res.data.as_mut_slice(), &mut res.data.as_mut_slice(),
nrows, nrows,
ncols, new_ncols.value(),
nrows, nrows,
new_nrows.value() - nrows, new_nrows.value() - nrows,
); );
@ -638,6 +642,11 @@ unsafe fn compress_rows<N: Scalar>(
nremove: usize, nremove: usize,
) { ) {
let new_nrows = nrows - nremove; let new_nrows = nrows - nremove;
if new_nrows == 0 || ncols == 0 {
return; // Nothing to do as the output matrix is empty.
}
let ptr_in = data.as_ptr(); let ptr_in = data.as_ptr();
let ptr_out = data.as_mut_ptr(); let ptr_out = data.as_mut_ptr();
@ -670,6 +679,11 @@ unsafe fn extend_rows<N: Scalar>(
ninsert: usize, ninsert: usize,
) { ) {
let new_nrows = nrows + ninsert; let new_nrows = nrows + ninsert;
if new_nrows == 0 || ncols == 0 {
return; // Nothing to do as the output matrix is empty.
}
let ptr_in = data.as_ptr(); let ptr_in = data.as_ptr();
let ptr_out = data.as_mut_ptr(); let ptr_out = data.as_mut_ptr();

View File

@ -1,9 +1,12 @@
#![cfg_attr(rustfmt, rustfmt_skip)]
use na::{Matrix, use na::{Matrix,
DMatrix, DMatrix,
Matrix3, Matrix4, Matrix5, Matrix3, Matrix4, Matrix5,
Matrix4x3, Matrix3x4, Matrix5x3, Matrix3x5, Matrix4x5, Matrix5x4}; Matrix4x3, Matrix3x4, Matrix5x3, Matrix3x5, Matrix4x5, Matrix5x4};
use na::{Dynamic, U2, U3, U5}; use na::{Dynamic, U2, U3, U5};
#[test] #[test]
fn upper_lower_triangular() { fn upper_lower_triangular() {
let m = Matrix4::new( let m = Matrix4::new(
@ -378,6 +381,17 @@ fn insert_columns() {
assert!(computed.eq(&expected2)); assert!(computed.eq(&expected2));
} }
#[test]
fn insert_columns_to_empty_matrix() {
let m1 = DMatrix::repeat(0, 0, 0);
let m2 = DMatrix::repeat(3, 0, 0);
let expected1 = DMatrix::repeat(0, 5, 42);
let expected2 = DMatrix::repeat(3, 5, 42);
assert_eq!(expected1, m1.insert_columns(0, 5, 42));
assert_eq!(expected2, m2.insert_columns(0, 5, 42));
}
#[test] #[test]
fn insert_rows() { fn insert_rows() {
@ -438,6 +452,18 @@ fn insert_rows() {
assert!(computed.eq(&expected2)); assert!(computed.eq(&expected2));
} }
#[test]
fn insert_rows_to_empty_matrix() {
let m1 = DMatrix::repeat(0, 0, 0);
let m2 = DMatrix::repeat(0, 5, 0);
let expected1 = DMatrix::repeat(3, 0, 42);
let expected2 = DMatrix::repeat(3, 5, 42);
assert_eq!(expected1, m1.insert_rows(0, 3, 42));
assert_eq!(expected2, m2.insert_rows(0, 3, 42));
}
#[test] #[test]
fn resize() { fn resize() {
let m = Matrix3x5::new( let m = Matrix3x5::new(
@ -445,6 +471,26 @@ fn resize() {
21, 22, 23, 24, 25, 21, 22, 23, 24, 25,
31, 32, 33, 34, 35); 31, 32, 33, 34, 35);
let add_colls = DMatrix::from_row_slice(3, 6, &[
11, 12, 13, 14, 15, 42,
21, 22, 23, 24, 25, 42,
31, 32, 33, 34, 35, 42]);
let del_colls = DMatrix::from_row_slice(3, 4, &[
11, 12, 13, 14,
21, 22, 23, 24,
31, 32, 33, 34]);
let add_rows = DMatrix::from_row_slice(4, 5, &[
11, 12, 13, 14, 15,
21, 22, 23, 24, 25,
31, 32, 33, 34, 35,
42, 42, 42, 42, 42]);
let del_rows = DMatrix::from_row_slice(2, 5, &[
11, 12, 13, 14, 15,
21, 22, 23, 24, 25]);
let add_add = DMatrix::from_row_slice(5, 6, &[ let add_add = DMatrix::from_row_slice(5, 6, &[
11, 12, 13, 14, 15, 42, 11, 12, 13, 14, 15, 42,
21, 22, 23, 24, 25, 42, 21, 22, 23, 24, 25, 42,
@ -464,8 +510,59 @@ fn resize() {
let del_add = DMatrix::from_row_slice(1, 8, &[ let del_add = DMatrix::from_row_slice(1, 8, &[
11, 12, 13, 14, 15, 42, 42, 42]); 11, 12, 13, 14, 15, 42, 42, 42]);
assert_eq!(add_colls, m.resize(3, 6, 42));
assert_eq!(del_colls, m.resize(3, 4, 42));
assert_eq!(add_rows, m.resize(4, 5, 42));
assert_eq!(del_rows, m.resize(2, 5, 42));
assert_eq!(del_del, m.resize(1, 2, 42)); assert_eq!(del_del, m.resize(1, 2, 42));
assert_eq!(add_add, m.resize(5, 6, 42)); assert_eq!(add_add, m.resize(5, 6, 42));
assert_eq!(add_del, m.resize(5, 2, 42)); assert_eq!(add_del, m.resize(5, 2, 42));
assert_eq!(del_add, m.resize(1, 8, 42)); assert_eq!(del_add, m.resize(1, 8, 42));
} }
#[test]
fn resize_empty_matrix() {
let m1 = DMatrix::repeat(0, 0, 0);
let m2 = DMatrix::repeat(1, 0, 0); // Less rows than target size.
let m3 = DMatrix::repeat(3, 0, 0); // Same rows as target size.
let m4 = DMatrix::repeat(9, 0, 0); // More rows than target size.
let m5 = DMatrix::repeat(0, 1, 0); // Less columns than target size.
let m6 = DMatrix::repeat(0, 5, 0); // Same columns as target size.
let m7 = DMatrix::repeat(0, 9, 0); // More columns than target size.
let resized = DMatrix::repeat(3, 5, 42);
let resized_wo_rows = DMatrix::repeat(0, 5, 42);
let resized_wo_cols = DMatrix::repeat(3, 0, 42);
assert_eq!(resized, m1.clone().resize(3, 5, 42));
assert_eq!(resized, m2.clone().resize(3, 5, 42));
assert_eq!(resized, m3.clone().resize(3, 5, 42));
assert_eq!(resized, m4.clone().resize(3, 5, 42));
assert_eq!(resized, m5.clone().resize(3, 5, 42));
assert_eq!(resized, m6.clone().resize(3, 5, 42));
assert_eq!(resized, m7.clone().resize(3, 5, 42));
assert_eq!(resized_wo_rows, m1.clone().resize(0, 5, 42));
assert_eq!(resized_wo_rows, m2.clone().resize(0, 5, 42));
assert_eq!(resized_wo_rows, m3.clone().resize(0, 5, 42));
assert_eq!(resized_wo_rows, m4.clone().resize(0, 5, 42));
assert_eq!(resized_wo_rows, m5.clone().resize(0, 5, 42));
assert_eq!(resized_wo_rows, m6.clone().resize(0, 5, 42));
assert_eq!(resized_wo_rows, m7.clone().resize(0, 5, 42));
assert_eq!(resized_wo_cols, m1.clone().resize(3, 0, 42));
assert_eq!(resized_wo_cols, m2.clone().resize(3, 0, 42));
assert_eq!(resized_wo_cols, m3.clone().resize(3, 0, 42));
assert_eq!(resized_wo_cols, m4.clone().resize(3, 0, 42));
assert_eq!(resized_wo_cols, m5.clone().resize(3, 0, 42));
assert_eq!(resized_wo_cols, m6.clone().resize(3, 0, 42));
assert_eq!(resized_wo_cols, m7.clone().resize(3, 0, 42));
assert_eq!(m1, m1.clone().resize(0, 0, 42));
assert_eq!(m1, m2.resize(0, 0, 42));
assert_eq!(m1, m3.resize(0, 0, 42));
assert_eq!(m1, m4.resize(0, 0, 42));
assert_eq!(m1, m5.resize(0, 0, 42));
assert_eq!(m1, m6.resize(0, 0, 42));
assert_eq!(m1, m7.resize(0, 0, 42));
}