diff --git a/src/linalg/convolution.rs b/src/linalg/convolution.rs index 3f934eda..48894d3b 100644 --- a/src/linalg/convolution.rs +++ b/src/linalg/convolution.rs @@ -7,14 +7,14 @@ use crate::storage::Storage; use crate::{zero, RealField, Vector, VectorN, U1}; impl> Vector { - /// Returns the convolution of the target vector and a kernel + /// Returns the convolution of the target vector and a kernel. /// /// # Arguments /// /// * `kernel` - A Vector with size > 0 /// /// # Errors - /// Inputs must statisfy `vector.len() >= kernel.len() > 0`. + /// Inputs must satisfy `vector.len() >= kernel.len() > 0`. /// pub fn convolve_full( &self, @@ -53,7 +53,8 @@ impl> Vector { } conv } - /// Returns the convolution of the target vector and a kernel + /// Returns the convolution of the target vector and a kernel. + /// /// The output convolution consists only of those elements that do not rely on the zero-padding. /// # Arguments /// @@ -61,10 +62,9 @@ impl> Vector { /// /// /// # Errors - /// Inputs must statisfy `self.len() >= kernel.len() > 0`. + /// Inputs must satisfy `self.len() >= kernel.len() > 0`. /// - pub fn convolve_valid(&self, kernel: Vector, - ) -> VectorN, D2>> + pub fn convolve_valid(&self, kernel: Vector) -> VectorN, D2>> where D1: DimAdd, D2: Dim, @@ -90,20 +90,20 @@ impl> Vector { conv } - /// Returns the convolution of the targetvector and a kernel + /// Returns the convolution of the target vector and a kernel. + /// /// The output convolution is the same size as vector, centered with respect to the ‘full’ output. /// # Arguments /// /// * `kernel` - A Vector with size > 0 /// /// # Errors - /// Inputs must statisfy `self.len() >= kernel.len() > 0`. - pub fn convolve_same(&self, kernel: Vector) -> VectorN> + /// Inputs must satisfy `self.len() >= kernel.len() > 0`. + pub fn convolve_same(&self, kernel: Vector) -> VectorN where - D1: DimMax, - D2: DimMax>, + D2: Dim, S2: Storage, - DefaultAllocator: Allocator>, + DefaultAllocator: Allocator, { let vec = self.len(); let ker = kernel.len(); @@ -112,8 +112,7 @@ impl> Vector { panic!("convolve_same expects `self.len() >= kernel.len() > 0`, received {} and {} respectively.",vec,ker); } - let result_len = self.data.shape().0.max(kernel.data.shape().0); - let mut conv = VectorN::zeros_generic(result_len, U1); + let mut conv = VectorN::zeros_generic(self.data.shape().0, U1); for i in 0..vec { for j in 0..ker { @@ -125,6 +124,7 @@ impl> Vector { conv[i] += val * kernel[ker - j - 1]; } } + conv } } diff --git a/tests/linalg/convolution.rs b/tests/linalg/convolution.rs index b0d57f72..65380162 100644 --- a/tests/linalg/convolution.rs +++ b/tests/linalg/convolution.rs @@ -11,14 +11,14 @@ use std::panic; #[test] fn convolve_same_check(){ // Static Tests - let actual_s = Vector4::from_vec(vec![1.0,4.0,7.0,10.0]); - let expected_s = Vector4::new(1.0,2.0,3.0,4.0).convolve_same(Vector2::new(1.0,2.0)); + let actual_s = Vector4::new(1.0, 4.0, 7.0, 10.0); + let expected_s = Vector4::new(1.0, 2.0, 3.0, 4.0).convolve_same(Vector2::new(1.0, 2.0)); assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); // Dynamic Tests - let actual_d = DVector::from_vec(vec![1.0,4.0,7.0,10.0]); - let expected_d = DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_same(DVector::from_vec(vec![1.0,2.0])); + let actual_d = DVector::from_vec(vec![1.0, 4.0, 7.0, 10.0]); + let expected_d = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0]).convolve_same(DVector::from_vec(vec![1.0, 2.0])); assert!(relative_eq!(actual_d, expected_d, epsilon = 1.0e-7)); @@ -26,19 +26,19 @@ fn convolve_same_check(){ // These really only apply to dynamic sized vectors assert!( panic::catch_unwind(|| { - DVector::from_vec(vec![1.0,2.0]).convolve_same(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::from_vec(vec![1.0, 2.0]).convolve_same(DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - DVector::::from_vec(vec![]).convolve_same(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::::from_vec(vec![]).convolve_same(DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_same(DVector::::from_vec(vec![])); + DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0]).convolve_same(DVector::::from_vec(vec![])); }).is_err() ); } @@ -48,14 +48,14 @@ fn convolve_same_check(){ #[test] fn convolve_full_check(){ // Static Tests - let actual_s = Vector5::new(1.0,4.0,7.0,10.0,8.0); - let expected_s = Vector4::new(1.0,2.0,3.0,4.0).convolve_full(Vector2::new(1.0,2.0)); + let actual_s = Vector5::new(1.0, 4.0, 7.0, 10.0, 8.0); + let expected_s = Vector4::new(1.0, 2.0, 3.0, 4.0).convolve_full(Vector2::new(1.0, 2.0)); assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); // Dynamic Tests - let actual_d = DVector::from_vec(vec![1.0,4.0,7.0,10.0,8.0]); - let expected_d = DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_full(DVector::from_vec(vec![1.0,2.0])); + let actual_d = DVector::from_vec(vec![1.0, 4.0, 7.0, 10.0, 8.0]); + let expected_d = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0]).convolve_full(DVector::from_vec(vec![1.0, 2.0])); assert!(relative_eq!(actual_d, expected_d, epsilon = 1.0e-7)); @@ -63,36 +63,36 @@ fn convolve_full_check(){ // These really only apply to dynamic sized vectors assert!( panic::catch_unwind(|| { - DVector::from_vec(vec![1.0,2.0]).convolve_full(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::from_vec(vec![1.0, 2.0] ).convolve_full(DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0] )); }).is_err() ); assert!( panic::catch_unwind(|| { - DVector::::from_vec(vec![]).convolve_full(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::::from_vec(vec![]).convolve_full(DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0] )); }).is_err() ); assert!( panic::catch_unwind(|| { - DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_full(DVector::::from_vec(vec![])); + DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0] ).convolve_full(DVector::::from_vec(vec![])); }).is_err() ); } -// >>> convolve([1,2,3,4],[1,2],"valid") -// array([ 4, 7, 10]) +// >>> convolve([1, 2, 3, 4],[1, 2],"valid") +// array([4, 7, 10]) #[test] fn convolve_valid_check(){ // Static Tests - let actual_s = Vector3::from_vec(vec![4.0,7.0,10.0]); - let expected_s = Vector4::new(1.0,2.0,3.0,4.0).convolve_valid( Vector2::new(1.0,2.0)); + let actual_s = Vector3::from_vec(vec![4.0, 7.0, 10.0]); + let expected_s = Vector4::new(1.0, 2.0, 3.0, 4.0).convolve_valid( Vector2::new(1.0, 2.0)); assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); // Dynamic Tests - let actual_d = DVector::from_vec(vec![4.0,7.0,10.0]); - let expected_d = DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_valid(DVector::from_vec(vec![1.0,2.0])); + let actual_d = DVector::from_vec(vec![4.0, 7.0, 10.0]); + let expected_d = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0]).convolve_valid(DVector::from_vec(vec![1.0, 2.0])); assert!(relative_eq!(actual_d, expected_d, epsilon = 1.0e-7)); @@ -100,19 +100,19 @@ fn convolve_valid_check(){ // These really only apply to dynamic sized vectors assert!( panic::catch_unwind(|| { - DVector::from_vec(vec![1.0,2.0]).convolve_valid(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::from_vec(vec![1.0, 2.0]).convolve_valid(DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - DVector::::from_vec(vec![]).convolve_valid(DVector::from_vec(vec![1.0,2.0,3.0,4.0])); + DVector::::from_vec(vec![]).convolve_valid(DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0])); }).is_err() ); assert!( panic::catch_unwind(|| { - DVector::from_vec(vec![1.0,2.0,3.0,4.0]).convolve_valid(DVector::::from_vec(vec![])); + DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0]).convolve_valid(DVector::::from_vec(vec![])); }).is_err() );