diff --git a/src/linalg/convolution.rs b/src/linalg/convolution.rs index 49c8b154..b121b34a 100644 --- a/src/linalg/convolution.rs +++ b/src/linalg/convolution.rs @@ -5,138 +5,125 @@ use std::cmp; use storage::Storage; use {zero, Real, Vector, VectorN, U1}; -/// Returns the convolution of the target vector and a kernel -/// -/// # Arguments -/// -/// * `vector` - A Vector with size > 0 -/// * `kernel` - A Vector with size > 0 -/// -/// # Errors -/// Inputs must statisfy `vector.len() >= kernel.len() > 0`. -/// -pub fn convolve_full( - vector: Vector, - kernel: Vector, -) -> VectorN, U1>> -where - N: Real, - D1: DimAdd, - D2: DimAdd>, - DimSum: DimSub, - S1: Storage, - S2: Storage, - DefaultAllocator: Allocator, U1>>, -{ - let vec = vector.len(); - let ker = kernel.len(); +impl> Vector { + /// Returns the convolution of the target vector and a kernel + /// + /// # Arguments + /// + /// * `kernel` - A Vector with size > 0 + /// + /// # Errors + /// Inputs must statisfy `vector.len() >= kernel.len() > 0`. + /// + pub fn convolve_full( + &self, + kernel: Vector, + ) -> VectorN, U1>> + where + D1: DimAdd, + D2: DimAdd>, + DimSum: DimSub, + S2: Storage, + DefaultAllocator: Allocator, U1>>, + { + let vec = self.len(); + let ker = kernel.len(); - if ker == 0 || ker > vec { - panic!("convolve_full expects `vector.len() >= kernel.len() > 0`, received {} and {} respectively.",vec,ker); - } + if ker == 0 || ker > vec { + panic!("convolve_full expects `self.len() >= kernel.len() > 0`, received {} and {} respectively.",vec,ker); + } - let result_len = vector.data.shape().0.add(kernel.data.shape().0).sub(U1); - let mut conv = VectorN::zeros_generic(result_len, U1); + let result_len = self.data.shape().0.add(kernel.data.shape().0).sub(U1); + let mut conv = VectorN::zeros_generic(result_len, U1); - for i in 0..(vec + ker - 1) { - let u_i = if i > vec { i - ker } else { 0 }; - let u_f = cmp::min(i, vec - 1); + for i in 0..(vec + ker - 1) { + let u_i = if i > vec { i - ker } else { 0 }; + let u_f = cmp::min(i, vec - 1); - if u_i == u_f { - conv[i] += vector[u_i] * kernel[(i - u_i)]; - } else { - for u in u_i..(u_f + 1) { - if i - u < ker { - conv[i] += vector[u] * kernel[(i - u)]; + if u_i == u_f { + conv[i] += self[u_i] * kernel[(i - u_i)]; + } else { + for u in u_i..(u_f + 1) { + if i - u < ker { + conv[i] += self[u] * kernel[(i - u)]; + } } } } + conv } - conv -} + /// Returns the convolution of the target vector and a kernel + /// The output convolution consists only of those elements that do not rely on the zero-padding. + /// # Arguments + /// + /// * `kernel` - A Vector with size > 0 + /// + /// + /// # Errors + /// Inputs must statisfy `self.len() >= kernel.len() > 0`. + /// + pub fn convolve_valid(&self, kernel: Vector, + ) -> VectorN, D2>> + where + D1: DimAdd, + D2: Dim, + DimSum: DimSub, + S2: Storage, + DefaultAllocator: Allocator, D2>>, + { + let vec = self.len(); + let ker = kernel.len(); -/// Returns the convolution of the vector and a kernel -/// The output convolution consists only of those elements that do not rely on the zero-padding. -/// # Arguments -/// -/// * `vector` - A Vector with size > 0 -/// * `kernel` - A Vector with size > 0 -/// -/// -/// # Errors -/// Inputs must statisfy `vector.len() >= kernel.len() > 0`. -/// -pub fn convolve_valid( - vector: Vector, - kernel: Vector, -) -> VectorN, D2>> -where - N: Real, - D1: DimAdd, - D2: Dim, - DimSum: DimSub, - S1: Storage, - S2: Storage, - DefaultAllocator: Allocator, D2>>, -{ - let vec = vector.len(); - let ker = kernel.len(); - - if ker == 0 || ker > vec { - panic!("convolve_valid expects `vector.len() >= kernel.len() > 0`, received {} and {} respectively.",vec,ker); - } - - let result_len = vector.data.shape().0.add(U1).sub(kernel.data.shape().0); - let mut conv = VectorN::zeros_generic(result_len, U1); - - for i in 0..(vec - ker + 1) { - for j in 0..ker { - conv[i] += vector[i + j] * kernel[ker - j - 1]; + if ker == 0 || ker > vec { + panic!("convolve_valid expects `self.len() >= kernel.len() > 0`, received {} and {} respectively.",vec,ker); } - } - conv -} -/// Returns the convolution of the vector and a kernel -/// The output convolution is the same size as vector, centered with respect to the ‘full’ output. -/// # Arguments -/// -/// * `vector` - A Vector with size > 0 -/// * `kernel` - A Vector with size > 0 -/// -/// # Errors -/// Inputs must statisfy `vector.len() >= kernel.len() > 0`. -pub fn convolve_same( - vector: Vector, - kernel: Vector, -) -> VectorN> -where - N: Real, - D1: DimMax, - D2: DimMax>, - S1: Storage, - S2: Storage, - DefaultAllocator: Allocator>, -{ - let vec = vector.len(); - let ker = kernel.len(); + let result_len = self.data.shape().0.add(U1).sub(kernel.data.shape().0); + let mut conv = VectorN::zeros_generic(result_len, U1); - if ker == 0 || ker > vec { - panic!("convolve_same expects `vector.len() >= kernel.len() > 0`, received {} and {} respectively.",vec,ker); - } - - let result_len = vector.data.shape().0.max(kernel.data.shape().0); - let mut conv = VectorN::zeros_generic(result_len, U1); - - for i in 0..vec { - for j in 0..ker { - let val = if i + j < 1 || i + j >= vec + 1 { - zero::() - } else { - vector[i + j - 1] - }; - conv[i] += val * kernel[ker - j - 1]; + for i in 0..(vec - ker + 1) { + for j in 0..ker { + conv[i] += self[i + j] * kernel[ker - j - 1]; + } } + conv + } + + /// Returns the convolution of the targetvector and a kernel + /// The output convolution is the same size as vector, centered with respect to the ‘full’ output. + /// # Arguments + /// + /// * `kernel` - A Vector with size > 0 + /// + /// # Errors + /// Inputs must statisfy `self.len() >= kernel.len() > 0`. + pub fn convolve_same(&self, kernel: Vector) -> VectorN> + where + D1: DimMax, + D2: DimMax>, + S2: Storage, + DefaultAllocator: Allocator>, + { + let vec = self.len(); + let ker = kernel.len(); + + if ker == 0 || ker > vec { + panic!("convolve_same expects `self.len() >= kernel.len() > 0`, received {} and {} respectively.",vec,ker); + } + + let result_len = self.data.shape().0.max(kernel.data.shape().0); + let mut conv = VectorN::zeros_generic(result_len, U1); + + for i in 0..vec { + for j in 0..ker { + let val = if i + j < 1 || i + j >= vec + 1 { + zero::() + } else { + self[i + j - 1] + }; + conv[i] += val * kernel[ker - j - 1]; + } + } + conv } - conv } diff --git a/tests/linalg/convolution.rs b/tests/linalg/convolution.rs index ddcfe9f6..b0d57f72 100644 --- a/tests/linalg/convolution.rs +++ b/tests/linalg/convolution.rs @@ -1,4 +1,3 @@ -use na::linalg::{convolve_full,convolve_valid,convolve_same}; use na::{Vector2,Vector3,Vector4,Vector5,DVector}; use std::panic; @@ -13,13 +12,13 @@ use std::panic; fn convolve_same_check(){ // Static Tests let actual_s = Vector4::from_vec(vec![1.0,4.0,7.0,10.0]); - let expected_s = convolve_same(Vector4::new(1.0,2.0,3.0,4.0), Vector2::new(1.0,2.0)); + let expected_s = Vector4::new(1.0,2.0,3.0,4.0).convolve_same(Vector2::new(1.0,2.0)); assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); // Dynamic Tests let actual_d = DVector::from_vec(vec![1.0,4.0,7.0,10.0]); - let expected_d = convolve_same(DVector::from_vec(vec![1.0,2.0,3.0,4.0]),DVector::from_vec(vec![1.0,2.0])); + let expected_d = DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_same(DVector::from_vec(vec![1.0,2.0])); assert!(relative_eq!(actual_d, expected_d, epsilon = 1.0e-7)); @@ -27,19 +26,19 @@ fn convolve_same_check(){ // These really only apply to dynamic sized vectors assert!( panic::catch_unwind(|| { - convolve_same(DVector::from_vec(vec![1.0,2.0]), DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::from_vec(vec![1.0,2.0]).convolve_same(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - convolve_same(DVector::::from_vec(vec![]), DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::::from_vec(vec![]).convolve_same(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - convolve_same(DVector::from_vec(vec![1.0,2.0,3.0,4.0]),DVector::::from_vec(vec![])); + DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_same(DVector::::from_vec(vec![])); }).is_err() ); } @@ -50,13 +49,13 @@ fn convolve_same_check(){ fn convolve_full_check(){ // Static Tests let actual_s = Vector5::new(1.0,4.0,7.0,10.0,8.0); - let expected_s = convolve_full(Vector4::new(1.0,2.0,3.0,4.0), Vector2::new(1.0,2.0)); + let expected_s = Vector4::new(1.0,2.0,3.0,4.0).convolve_full(Vector2::new(1.0,2.0)); assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); // Dynamic Tests let actual_d = DVector::from_vec(vec![1.0,4.0,7.0,10.0,8.0]); - let expected_d = convolve_full(DVector::from_vec(vec![1.0,2.0,3.0,4.0]), DVector::from_vec(vec![1.0,2.0])); + let expected_d = DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_full(DVector::from_vec(vec![1.0,2.0])); assert!(relative_eq!(actual_d, expected_d, epsilon = 1.0e-7)); @@ -64,19 +63,19 @@ fn convolve_full_check(){ // These really only apply to dynamic sized vectors assert!( panic::catch_unwind(|| { - convolve_full(DVector::from_vec(vec![1.0,2.0]), DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::from_vec(vec![1.0,2.0]).convolve_full(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - convolve_full(DVector::::from_vec(vec![]), DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::::from_vec(vec![]).convolve_full(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - convolve_full(DVector::from_vec(vec![1.0,2.0,3.0,4.0]),DVector::::from_vec(vec![])); + DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_full(DVector::::from_vec(vec![])); }).is_err() ); } @@ -87,13 +86,13 @@ fn convolve_full_check(){ fn convolve_valid_check(){ // Static Tests let actual_s = Vector3::from_vec(vec![4.0,7.0,10.0]); - let expected_s = convolve_valid( Vector4::new(1.0,2.0,3.0,4.0), Vector2::new(1.0,2.0)); + let expected_s = Vector4::new(1.0,2.0,3.0,4.0).convolve_valid( Vector2::new(1.0,2.0)); assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); // Dynamic Tests let actual_d = DVector::from_vec(vec![4.0,7.0,10.0]); - let expected_d = convolve_valid(DVector::from_vec(vec![1.0,2.0,3.0,4.0]), DVector::from_vec(vec![1.0,2.0])); + let expected_d = DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_valid(DVector::from_vec(vec![1.0,2.0])); assert!(relative_eq!(actual_d, expected_d, epsilon = 1.0e-7)); @@ -101,19 +100,19 @@ fn convolve_valid_check(){ // These really only apply to dynamic sized vectors assert!( panic::catch_unwind(|| { - convolve_valid(DVector::from_vec(vec![1.0,2.0]), DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::from_vec(vec![1.0,2.0]).convolve_valid(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - convolve_valid(DVector::::from_vec(vec![]), DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::::from_vec(vec![]).convolve_valid(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - convolve_valid(DVector::from_vec(vec![1.0,2.0,3.0,4.0]),DVector::::from_vec(vec![])); + DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_valid(DVector::::from_vec(vec![])); }).is_err() );