diff --git a/src/linalg/convolution.rs b/src/linalg/convolution.rs index 60798b6c..eba9dec9 100644 --- a/src/linalg/convolution.rs +++ b/src/linalg/convolution.rs @@ -130,13 +130,12 @@ impl> Vector { } } - impl DMatrix { /// Returns the convolution of the target vector and a kernel. /// /// # Arguments /// - /// * `kernel` - A Matrix with rows > 0 and cols > 0 + /// * `kernel` - A Matrix with rows > 0 and cols > 0 rows == cols /// /// # Errors /// Inputs must satisfy `self.shape() >= kernel.shape() > 0`. @@ -152,49 +151,11 @@ impl DMatrix { { let mat_rows = self.nrows() as i32; let mat_cols = self.ncols() as i32; - let ker_rows = kernel.data.shape().0.value() as i32; - let ker_cols = kernel.data.shape().1.value() as i32; - if ker_rows == 0 || ker_rows > mat_rows || ker_cols == 0|| ker_cols > mat_cols { - panic!( - "convolve_full expects `self.nrows() >= kernel.nrows() > 0 and self.ncols() >= kernel.ncols() > 0 `, \ - rows received {} and {} respectively. \ - cols received {} and {} respectively.", - mat_rows, ker_rows, mat_cols, ker_cols); - } - - let kernel_size = ker_rows; - let kernel_min = kernel_size/2; - let zero = zero::(); let mut conv = DMatrix::::zeros(mat_cols as usize, mat_rows as usize); - for i in 0..mat_rows { - for j in 0..mat_cols { - for k_i in 0..kernel_size { - for k_j in 0..kernel_size { - let i_matrix = i + k_i - kernel_min; - let j_matrix = j + k_j - kernel_min; - let is_i_in_range = i_matrix >=0 && i_matrix < mat_rows; - let is_j_in_range = j_matrix >=0 && j_matrix < mat_cols; - - let convolved_value = - match is_i_in_range && is_j_in_range { - true => { - let pixel_value = *self.index((i_matrix as usize, j_matrix as usize)); - let kernel_value = *kernel.index((k_i as usize,k_j as usize)); - kernel_value*pixel_value - } - //TODO: More behaviour on borders - false => zero - }; - - *conv.index_mut((i as usize,j as usize)) += convolved_value; - } - } - - } - } + convolve(&self, &kernel,&mut conv,mat_rows,mat_cols); conv } @@ -209,37 +170,65 @@ impl MatrixMN where /// /// # Arguments /// - /// * `kernel` - A Matrix with rows > 0 and cols > 0 + /// * `kernel` - A Matrix with rows > 0 and cols > 0 and rows == cols /// /// # Errors /// Inputs must satisfy `self.shape() >= kernel.shape() > 0`. /// - pub fn smat_convolve_full( + pub fn smat_convolve_full( &self, - kernel: Matrix, //TODO: Would be nice to have an IsOdd trait. As kernels could be of even size atm + kernel: Matrix, //TODO: Would be nice to have an IsOdd trait. As kernels could be of even size atm ) -> MatrixMN where R2: Dim, C2: Dim, - S1: Storage + S2: Storage { + + let mat_rows = self.nrows() as i32; let mat_cols = self.ncols() as i32; + + let mut conv = MatrixMN::::zeros(); + + convolve(&self, &kernel,&mut conv,mat_rows,mat_cols); + + + conv + } + + //TODO: rest ? + + +} + + +fn convolve(mat: &MatrixMN, kernel: &Matrix, target: &mut MatrixMN, mat_rows: i32, mat_cols: i32) + where + N: RealField, + R1: Dim, + C1: Dim, + R2: Dim, + C2: Dim, + S2: Storage, + DefaultAllocator: Allocator + { + let ker_rows = kernel.data.shape().0.value() as i32; let ker_cols = kernel.data.shape().1.value() as i32; - if ker_rows == 0 || ker_rows > mat_rows || ker_cols == 0|| ker_cols > mat_cols { + if ker_rows == 0 || ker_rows > mat_rows || ker_cols == 0 || ker_cols > mat_cols || ker_cols != ker_rows { panic!( - "convolve_full expects `self.nrows() >= kernel.nrows() > 0 and self.ncols() >= kernel.ncols() > 0 `, \ + "convolve_full expects `self.nrows() >= kernel.nrows() > 0 and self.ncols() >= kernel.ncols() > 0 and kernel.nrows() == kernel.ncols() `, \ rows received {} and {} respectively. \ cols received {} and {} respectively.", mat_rows, ker_rows, mat_cols, ker_cols); } + let kernel_size = ker_rows; let kernel_min = kernel_size/2; let zero = zero::(); - let mut conv = MatrixMN::::zeros(); for i in 0..mat_rows { for j in 0..mat_cols { @@ -254,7 +243,7 @@ impl MatrixMN where let convolved_value = match is_i_in_range && is_j_in_range { true => { - let pixel_value = *self.index((i_matrix as usize, j_matrix as usize)); + let pixel_value = *mat.index((i_matrix as usize, j_matrix as usize)); let kernel_value = *kernel.index((k_i as usize,k_j as usize)); kernel_value*pixel_value } @@ -262,17 +251,9 @@ impl MatrixMN where false => zero }; - *conv.index_mut((i as usize,j as usize)) += convolved_value; + *target.index_mut((i as usize,j as usize)) += convolved_value; } } - } } - - conv } - - //TODO: rest ? - - -}