From ac7f495f566d590f60de9e8fee902e9e29ee3f08 Mon Sep 17 00:00:00 2001 From: Guilherme Salustiano Date: Tue, 30 Mar 2021 04:06:23 -0300 Subject: [PATCH] Add filter_2d --- src/third_party/image/kernel.rs | 54 +++++++++++++++++++++++++++++++ src/third_party/image/mod.rs | 1 + src/third_party/mod.rs | 2 ++ tests/lib.rs | 1 + tests/third_party/image/kernel.rs | 16 +++++++++ tests/third_party/image/mod.rs | 1 + tests/third_party/mod.rs | 1 + 7 files changed, 76 insertions(+) create mode 100644 src/third_party/image/kernel.rs create mode 100644 src/third_party/image/mod.rs create mode 100644 tests/third_party/image/kernel.rs create mode 100644 tests/third_party/image/mod.rs create mode 100644 tests/third_party/mod.rs diff --git a/src/third_party/image/kernel.rs b/src/third_party/image/kernel.rs new file mode 100644 index 00000000..b00a54ae --- /dev/null +++ b/src/third_party/image/kernel.rs @@ -0,0 +1,54 @@ +use crate::base::DMatrix; +use crate::storage::Storage; +use crate::{Dim, Dynamic, Matrix, Scalar}; +use num::Zero; +use std::ops::{AddAssign, Mul}; + +impl Matrix +where + N: Scalar + Zero + AddAssign + Mul + Copy, + R1: Dim, + C1: Dim, + SA: Storage, +{ + /// Returns the convolution of the target matrix and a kernel. + /// + /// # Arguments + /// + /// * `kernel` - A Matrix with size > 0 + /// + /// # Errors + /// Inputs must satisfy `matrix.len() >= matrix.len() > 0`. + /// + pub fn filter_2d(&self, kernel: Matrix) -> DMatrix + where + R2: Dim, + C2: Dim, + SB: Storage, + { + let mat_shape = self.shape(); + let ker_shape = kernel.shape(); + + if ker_shape == (0, 0) || ker_shape > mat_shape { + panic!("filter_2d expects `self.shape() >= kernel.shape() > 0`, received {:?} and {:?} respectively.", mat_shape, ker_shape); + } + + let result_shape = (mat_shape.0 - ker_shape.0 + 1, mat_shape.1 - ker_shape.1 + 1); + let mut conv = DMatrix::zeros_generic( + Dynamic::from_usize(result_shape.0), + Dynamic::from_usize(result_shape.1), + ); + + // TODO: optimize + for i in 0..(result_shape.0) { + for j in 0..(result_shape.1) { + for k in 0..(ker_shape.0) { + for l in 0..(ker_shape.1) { + conv[(i, j)] += self[(i + k, j + l)] * kernel[(k, l)] + } + } + } + } + conv + } +} diff --git a/src/third_party/image/mod.rs b/src/third_party/image/mod.rs new file mode 100644 index 00000000..a0119d93 --- /dev/null +++ b/src/third_party/image/mod.rs @@ -0,0 +1 @@ +mod kernel; diff --git a/src/third_party/mod.rs b/src/third_party/mod.rs index ce0fcaad..65d2ec69 100644 --- a/src/third_party/mod.rs +++ b/src/third_party/mod.rs @@ -4,3 +4,5 @@ mod alga; mod glam; #[cfg(feature = "mint")] mod mint; + +mod image; diff --git a/tests/lib.rs b/tests/lib.rs index add7a468..c32f32a6 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -22,6 +22,7 @@ mod linalg; #[cfg(feature = "proptest-support")] mod proptest; +mod third_party; //#[cfg(feature = "sparse")] //mod sparse; diff --git a/tests/third_party/image/kernel.rs b/tests/third_party/image/kernel.rs new file mode 100644 index 00000000..f806bdaa --- /dev/null +++ b/tests/third_party/image/kernel.rs @@ -0,0 +1,16 @@ +use na::{Matrix3, MatrixMN, U10, U8}; +use std::panic; + +#[test] +fn image_convolve_check() { + // Static Tests + let expect = MatrixMN::::from_element(18); + let src = MatrixMN::::from_element(2); + let kernel = Matrix3::::from_element(1); + let result = src.filter_2d(kernel); + println!("src: {}", src); + println!("ker: {}", kernel); + println!("res: {}", result); + + assert_eq!(result, expect); +} diff --git a/tests/third_party/image/mod.rs b/tests/third_party/image/mod.rs new file mode 100644 index 00000000..a0119d93 --- /dev/null +++ b/tests/third_party/image/mod.rs @@ -0,0 +1 @@ +mod kernel; diff --git a/tests/third_party/mod.rs b/tests/third_party/mod.rs new file mode 100644 index 00000000..a5b026f0 --- /dev/null +++ b/tests/third_party/mod.rs @@ -0,0 +1 @@ +mod image;