Use an iterator to iterate through a column entries.

This commit is contained in:
sebcrozet 2018-10-21 07:42:32 +02:00
parent 9fa3e7a769
commit dc8edeceb2

View File

@ -1,21 +1,30 @@
use alga::general::{ClosedAdd, ClosedMul}; use alga::general::{ClosedAdd, ClosedMul};
use num::{One, Zero}; use num::{One, Zero};
use std::iter;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::{Add, Mul, Range}; use std::ops::{Add, Mul, Range};
use std::slice;
use allocator::Allocator; use allocator::Allocator;
use constraint::{AreMultipliable, DimEq, ShapeConstraint}; use constraint::{AreMultipliable, DimEq, ShapeConstraint};
use storage::{Storage, StorageMut}; use storage::{Storage, StorageMut};
use {DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1}; use {DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1};
pub trait CsStorage<N, R, C = U1> { // FIXME: this structure exists for now only because impl trait
// cannot be used for trait method return types.
pub trait CsStorageIter<'a, N, R, C = U1> {
type ColumnEntries: Iterator<Item = (usize, N)>;
fn column_entries(&'a self, j: usize) -> Self::ColumnEntries;
}
pub trait CsStorage<N, R, C = U1>: for<'a> CsStorageIter<'a, N, R, C> {
fn shape(&self) -> (R, C); fn shape(&self) -> (R, C);
fn nvalues(&self) -> usize; fn nvalues(&self) -> usize;
unsafe fn row_index_unchecked(&self, i: usize) -> usize; unsafe fn row_index_unchecked(&self, i: usize) -> usize;
unsafe fn get_value_unchecked(&self, i: usize) -> &N; unsafe fn get_value_unchecked(&self, i: usize) -> &N;
fn get_value(&self, i: usize) -> &N; fn get_value(&self, i: usize) -> &N;
fn row_index(&self, i: usize) -> usize; fn row_index(&self, i: usize) -> usize;
fn column_range(&self, j: usize) -> Range<usize>;
} }
pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> { pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> {
@ -24,10 +33,8 @@ pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> {
/// ///
/// If the given length is larger than the current one, uninitialized entries are /// If the given length is larger than the current one, uninitialized entries are
/// added at the end of the column `i`. This will effectively shift all the matrix entries /// added at the end of the column `i`. This will effectively shift all the matrix entries
/// of the columns at indices `j` with `j > i`. Therefore this is a `O(n)` operation. /// of the columns at indices `j` with `j > i`.
/// This is unsafe as the row indices on newly created components may end up being out fn set_column_len(&mut self, i: usize, len: usize);
/// of bounds.
unsafe fn set_column_len(&mut self, i: usize, len: usize);
*/ */
} }
@ -42,6 +49,39 @@ where
vals: Vec<N>, vals: Vec<N>,
} }
impl<N: Scalar, R: Dim, C: Dim> CsVecStorage<N, R, C>
where
DefaultAllocator: Allocator<usize, C>,
{
#[inline]
fn column_range(&self, j: usize) -> Range<usize> {
let end = if j + 1 == self.p.len() {
self.nvalues()
} else {
self.p[j + 1]
};
self.p[j]..end
}
}
impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage<N, R, C>
where
DefaultAllocator: Allocator<usize, C>,
{
type ColumnEntries =
iter::Zip<iter::Cloned<slice::Iter<'a, usize>>, iter::Cloned<slice::Iter<'a, N>>>;
#[inline]
fn column_entries(&'a self, j: usize) -> Self::ColumnEntries {
let rng = self.column_range(j);
self.i[rng.clone()]
.iter()
.cloned()
.zip(self.vals[rng].iter().cloned())
}
}
impl<N: Scalar, R: Dim, C: Dim> CsStorage<N, R, C> for CsVecStorage<N, R, C> impl<N: Scalar, R: Dim, C: Dim> CsStorage<N, R, C> for CsVecStorage<N, R, C>
where where
DefaultAllocator: Allocator<usize, C>, DefaultAllocator: Allocator<usize, C>,
@ -56,17 +96,6 @@ where
self.vals.len() self.vals.len()
} }
#[inline]
fn column_range(&self, j: usize) -> Range<usize> {
let end = if j + 1 == self.p.len() {
self.nvalues()
} else {
self.p[j + 1]
};
self.p[j]..end
}
#[inline] #[inline]
fn row_index(&self, i: usize) -> usize { fn row_index(&self, i: usize) -> usize {
self.i[i] self.i[i]
@ -175,13 +204,10 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
// Fill the result. // Fill the result.
for j in 0..ncols.value() { for j in 0..ncols.value() {
let column_idx = self.data.column_range(j); for (row_id, value) in self.data.column_entries(j) {
for vi in column_idx {
let row_id = self.data.row_index(vi);
let shift = workspace[row_id]; let shift = workspace[row_id];
res.data.vals[shift] = *self.data.get_value(vi); res.data.vals[shift] = value;
res.data.i[shift] = j; res.data.i[shift] = j;
workspace[row_id] += 1; workspace[row_id] += 1;
} }
@ -204,19 +230,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
N: ClosedAdd + ClosedMul, N: ClosedAdd + ClosedMul,
DefaultAllocator: Allocator<usize, C2>, DefaultAllocator: Allocator<usize, C2>,
{ {
let column_idx = self.data.column_range(j); for (i, val) in self.data.column_entries(j) {
for vi in column_idx {
let i = self.data.row_index(vi);
let val = beta * *self.data.get_value(vi);
if timestamps[i] < timestamp { if timestamps[i] < timestamp {
timestamps[i] = timestamp; timestamps[i] = timestamp;
res.data.i[nz] = i; res.data.i[nz] = i;
nz += 1; nz += 1;
workspace[i] = val; workspace[i] = val * beta;
} else { } else {
workspace[i] += val; workspace[i] += val * beta;
} }
} }
@ -340,16 +361,14 @@ where
for j in 0..ncols2.value() { for j in 0..ncols2.value() {
res.data.p[j] = nz; res.data.p[j] = nz;
let column_idx = rhs.data.column_range(j);
let new_size_bound = nz + nrows1.value(); let new_size_bound = nz + nrows1.value();
res.data.i.resize(new_size_bound, 0); res.data.i.resize(new_size_bound, 0);
res.data.vals.resize(new_size_bound, N::zero()); res.data.vals.resize(new_size_bound, N::zero());
for vi in column_idx { for (i, val) in rhs.data.column_entries(j) {
let i = rhs.data.row_index(vi);
nz = self.scatter( nz = self.scatter(
i, i,
*rhs.data.get_value(vi), val,
timestamps.as_mut_slice(), timestamps.as_mut_slice(),
j + 1, j + 1,
workspace.as_mut_slice(), workspace.as_mut_slice(),
@ -447,11 +466,8 @@ where
let mut res = MatrixMN::zeros_generic(nrows, ncols); let mut res = MatrixMN::zeros_generic(nrows, ncols);
for j in 0..ncols.value() { for j in 0..ncols.value() {
let column_idx = m.data.column_range(j); for (i, val) in m.data.column_entries(j) {
res[(i, j)] = val;
for iv in column_idx {
let i = m.data.row_index(iv);
res[(i, j)] = *m.data.get_value(iv);
} }
} }