diff --git a/examples/convolution.rs b/src/linalg/convolution.rs similarity index 58% rename from examples/convolution.rs rename to src/linalg/convolution.rs index 07b1f974..dba6c97d 100644 --- a/examples/convolution.rs +++ b/src/linalg/convolution.rs @@ -1,20 +1,25 @@ -extern crate nalgebra as na; -use na::storage::Storage; -use na::{zero, DVector, Dim, Dynamic, Matrix, Real, VecStorage, Vector, U1}; +use storage::Storage; +use {zero, DVector, Dim, Dynamic, Matrix, Real, VecStorage, Vector, U1}; use std::cmp; -enum ConvolveMode { - Full, - Valid, - Same, -} - -fn convolve_full, Q: Storage>( +/// +/// The output is the full discrete linear convolution of the inputs +/// +pub fn convolve_full, Q: Storage>( vector: Vector, kernel: Vector, ) -> Matrix> { let vec = vector.len(); let ker = kernel.len(); + + if vec == 0 || ker == 0 { + panic!("Convolve's inputs must not be 0-sized. "); + } + + if ker > vec { + return convolve_full(kernel, vector); + } + let newlen = vec + ker - 1; let mut conv = DVector::::zeros(newlen); @@ -36,12 +41,24 @@ fn convolve_full, Q: Storage>( conv } -fn convolve_valid, Q: Storage>( +/// +/// The output consists only of those elements that do not rely on the zero-padding. +/// +pub fn convolve_valid, Q: Storage>( vector: Vector, kernel: Vector, ) -> Matrix> { let vec = vector.len(); let ker = kernel.len(); + + if vec == 0 || ker == 0 { + panic!("Convolve's inputs must not be 0-sized. "); + } + + if ker > vec { + return convolve_valid(kernel, vector); + } + let newlen = vec - ker + 1; let mut conv = DVector::::zeros(newlen); @@ -54,13 +71,24 @@ fn convolve_valid, Q: Storage>( conv } -fn convolve_same, Q: Storage>( +/// +/// The output is the same size as in1, centered with respect to the ‘full’ output. +/// +pub fn convolve_same, Q: Storage>( vector: Vector, kernel: Vector, ) -> Matrix> { let vec = vector.len(); let ker = kernel.len(); + if vec == 0 || ker == 0 { + panic!("Convolve's inputs must not be 0-sized. "); + } + + if ker > vec { + return convolve_same(kernel, vector); + } + let mut conv = DVector::::zeros(vec); for i in 0..vec { @@ -74,24 +102,4 @@ fn convolve_same, Q: Storage>( } } conv -} - -fn convolve, Q: Storage>( - vector: Vector, - kernel: Vector, - mode: Option, -) -> Matrix> { - if kernel.len() > vector.len() { - return convolve(kernel, vector, mode); - } - - match mode.unwrap_or(ConvolveMode::Full) { - ConvolveMode::Full => return convolve_full(vector, kernel), - ConvolveMode::Valid => return convolve_valid(vector, kernel), - ConvolveMode::Same => return convolve_same(vector, kernel), - } -} - -fn main() { - -} +} \ No newline at end of file diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 4418b283..b6a9e8d8 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -17,6 +17,7 @@ mod solve; mod svd; mod symmetric_eigen; mod symmetric_tridiagonal; +mod convolution; //// FIXME: Not complete enough for publishing. //// This handles only cases where each eigenvalue has multiplicity one. @@ -33,3 +34,4 @@ pub use self::schur::*; pub use self::svd::*; pub use self::symmetric_eigen::*; pub use self::symmetric_tridiagonal::*; +pub use self::convolution::*; diff --git a/tests/linalg/convolution.rs b/tests/linalg/convolution.rs new file mode 100644 index 00000000..ef3a02db --- /dev/null +++ b/tests/linalg/convolution.rs @@ -0,0 +1,49 @@ +use na::linalg::{convolve_full,convolve_valid,convolve_same}; +use na::{Vector2,Vector4,DVector}; + +// +// Should mimic calculations in Python's scipy library +// >>>from scipy.signal import convolve +// + +// >>> convolve([1,2,3,4],[1,2],"same") +// array([ 1, 4, 7, 10]) +#[test] +fn convolve_same_check(){ + let vec = Vector4::new(1.0,2.0,3.0,4.0); + let ker = Vector2::new(1.0,2.0); + + let actual = DVector::from_vec(4, vec![1.0,4.0,7.0,10.0]); + + let expected = convolve_same(vec,ker); + + assert!(relative_eq!(actual, expected, epsilon = 1.0e-7)); +} + +// >>> convolve([1,2,3,4],[1,2],"valid") +// array([ 1, 4, 7, 10, 8]) +#[test] +fn convolve_full_check(){ + let vec = Vector4::new(1.0,2.0,3.0,4.0); + let ker = Vector2::new(1.0,2.0); + + let actual = DVector::from_vec(5, vec![1.0,4.0,7.0,10.0,8.0]); + + let expected = convolve_full(vec,ker); + + assert!(relative_eq!(actual, expected, epsilon = 1.0e-7)); +} + +// >>> convolve([1,2,3,4],[1,2],"valid") +// array([ 4, 7, 10]) +#[test] +fn convolve_valid_check(){ + let vec = Vector4::new(1.0,2.0,3.0,4.0); + let ker = Vector2::new(1.0,2.0); + + let actual = DVector::from_vec(3, vec![4.0,7.0,10.0]); + + let expected = convolve_valid(vec,ker); + + assert!(relative_eq!(actual, expected, epsilon = 1.0e-7)); +} \ No newline at end of file diff --git a/tests/linalg/mod.rs b/tests/linalg/mod.rs index 74a5e03c..4e0bf2eb 100644 --- a/tests/linalg/mod.rs +++ b/tests/linalg/mod.rs @@ -11,3 +11,4 @@ mod real_schur; mod solve; mod svd; mod tridiagonal; +mod convolution; \ No newline at end of file