Use an iterator to iterate through a column entries.
This commit is contained in:
parent
9fa3e7a769
commit
dc8edeceb2
@ -1,21 +1,30 @@
|
||||
use alga::general::{ClosedAdd, ClosedMul};
|
||||
use num::{One, Zero};
|
||||
use std::iter;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::{Add, Mul, Range};
|
||||
use std::slice;
|
||||
|
||||
use allocator::Allocator;
|
||||
use constraint::{AreMultipliable, DimEq, ShapeConstraint};
|
||||
use storage::{Storage, StorageMut};
|
||||
use {DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1};
|
||||
|
||||
pub trait CsStorage<N, R, C = U1> {
|
||||
// FIXME: this structure exists for now only because impl trait
|
||||
// cannot be used for trait method return types.
|
||||
pub trait CsStorageIter<'a, N, R, C = U1> {
|
||||
type ColumnEntries: Iterator<Item = (usize, N)>;
|
||||
|
||||
fn column_entries(&'a self, j: usize) -> Self::ColumnEntries;
|
||||
}
|
||||
|
||||
pub trait CsStorage<N, R, C = U1>: for<'a> CsStorageIter<'a, N, R, C> {
|
||||
fn shape(&self) -> (R, C);
|
||||
fn nvalues(&self) -> usize;
|
||||
unsafe fn row_index_unchecked(&self, i: usize) -> usize;
|
||||
unsafe fn get_value_unchecked(&self, i: usize) -> &N;
|
||||
fn get_value(&self, i: usize) -> &N;
|
||||
fn row_index(&self, i: usize) -> usize;
|
||||
fn column_range(&self, j: usize) -> Range<usize>;
|
||||
}
|
||||
|
||||
pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> {
|
||||
@ -24,10 +33,8 @@ pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> {
|
||||
///
|
||||
/// If the given length is larger than the current one, uninitialized entries are
|
||||
/// added at the end of the column `i`. This will effectively shift all the matrix entries
|
||||
/// of the columns at indices `j` with `j > i`. Therefore this is a `O(n)` operation.
|
||||
/// This is unsafe as the row indices on newly created components may end up being out
|
||||
/// of bounds.
|
||||
unsafe fn set_column_len(&mut self, i: usize, len: usize);
|
||||
/// of the columns at indices `j` with `j > i`.
|
||||
fn set_column_len(&mut self, i: usize, len: usize);
|
||||
*/
|
||||
}
|
||||
|
||||
@ -42,6 +49,39 @@ where
|
||||
vals: Vec<N>,
|
||||
}
|
||||
|
||||
impl<N: Scalar, R: Dim, C: Dim> CsVecStorage<N, R, C>
|
||||
where
|
||||
DefaultAllocator: Allocator<usize, C>,
|
||||
{
|
||||
#[inline]
|
||||
fn column_range(&self, j: usize) -> Range<usize> {
|
||||
let end = if j + 1 == self.p.len() {
|
||||
self.nvalues()
|
||||
} else {
|
||||
self.p[j + 1]
|
||||
};
|
||||
|
||||
self.p[j]..end
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage<N, R, C>
|
||||
where
|
||||
DefaultAllocator: Allocator<usize, C>,
|
||||
{
|
||||
type ColumnEntries =
|
||||
iter::Zip<iter::Cloned<slice::Iter<'a, usize>>, iter::Cloned<slice::Iter<'a, N>>>;
|
||||
|
||||
#[inline]
|
||||
fn column_entries(&'a self, j: usize) -> Self::ColumnEntries {
|
||||
let rng = self.column_range(j);
|
||||
self.i[rng.clone()]
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(self.vals[rng].iter().cloned())
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Scalar, R: Dim, C: Dim> CsStorage<N, R, C> for CsVecStorage<N, R, C>
|
||||
where
|
||||
DefaultAllocator: Allocator<usize, C>,
|
||||
@ -56,17 +96,6 @@ where
|
||||
self.vals.len()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn column_range(&self, j: usize) -> Range<usize> {
|
||||
let end = if j + 1 == self.p.len() {
|
||||
self.nvalues()
|
||||
} else {
|
||||
self.p[j + 1]
|
||||
};
|
||||
|
||||
self.p[j]..end
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn row_index(&self, i: usize) -> usize {
|
||||
self.i[i]
|
||||
@ -175,13 +204,10 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
||||
|
||||
// Fill the result.
|
||||
for j in 0..ncols.value() {
|
||||
let column_idx = self.data.column_range(j);
|
||||
|
||||
for vi in column_idx {
|
||||
let row_id = self.data.row_index(vi);
|
||||
for (row_id, value) in self.data.column_entries(j) {
|
||||
let shift = workspace[row_id];
|
||||
|
||||
res.data.vals[shift] = *self.data.get_value(vi);
|
||||
res.data.vals[shift] = value;
|
||||
res.data.i[shift] = j;
|
||||
workspace[row_id] += 1;
|
||||
}
|
||||
@ -204,19 +230,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
||||
N: ClosedAdd + ClosedMul,
|
||||
DefaultAllocator: Allocator<usize, C2>,
|
||||
{
|
||||
let column_idx = self.data.column_range(j);
|
||||
|
||||
for vi in column_idx {
|
||||
let i = self.data.row_index(vi);
|
||||
let val = beta * *self.data.get_value(vi);
|
||||
|
||||
for (i, val) in self.data.column_entries(j) {
|
||||
if timestamps[i] < timestamp {
|
||||
timestamps[i] = timestamp;
|
||||
res.data.i[nz] = i;
|
||||
nz += 1;
|
||||
workspace[i] = val;
|
||||
workspace[i] = val * beta;
|
||||
} else {
|
||||
workspace[i] += val;
|
||||
workspace[i] += val * beta;
|
||||
}
|
||||
}
|
||||
|
||||
@ -340,16 +361,14 @@ where
|
||||
|
||||
for j in 0..ncols2.value() {
|
||||
res.data.p[j] = nz;
|
||||
let column_idx = rhs.data.column_range(j);
|
||||
let new_size_bound = nz + nrows1.value();
|
||||
res.data.i.resize(new_size_bound, 0);
|
||||
res.data.vals.resize(new_size_bound, N::zero());
|
||||
|
||||
for vi in column_idx {
|
||||
let i = rhs.data.row_index(vi);
|
||||
for (i, val) in rhs.data.column_entries(j) {
|
||||
nz = self.scatter(
|
||||
i,
|
||||
*rhs.data.get_value(vi),
|
||||
val,
|
||||
timestamps.as_mut_slice(),
|
||||
j + 1,
|
||||
workspace.as_mut_slice(),
|
||||
@ -447,11 +466,8 @@ where
|
||||
let mut res = MatrixMN::zeros_generic(nrows, ncols);
|
||||
|
||||
for j in 0..ncols.value() {
|
||||
let column_idx = m.data.column_range(j);
|
||||
|
||||
for iv in column_idx {
|
||||
let i = m.data.row_index(iv);
|
||||
res[(i, j)] = *m.data.get_value(iv);
|
||||
for (i, val) in m.data.column_entries(j) {
|
||||
res[(i, j)] = val;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user