Add filter_2d

This commit is contained in:
Guilherme Salustiano 2021-03-30 04:06:23 -03:00
parent e9535d5cb5
commit ac7f495f56
No known key found for this signature in database
GPG Key ID: FB947B826DBAA61A
7 changed files with 76 additions and 0 deletions

54
src/third_party/image/kernel.rs vendored Normal file
View File

@ -0,0 +1,54 @@
use crate::base::DMatrix;
use crate::storage::Storage;
use crate::{Dim, Dynamic, Matrix, Scalar};
use num::Zero;
use std::ops::{AddAssign, Mul};
impl<N, R1, C1, SA> Matrix<N, R1, C1, SA>
where
N: Scalar + Zero + AddAssign + Mul<Output = N> + Copy,
R1: Dim,
C1: Dim,
SA: Storage<N, R1, C1>,
{
/// Returns the convolution of the target matrix and a kernel.
///
/// # Arguments
///
/// * `kernel` - A Matrix with size > 0
///
/// # Errors
/// Inputs must satisfy `matrix.len() >= matrix.len() > 0`.
///
pub fn filter_2d<R2, C2, SB>(&self, kernel: Matrix<N, R2, C2, SB>) -> DMatrix<N>
where
R2: Dim,
C2: Dim,
SB: Storage<N, R2, C2>,
{
let mat_shape = self.shape();
let ker_shape = kernel.shape();
if ker_shape == (0, 0) || ker_shape > mat_shape {
panic!("filter_2d expects `self.shape() >= kernel.shape() > 0`, received {:?} and {:?} respectively.", mat_shape, ker_shape);
}
let result_shape = (mat_shape.0 - ker_shape.0 + 1, mat_shape.1 - ker_shape.1 + 1);
let mut conv = DMatrix::zeros_generic(
Dynamic::from_usize(result_shape.0),
Dynamic::from_usize(result_shape.1),
);
// TODO: optimize
for i in 0..(result_shape.0) {
for j in 0..(result_shape.1) {
for k in 0..(ker_shape.0) {
for l in 0..(ker_shape.1) {
conv[(i, j)] += self[(i + k, j + l)] * kernel[(k, l)]
}
}
}
}
conv
}
}

1
src/third_party/image/mod.rs vendored Normal file
View File

@ -0,0 +1 @@
mod kernel;

View File

@ -4,3 +4,5 @@ mod alga;
mod glam;
#[cfg(feature = "mint")]
mod mint;
mod image;

View File

@ -22,6 +22,7 @@ mod linalg;
#[cfg(feature = "proptest-support")]
mod proptest;
mod third_party;
//#[cfg(feature = "sparse")]
//mod sparse;

16
tests/third_party/image/kernel.rs vendored Normal file
View File

@ -0,0 +1,16 @@
use na::{Matrix3, MatrixMN, U10, U8};
use std::panic;
#[test]
fn image_convolve_check() {
// Static Tests
let expect = MatrixMN::<usize, U8, U8>::from_element(18);
let src = MatrixMN::<usize, U10, U10>::from_element(2);
let kernel = Matrix3::<usize>::from_element(1);
let result = src.filter_2d(kernel);
println!("src: {}", src);
println!("ker: {}", kernel);
println!("res: {}", result);
assert_eq!(result, expect);
}

1
tests/third_party/image/mod.rs vendored Normal file
View File

@ -0,0 +1 @@
mod kernel;

1
tests/third_party/mod.rs vendored Normal file
View File

@ -0,0 +1 @@
mod image;