From c4bce82290718efeb9509b97e44c993887483409 Mon Sep 17 00:00:00 2001 From: Oleg Beloglazov Date: Sun, 10 Jan 2016 16:03:04 +0300 Subject: [PATCH] Implement Row & Col Traits for DMat --- src/structs/dmat.rs | 80 +++++++++++++++++++++++++++++++++++------- tests/mat.rs | 84 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 13 deletions(-) diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index 7cf5ab7f..74a0308e 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -10,7 +10,7 @@ use rand::{self, Rand}; use num::{Zero, One}; use structs::dvec::DVec; use traits::operations::{ApproxEq, Inv, Transpose, Mean, Cov}; -use traits::structure::{Cast, ColSlice, RowSlice, Diag, DiagMut, Eye, Indexable, Shape, BaseNum}; +use traits::structure::{Cast, Col, Row, ColSlice, RowSlice, Diag, DiagMut, Eye, Indexable, Shape, BaseNum}; #[cfg(feature="arbitrary")] use quickcheck::{Arbitrary, Gen}; @@ -135,18 +135,6 @@ impl DMat { } } - /// The number of row on the matrix. - #[inline] - pub fn nrows(&self) -> usize { - self.nrows - } - - /// The number of columns on the matrix. - #[inline] - pub fn ncols(&self) -> usize { - self.ncols - } - /// Transforms this matrix isizeo an array. This consumes the matrix and is O(1). /// The returned vector contains the matrix data in column-major order. #[inline] @@ -229,6 +217,72 @@ impl Indexable<(usize, usize), N> for DMat { } +impl Col> for DMat { + /// The number of columns on the matrix. + #[inline] + fn ncols(&self) -> usize { + self.ncols + } + + /// The `i`-th column of matrix. + #[inline] + fn col(&self, col_id: usize) -> DVec { + assert!(col_id < self.ncols); + + let start = self.offset(0, col_id); + let stop = self.offset(self.nrows, col_id); + DVec::from_slice(self.nrows, &self.mij[start .. stop]) + } + + /// Writes the `i`-th column of matrix. + #[inline] + fn set_col(&mut self, col_id: usize, col: DVec) { + assert!(col_id < self.ncols); + + for row_id in 0..self.nrows { + unsafe { + self.unsafe_set((row_id, col_id), col.unsafe_at(row_id)); + } + } + } +} + +impl Row> for DMat { + /// The number of row on the matrix. + #[inline] + fn nrows(&self) -> usize { + self.nrows + } + + /// The `i`-th row of matrix. + #[inline] + fn row(&self, row_id: usize) -> DVec { + assert!(row_id < self.nrows); + + let mut slice : DVec = unsafe { + DVec::new_uninitialized(self.ncols) + }; + for col_id in 0..self.ncols { + unsafe { + slice.unsafe_set(col_id, self.unsafe_at((row_id, col_id))); + } + } + slice + } + + /// Writes the `i`-th row of matrix. + #[inline] + fn set_row(&mut self, row_id: usize, row: DVec) { + assert!(row_id < self.nrows); + + for col_id in 0..self.ncols { + unsafe { + self.unsafe_set((row_id, col_id), row.unsafe_at(col_id)); + } + } + } +} + impl Shape<(usize, usize)> for DMat { #[inline] fn shape(&self) -> (usize, usize) { diff --git a/tests/mat.rs b/tests/mat.rs index afae18d8..f0575152 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -521,6 +521,90 @@ fn test_dmat_subtraction() { assert!((mat1 - mat2) == res); } +#[test] +fn test_dmat_col() { + let mat = DMat::from_row_vec( + 3, + 3, + &[ + 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, + 7.0, 8.0, 9.0, + ] + ); + + assert!(mat.col(1) == DVec::from_slice(3, &[2.0, 5.0, 8.0])); +} + +#[test] +fn test_dmat_set_col() { + let mut mat = DMat::from_row_vec( + 3, + 3, + &[ + 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, + 7.0, 8.0, 9.0, + ] + ); + + mat.set_col(1, DVec::from_slice(3, &[12.0, 15.0, 18.0])); + + let expected = DMat::from_row_vec( + 3, + 3, + &[ + 1.0, 12.0, 3.0, + 4.0, 15.0, 6.0, + 7.0, 18.0, 9.0, + ] + ); + + assert!(mat == expected); +} + +#[test] +fn test_dmat_row() { + let mat = DMat::from_row_vec( + 3, + 3, + &[ + 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, + 7.0, 8.0, 9.0, + ] + ); + + assert!(mat.row(1) == DVec::from_slice(3, &[4.0, 5.0, 6.0])); +} + +#[test] +fn test_dmat_set_row() { + let mut mat = DMat::from_row_vec( + 3, + 3, + &[ + 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, + 7.0, 8.0, 9.0, + ] + ); + + mat.set_row(1, DVec::from_slice(3, &[14.0, 15.0, 16.0])); + + let expected = DMat::from_row_vec( + 3, + 3, + &[ + 1.0, 2.0, 3.0, + 14.0, 15.0, 16.0, + 7.0, 8.0, 9.0, + ] + ); + + assert!(mat == expected); +} + /* FIXME: review qr decomposition to make it work with DMat. #[test] fn test_qr() {