Allow the removal of multiple rows/columns given an array of indices. #530

This commit is contained in:
Stefan Mesken 2019-04-08 17:29:32 +02:00 committed by Sébastien Crozet
parent dda41c1508
commit 3baefb1319
2 changed files with 219 additions and 32 deletions

View File

@ -1,18 +1,18 @@
use num::{One, Zero};
use std::cmp;
use std::ptr;
#[cfg(any(feature = "std", feature = "alloc"))]
use std::iter::ExactSizeIterator;
#[cfg(any(feature = "std", feature = "alloc"))]
use std::mem;
use std::ptr;
use crate::base::allocator::{Allocator, Reallocator};
use crate::base::constraint::{DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint};
#[cfg(any(feature = "std", feature = "alloc"))]
use crate::base::dimension::Dynamic;
use crate::base::dimension::{
Dim, DimAdd, DimDiff, DimMin, DimMinimum, DimName, DimSub, DimSum, U1,
};
#[cfg(any(feature = "std", feature = "alloc"))]
use crate::base::dimension::Dynamic;
use crate::base::storage::{Storage, StorageMut};
#[cfg(any(feature = "std", feature = "alloc"))]
use crate::base::DMatrix;
@ -42,12 +42,15 @@ impl<N: Scalar + Zero, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
/// Creates a new matrix by extracting the given set of rows from `self`.
#[cfg(any(feature = "std", feature = "alloc"))]
pub fn select_rows<'a, I>(&self, irows: I) -> MatrixMN<N, Dynamic, C>
where I: IntoIterator<Item = &'a usize>,
where
I: IntoIterator<Item = &'a usize>,
I::IntoIter: ExactSizeIterator + Clone,
DefaultAllocator: Allocator<N, Dynamic, C> {
DefaultAllocator: Allocator<N, Dynamic, C>,
{
let irows = irows.into_iter();
let ncols = self.data.shape().1;
let mut res = unsafe { MatrixMN::new_uninitialized_generic(Dynamic::new(irows.len()), ncols) };
let mut res =
unsafe { MatrixMN::new_uninitialized_generic(Dynamic::new(irows.len()), ncols) };
// First, check that all the indices from irows are valid.
// This will allow us to use unchecked access in the inner loop.
@ -61,9 +64,7 @@ impl<N: Scalar + Zero, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let src = self.column(j);
for (destination, source) in irows.clone().enumerate() {
unsafe {
*res.vget_unchecked_mut(destination) = *src.vget_unchecked(*source)
}
unsafe { *res.vget_unchecked_mut(destination) = *src.vget_unchecked(*source) }
}
}
@ -73,12 +74,15 @@ impl<N: Scalar + Zero, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
/// Creates a new matrix by extracting the given set of columns from `self`.
#[cfg(any(feature = "std", feature = "alloc"))]
pub fn select_columns<'a, I>(&self, icols: I) -> MatrixMN<N, R, Dynamic>
where I: IntoIterator<Item = &'a usize>,
where
I: IntoIterator<Item = &'a usize>,
I::IntoIter: ExactSizeIterator,
DefaultAllocator: Allocator<N, R, Dynamic> {
DefaultAllocator: Allocator<N, R, Dynamic>,
{
let icols = icols.into_iter();
let nrows = self.data.shape().0;
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, Dynamic::new(icols.len())) };
let mut res =
unsafe { MatrixMN::new_uninitialized_generic(nrows, Dynamic::new(icols.len())) };
for (destination, source) in icols.enumerate() {
res.column_mut(destination).copy_from(&self.column(*source))
@ -303,6 +307,79 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
self.remove_fixed_columns::<U1>(i)
}
/// Removes all columns in `indices`
#[inline]
pub fn remove_columns_at(self, indices: &[usize]) -> MatrixMN<N, R, Dynamic>
where
C: DimSub<Dynamic, Output = Dynamic>,
DefaultAllocator: Reallocator<N, R, C, R, Dynamic>,
{
let mut m = self.into_owned();
let mut v: Vec<usize> = indices.to_vec();
let (nrows, ncols) = m.data.shape();
let mut offset: usize = 0;
let mut target: usize = 0;
while offset + target < ncols.value() {
if v.contains(&(target + offset)) {
offset += 1;
} else {
unsafe {
let ptr_source = m
.data
.ptr()
.offset(((target + offset) * nrows.value()) as isize);
let ptr_target = m.data.ptr_mut().offset((target * nrows.value()) as isize);
ptr::copy(ptr_source, ptr_target, nrows.value());
target += 1;
}
}
}
unsafe {
Matrix::from_data(DefaultAllocator::reallocate_copy(
nrows,
ncols.sub(Dynamic::from_usize(v.len())),
m.data,
))
}
}
/// Removes all columns in `indices`
#[inline]
pub fn remove_rows_at(self, indices: &[usize]) -> MatrixMN<N, Dynamic, C>
where
R: DimSub<Dynamic, Output = Dynamic>,
DefaultAllocator: Reallocator<N, R, C, Dynamic, C>,
{
let mut m = self.into_owned();
let mut v: Vec<usize> = indices.to_vec();
let (nrows, ncols) = m.data.shape();
let mut offset: usize = 0;
let mut target: usize = 0;
while offset + target < nrows.value() * ncols.value() {
if v.contains(&((target + offset) % nrows.value())) {
offset += 1;
} else {
unsafe {
let ptr_source = m.data.ptr().offset((target + offset) as isize);
let ptr_target = m.data.ptr_mut().offset(target as isize);
ptr::copy(ptr_source, ptr_target, 1);
target += 1;
}
}
}
unsafe {
Matrix::from_data(DefaultAllocator::reallocate_copy(
nrows.sub(Dynamic::from_usize(v.len())),
ncols,
m.data,
))
}
}
/// Removes `D::dim()` consecutive columns from this matrix, starting with the `i`-th
/// (included).
#[inline]
@ -644,7 +721,6 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
self.resize_generic(nrows, Dynamic::new(new_ncols), val)
}
/// Resizes this matrix so that it contains `R2::value()` rows and `C2::value()` columns.
///
/// The values are copied such that `self[(i, j)] == result[(i, j)]`. If the result has more
@ -741,7 +817,8 @@ impl<N: Scalar> DMatrix<N> {
#[cfg(any(feature = "std", feature = "alloc"))]
impl<N: Scalar, C: Dim> MatrixMN<N, Dynamic, C>
where DefaultAllocator: Allocator<N, Dynamic, C> {
where DefaultAllocator: Allocator<N, Dynamic, C>
{
/// Changes the number of rows of this matrix in-place.
///
/// The values are copied such that `self[(i, j)] == result[(i, j)]`. If the result has more
@ -751,7 +828,8 @@ impl<N: Scalar, C: Dim> MatrixMN<N, Dynamic, C>
#[cfg(any(feature = "std", feature = "alloc"))]
pub fn resize_vertically_mut(&mut self, new_nrows: usize, val: N)
where DefaultAllocator: Reallocator<N, Dynamic, C, Dynamic, C> {
let placeholder = unsafe { Self::new_uninitialized_generic(Dynamic::new(0), self.data.shape().1) };
let placeholder =
unsafe { Self::new_uninitialized_generic(Dynamic::new(0), self.data.shape().1) };
let old = mem::replace(self, placeholder);
let new = old.resize_vertically(new_nrows, val);
let _ = mem::replace(self, new);
@ -760,7 +838,8 @@ impl<N: Scalar, C: Dim> MatrixMN<N, Dynamic, C>
#[cfg(any(feature = "std", feature = "alloc"))]
impl<N: Scalar, R: Dim> MatrixMN<N, R, Dynamic>
where DefaultAllocator: Allocator<N, R, Dynamic> {
where DefaultAllocator: Allocator<N, R, Dynamic>
{
/// Changes the number of column of this matrix in-place.
///
/// The values are copied such that `self[(i, j)] == result[(i, j)]`. If the result has more
@ -770,7 +849,8 @@ impl<N: Scalar, R: Dim> MatrixMN<N, R, Dynamic>
#[cfg(any(feature = "std", feature = "alloc"))]
pub fn resize_horizontally_mut(&mut self, new_ncols: usize, val: N)
where DefaultAllocator: Reallocator<N, R, Dynamic, R, Dynamic> {
let placeholder = unsafe { Self::new_uninitialized_generic(self.data.shape().0, Dynamic::new(0)) };
let placeholder =
unsafe { Self::new_uninitialized_generic(self.data.shape().0, Dynamic::new(0)) };
let old = mem::replace(self, placeholder);
let new = old.resize_horizontally(new_ncols, val);
let _ = mem::replace(self, new);
@ -898,7 +978,7 @@ where
/// // The following panics because the vec length is not a multiple of 3.
/// matrix.extend(vec![6, 7, 8, 9]);
/// ```
fn extend<I: IntoIterator<Item=N>>(&mut self, iter: I) {
fn extend<I: IntoIterator<Item = N>>(&mut self, iter: I) {
self.data.extend(iter);
}
}
@ -921,7 +1001,7 @@ where
/// vector.extend(vec![3, 4, 5]);
/// assert!(vector.eq(&DVector::from_vec(vec![0, 1, 2, 3, 4, 5])));
/// ```
fn extend<I: IntoIterator<Item=N>>(&mut self, iter: I) {
fn extend<I: IntoIterator<Item = N>>(&mut self, iter: I) {
self.data.extend(iter);
}
}
@ -985,8 +1065,7 @@ where
/// matrix.extend(
/// vec![Vector4::new(6, 7, 8, 9)]); // too few dimensions!
/// ```
fn extend<I: IntoIterator<Item=Vector<N, RV, SV>>>(&mut self, iter: I)
{
fn extend<I: IntoIterator<Item = Vector<N, RV, SV>>>(&mut self, iter: I) {
self.data.extend(iter);
}
}

View File

@ -260,6 +260,63 @@ fn remove_columns() {
assert!(computed.eq(&expected2));
}
#[test]
fn remove_columns_at() {
let m = DMatrix::from_row_slice(5, 5, &[
11, 12, 13, 14, 15,
21, 22, 23, 24, 25,
31, 32, 33, 34, 35,
41, 42, 43, 44, 45,
51, 52, 53, 54, 55
]);
let expected1 = DMatrix::from_row_slice(5, 4, &[
12, 13, 14, 15,
22, 23, 24, 25,
32, 33, 34, 35,
42, 43, 44, 45,
52, 53, 54, 55
]);
assert_eq!(m.remove_columns_at(&[0]), expected1);
let m = DMatrix::from_row_slice(5, 5, &[
11, 12, 13, 14, 15,
21, 22, 23, 24, 25,
31, 32, 33, 34, 35,
41, 42, 43, 44, 45,
51, 52, 53, 54, 55
]);
let expected2 = DMatrix::from_row_slice(5, 3, &[
11, 13, 15,
21, 23, 25,
31, 33, 35,
41, 43, 45,
51, 53, 55
]);
assert_eq!(m.remove_columns_at(&[1,3]), expected2);
let m = DMatrix::from_row_slice(5, 5, &[
11, 12, 13, 14, 15,
21, 22, 23, 24, 25,
31, 32, 33, 34, 35,
41, 42, 43, 44, 45,
51, 52, 53, 54, 55
]);
let expected3 = DMatrix::from_row_slice(5, 2, &[
12, 13,
22, 23,
32, 33,
42, 43,
52, 53,
]);
assert_eq!(m.remove_columns_at(&[0,3,4]), expected3);
}
#[test]
fn remove_rows() {
@ -316,6 +373,57 @@ fn remove_rows() {
assert!(computed.eq(&expected2));
}
#[test]
fn remove_rows_at() {
let m = DMatrix::from_row_slice(5, 5, &[
11, 12, 13, 14, 15,
21, 22, 23, 24, 25,
31, 32, 33, 34, 35,
41, 42, 43, 44, 45,
51, 52, 53, 54, 55
]);
let expected1 = DMatrix::from_row_slice(4, 5, &[
21, 22, 23, 24, 25,
31, 32, 33, 34, 35,
41, 42, 43, 44, 45,
51, 52, 53, 54, 55
]);
assert_eq!(m.remove_rows_at(&[0]), expected1);
let m = DMatrix::from_row_slice(5, 5, &[
11, 12, 13, 14, 15,
21, 22, 23, 24, 25,
31, 32, 33, 34, 35,
41, 42, 43, 44, 45,
51, 52, 53, 54, 55
]);
let expected2 = DMatrix::from_row_slice(3, 5, &[
11, 12, 13, 14, 15,
31, 32, 33, 34, 35,
51, 52, 53, 54, 55
]);
assert_eq!(m.remove_rows_at(&[1,3]), expected2);
let m = DMatrix::from_row_slice(5, 5, &[
11, 12, 13, 14, 15,
21, 22, 23, 24, 25,
31, 32, 33, 34, 35,
41, 42, 43, 44, 45,
51, 52, 53, 54, 55
]);
let expected3 = DMatrix::from_row_slice(2, 5, &[
21, 22, 23, 24, 25,
31, 32, 33, 34, 35
]);
assert_eq!(m.remove_rows_at(&[0,3,4]), expected3);
}
#[test]
fn insert_columns() {