diff --git a/src/linalg/convolution.rs b/src/linalg/convolution.rs index c3256180..b5f95d19 100644 --- a/src/linalg/convolution.rs +++ b/src/linalg/convolution.rs @@ -4,10 +4,11 @@ use crate::base::allocator::Allocator; use crate::base::default_allocator::DefaultAllocator; use crate::base::dimension::{Dim, DimAdd, DimDiff, DimSub, DimSum, DimName}; use crate::storage::Storage; -use crate::{zero, RealField, Vector, VectorN, U1, Matrix, MatrixMN, DMatrix}; -use crate::alga::general::Field; +use crate::{RealField, Vector, VectorN, U1, Matrix, MatrixMN, DMatrix, Scalar, zero}; +use crate::alga::general::{ClosedMul, ClosedAdd, Identity, Additive}; +use crate::num::Zero; -impl> Vector { +impl + ClosedAdd + Zero + Identity, D1: Dim, S1: Storage> Vector { /// Returns the convolution of the target vector and a kernel.s /// /// # Arguments @@ -130,7 +131,7 @@ impl> Vector { } } -impl DMatrix { +impl + ClosedAdd + Zero + Identity> DMatrix { /// Returns the convolution of the target vector and a kernel. /// /// # Arguments @@ -164,7 +165,7 @@ impl DMatrix { } -impl MatrixMN where DefaultAllocator: Allocator { +impl + ClosedAdd + Zero + Identity, R1: Dim + DimName, C1: Dim + DimName> MatrixMN where DefaultAllocator: Allocator { /// Returns the convolution of the target vector and a kernel. /// /// # Arguments @@ -200,7 +201,7 @@ impl MatrixMN whe fn convolve(mat: &MatrixMN, kernel: &Matrix, target: &mut MatrixMN, mat_rows: i32, mat_cols: i32) where - N: RealField, + N: Scalar + ClosedMul + ClosedAdd + Zero + Identity, R1: Dim, C1: Dim, R2: Dim, diff --git a/tests/linalg/convolution.rs b/tests/linalg/convolution.rs index 3aa0ee2a..5b8f491c 100644 --- a/tests/linalg/convolution.rs +++ b/tests/linalg/convolution.rs @@ -16,6 +16,11 @@ fn convolve_same_check(){ assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); + let actual_s_int = Vector4::new(1, 4, 7, 10); + let expected_s_int = Vector4::new(1, 2, 3, 4).convolve_same(Vector2::new(1, 2)); + + assert_eq!(actual_s_int, expected_s_int); + // Dynamic Tests let actual_d = DVector::from_vec(vec![1.0, 4.0, 7.0, 10.0]); let expected_d = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0]).convolve_same(DVector::from_vec(vec![1.0, 2.0])); @@ -51,7 +56,12 @@ fn convolve_full_check(){ let actual_s = Vector5::new(1.0, 4.0, 7.0, 10.0, 8.0); let expected_s = Vector4::new(1.0, 2.0, 3.0, 4.0).convolve_full(Vector2::new(1.0, 2.0)); - assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); + assert!(relative_eq!(actual_s, expected_s)); + + let actual_s_int = Vector5::new(1, 4, 7, 10, 8); + let expected_s_int = Vector4::new(1, 2, 3, 4).convolve_full(Vector2::new(1, 2)); + + assert_eq!(actual_s_int, expected_s_int); // Dynamic Tests let actual_d = DVector::from_vec(vec![1.0, 4.0, 7.0, 10.0, 8.0]); @@ -141,6 +151,11 @@ fn convolve_same_mat_check(){ assert!(relative_eq!(actual_s, expected_s, epsilon = 1.0e-7)); + let actual_s_int = Matrix5::from_vec( vec![3,4,4,4,3,4,5,5,5,4,4,5,5,5,4,4,5,5,5,4,3,4,4,4,3]); + let expected_s_int = Matrix5::from_element(1).smat_convolve_full(Matrix3::from_vec(vec![0,1,0,1,1,1,0,1,0])); + + assert_eq!(actual_s_int, expected_s_int); + let actual_d = DMatrix::from_vec(5,5, vec![3.0,4.0,4.0,4.0,3.0,4.0,5.0,5.0,5.0,4.0,4.0,5.0,5.0,5.0,4.0,4.0,5.0,5.0,5.0,4.0,3.0,4.0,4.0,4.0,3.0]); let expected_d = DMatrix::from_element(5,5,1.0).dmat_convolve_full(DMatrix::from_vec(3,3,vec![0.0,1.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0]));