From dc8edeceb208469bafb94c73ada00b4dae0ddb3b Mon Sep 17 00:00:00 2001 From: sebcrozet Date: Sun, 21 Oct 2018 07:42:32 +0200 Subject: [PATCH] Use an iterator to iterate through a column entries. --- src/sparse/cs_matrix.rs | 94 ++++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 39 deletions(-) diff --git a/src/sparse/cs_matrix.rs b/src/sparse/cs_matrix.rs index e2ddf9de..2666425a 100644 --- a/src/sparse/cs_matrix.rs +++ b/src/sparse/cs_matrix.rs @@ -1,21 +1,30 @@ use alga::general::{ClosedAdd, ClosedMul}; use num::{One, Zero}; +use std::iter; use std::marker::PhantomData; use std::ops::{Add, Mul, Range}; +use std::slice; use allocator::Allocator; use constraint::{AreMultipliable, DimEq, ShapeConstraint}; use storage::{Storage, StorageMut}; use {DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1}; -pub trait CsStorage { +// FIXME: this structure exists for now only because impl trait +// cannot be used for trait method return types. +pub trait CsStorageIter<'a, N, R, C = U1> { + type ColumnEntries: Iterator; + + fn column_entries(&'a self, j: usize) -> Self::ColumnEntries; +} + +pub trait CsStorage: for<'a> CsStorageIter<'a, N, R, C> { fn shape(&self) -> (R, C); fn nvalues(&self) -> usize; unsafe fn row_index_unchecked(&self, i: usize) -> usize; unsafe fn get_value_unchecked(&self, i: usize) -> &N; fn get_value(&self, i: usize) -> &N; fn row_index(&self, i: usize) -> usize; - fn column_range(&self, j: usize) -> Range; } pub trait CsStorageMut: CsStorage { @@ -24,10 +33,8 @@ pub trait CsStorageMut: CsStorage { /// /// If the given length is larger than the current one, uninitialized entries are /// added at the end of the column `i`. This will effectively shift all the matrix entries - /// of the columns at indices `j` with `j > i`. Therefore this is a `O(n)` operation. - /// This is unsafe as the row indices on newly created components may end up being out - /// of bounds. - unsafe fn set_column_len(&mut self, i: usize, len: usize); + /// of the columns at indices `j` with `j > i`. + fn set_column_len(&mut self, i: usize, len: usize); */ } @@ -42,6 +49,39 @@ where vals: Vec, } +impl CsVecStorage +where + DefaultAllocator: Allocator, +{ + #[inline] + fn column_range(&self, j: usize) -> Range { + let end = if j + 1 == self.p.len() { + self.nvalues() + } else { + self.p[j + 1] + }; + + self.p[j]..end + } +} + +impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage +where + DefaultAllocator: Allocator, +{ + type ColumnEntries = + iter::Zip>, iter::Cloned>>; + + #[inline] + fn column_entries(&'a self, j: usize) -> Self::ColumnEntries { + let rng = self.column_range(j); + self.i[rng.clone()] + .iter() + .cloned() + .zip(self.vals[rng].iter().cloned()) + } +} + impl CsStorage for CsVecStorage where DefaultAllocator: Allocator, @@ -56,17 +96,6 @@ where self.vals.len() } - #[inline] - fn column_range(&self, j: usize) -> Range { - let end = if j + 1 == self.p.len() { - self.nvalues() - } else { - self.p[j + 1] - }; - - self.p[j]..end - } - #[inline] fn row_index(&self, i: usize) -> usize { self.i[i] @@ -175,13 +204,10 @@ impl> CsMatrix { // Fill the result. for j in 0..ncols.value() { - let column_idx = self.data.column_range(j); - - for vi in column_idx { - let row_id = self.data.row_index(vi); + for (row_id, value) in self.data.column_entries(j) { let shift = workspace[row_id]; - res.data.vals[shift] = *self.data.get_value(vi); + res.data.vals[shift] = value; res.data.i[shift] = j; workspace[row_id] += 1; } @@ -204,19 +230,14 @@ impl> CsMatrix { N: ClosedAdd + ClosedMul, DefaultAllocator: Allocator, { - let column_idx = self.data.column_range(j); - - for vi in column_idx { - let i = self.data.row_index(vi); - let val = beta * *self.data.get_value(vi); - + for (i, val) in self.data.column_entries(j) { if timestamps[i] < timestamp { timestamps[i] = timestamp; res.data.i[nz] = i; nz += 1; - workspace[i] = val; + workspace[i] = val * beta; } else { - workspace[i] += val; + workspace[i] += val * beta; } } @@ -340,16 +361,14 @@ where for j in 0..ncols2.value() { res.data.p[j] = nz; - let column_idx = rhs.data.column_range(j); let new_size_bound = nz + nrows1.value(); res.data.i.resize(new_size_bound, 0); res.data.vals.resize(new_size_bound, N::zero()); - for vi in column_idx { - let i = rhs.data.row_index(vi); + for (i, val) in rhs.data.column_entries(j) { nz = self.scatter( i, - *rhs.data.get_value(vi), + val, timestamps.as_mut_slice(), j + 1, workspace.as_mut_slice(), @@ -447,11 +466,8 @@ where let mut res = MatrixMN::zeros_generic(nrows, ncols); for j in 0..ncols.value() { - let column_idx = m.data.column_range(j); - - for iv in column_idx { - let i = m.data.row_index(iv); - res[(i, j)] = *m.data.get_value(iv); + for (i, val) in m.data.column_entries(j) { + res[(i, j)] = val; } }