Add kronecker product.

Closes #248
This commit is contained in:
Sébastien Crozet 2017-05-03 22:27:05 +02:00 committed by Sébastien Crozet
parent 1cdad4c7c6
commit 35d2b6dc88
3 changed files with 94 additions and 6 deletions

View File

@ -4,6 +4,12 @@ documented here.
This project adheres to [Semantic Versioning](http://semver.org/). This project adheres to [Semantic Versioning](http://semver.org/).
## [0.13.0] - WIP
### Added
* `.kronecker(a, b)` computes the kronecker product (i.e. matrix tensor
product) of two matrices.
## [0.12.0] ## [0.12.0]
The main change of this release is the update of the dependency serde to 1.0. The main change of this release is the update of the dependency serde to 1.0.

View File

@ -5,7 +5,7 @@ use num::Zero;
use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg}; use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg};
use core::{Scalar, Matrix, OwnedMatrix, MatrixSum, MatrixMul, MatrixTrMul}; use core::{Scalar, Matrix, OwnedMatrix, MatrixSum, MatrixMul, MatrixTrMul};
use core::dimension::Dim; use core::dimension::{Dim, DimMul, DimProd};
use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable}; use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable};
use core::storage::{Storage, StorageMut, OwnedStorage}; use core::storage::{Storage, StorageMut, OwnedStorage};
use core::allocator::{SameShapeAllocator, Allocator, OwnedAllocator}; use core::allocator::{SameShapeAllocator, Allocator, OwnedAllocator};
@ -453,13 +453,15 @@ impl<'b, N, R1, C1, R2, SA, SB> MulAssign<&'b Matrix<N, R2, C1, SB>> for Matrix<
} }
// Transpose-multiplication.
impl<N, R1: Dim, C1: Dim, SA> Matrix<N, R1, C1, SA> impl<N, R1: Dim, C1: Dim, SA> Matrix<N, R1, C1, SA>
where N: Scalar + Zero + ClosedAdd + ClosedMul, where N: Scalar,
SA: Storage<N, R1, C1> { SA: Storage<N, R1, C1> {
/// Equivalent to `self.transpose() * right`. /// Equivalent to `self.transpose() * right`.
#[inline] #[inline]
pub fn tr_mul<R2: Dim, C2: Dim, SB>(&self, right: &Matrix<N, R2, C2, SB>) -> MatrixTrMul<N, R1, C1, C2, SA> pub fn tr_mul<R2: Dim, C2: Dim, SB>(&self, right: &Matrix<N, R2, C2, SB>) -> MatrixTrMul<N, R1, C1, C2, SA>
where SB: Storage<N, R2, C2>, where N: Zero + ClosedAdd + ClosedMul,
SB: Storage<N, R2, C2>,
SA::Alloc: Allocator<N, C1, C2>, SA::Alloc: Allocator<N, C1, C2>,
ShapeConstraint: AreMultipliable<C1, R1, R2, C2> { ShapeConstraint: AreMultipliable<C1, R1, R2, C2> {
let (nrows1, ncols1) = self.shape(); let (nrows1, ncols1) = self.shape();
@ -477,7 +479,7 @@ impl<N, R1: Dim, C1: Dim, SA> Matrix<N, R1, C1, SA>
unsafe { unsafe {
for k in 0 .. nrows1 { for k in 0 .. nrows1 {
acc = acc + *self.get_unchecked(k, i) * *right.get_unchecked(k, j); acc += *self.get_unchecked(k, i) * *right.get_unchecked(k, j);
} }
*res.get_unchecked_mut(i, j) = acc; *res.get_unchecked_mut(i, j) = acc;
@ -487,4 +489,42 @@ impl<N, R1: Dim, C1: Dim, SA> Matrix<N, R1, C1, SA>
res res
} }
/// The kronecker product of two matrices (aka. tensor product of the corresponding linear
/// maps).
pub fn kronecker<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<N, R2, C2, SB>)
-> OwnedMatrix<N, DimProd<R1, R2>, DimProd<C1, C2>, SA::Alloc>
where N: ClosedMul,
R1: DimMul<R2>,
C1: DimMul<C2>,
SB: Storage<N, R2, C2>,
SA::Alloc: Allocator<N, DimProd<R1, R2>, DimProd<C1, C2>> {
let (nrows1, ncols1) = self.data.shape();
let (nrows2, ncols2) = rhs.data.shape();
let mut res: OwnedMatrix<_, _, _, SA::Alloc> =
unsafe { Matrix::new_uninitialized_generic(nrows1.mul(nrows2), ncols1.mul(ncols2)) };
{
let mut data_res = res.data.ptr_mut();
for j1 in 0 .. ncols1.value() {
for j2 in 0 .. ncols2.value() {
for i1 in 0 .. nrows1.value() {
unsafe {
let coeff = *self.get_unchecked(i1, j1);
for i2 in 0 .. nrows2.value() {
*data_res = coeff * *rhs.get_unchecked(i2, j2);
data_res = data_res.offset(1);
}
}
}
}
}
}
res
}
} }

View File

@ -12,11 +12,12 @@ use std::fmt::Display;
use alga::linear::FiniteDimInnerSpace; use alga::linear::FiniteDimInnerSpace;
use na::{DVector, DMatrix, use na::{U8, U15,
DVector, DMatrix,
Vector1, Vector2, Vector3, Vector4, Vector5, Vector6, Vector1, Vector2, Vector3, Vector4, Vector5, Vector6,
RowVector4, RowVector4,
Matrix1, Matrix2, Matrix3, Matrix4, Matrix5, Matrix6, Matrix1, Matrix2, Matrix3, Matrix4, Matrix5, Matrix6,
Matrix2x3, Matrix3x2, Matrix3x4, Matrix4x3, Matrix2x4, Matrix4x6}; MatrixNM, Matrix2x3, Matrix3x2, Matrix3x4, Matrix4x3, Matrix2x4, Matrix4x5, Matrix4x6};
#[test] #[test]
@ -433,6 +434,47 @@ fn components_mut() {
assert_eq!(expected_m3, m3); assert_eq!(expected_m3, m3);
} }
#[test]
fn kronecker() {
let a = Matrix2x3::new(
11, 12, 13,
21, 22, 23);
let b = Matrix4x5::new(
110, 120, 130, 140, 150,
210, 220, 230, 240, 250,
310, 320, 330, 340, 350,
410, 420, 430, 440, 450);
let expected = MatrixNM::<_, U8, U15>::from_row_slice(&[
1210, 1320, 1430, 1540, 1650, 1320, 1440, 1560, 1680, 1800, 1430, 1560, 1690, 1820, 1950,
2310, 2420, 2530, 2640, 2750, 2520, 2640, 2760, 2880, 3000, 2730, 2860, 2990, 3120, 3250,
3410, 3520, 3630, 3740, 3850, 3720, 3840, 3960, 4080, 4200, 4030, 4160, 4290, 4420, 4550,
4510, 4620, 4730, 4840, 4950, 4920, 5040, 5160, 5280, 5400, 5330, 5460, 5590, 5720, 5850,
2310, 2520, 2730, 2940, 3150, 2420, 2640, 2860, 3080, 3300, 2530, 2760, 2990, 3220, 3450,
4410, 4620, 4830, 5040, 5250, 4620, 4840, 5060, 5280, 5500, 4830, 5060, 5290, 5520, 5750,
6510, 6720, 6930, 7140, 7350, 6820, 7040, 7260, 7480, 7700, 7130, 7360, 7590, 7820, 8050,
8610, 8820, 9030, 9240, 9450, 9020, 9240, 9460, 9680, 9900, 9430, 9660, 9890, 10120, 10350 ]);
let computed = a.kronecker(&b);
assert_eq!(computed, expected);
let a = Vector2::new(1, 2);
let b = Vector3::new(10, 20, 30);
let expected = Vector6::new(10, 20, 30, 20, 40, 60);
assert_eq!(a.kronecker(&b), expected);
let a = Vector2::new(1, 2);
let b = RowVector4::new(10, 20, 30, 40);
let expected = Matrix2x4::new(
10, 20, 30, 40,
20, 40, 60, 80);
assert_eq!(a.kronecker(&b), expected);
}
#[cfg(feature = "arbitrary")] #[cfg(feature = "arbitrary")]
quickcheck!{ quickcheck!{
/* /*