Add kronecker product.

Closes #248
This commit is contained in:
Sébastien Crozet 2017-05-03 22:27:05 +02:00 committed by Sébastien Crozet
parent 1cdad4c7c6
commit 35d2b6dc88
3 changed files with 94 additions and 6 deletions

View File

@ -4,6 +4,12 @@ documented here.
This project adheres to [Semantic Versioning](http://semver.org/).
## [0.13.0] - WIP
### Added
* `.kronecker(a, b)` computes the kronecker product (i.e. matrix tensor
product) of two matrices.
## [0.12.0]
The main change of this release is the update of the dependency serde to 1.0.

View File

@ -5,7 +5,7 @@ use num::Zero;
use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg};
use core::{Scalar, Matrix, OwnedMatrix, MatrixSum, MatrixMul, MatrixTrMul};
use core::dimension::Dim;
use core::dimension::{Dim, DimMul, DimProd};
use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable};
use core::storage::{Storage, StorageMut, OwnedStorage};
use core::allocator::{SameShapeAllocator, Allocator, OwnedAllocator};
@ -453,13 +453,15 @@ impl<'b, N, R1, C1, R2, SA, SB> MulAssign<&'b Matrix<N, R2, C1, SB>> for Matrix<
}
// Transpose-multiplication.
impl<N, R1: Dim, C1: Dim, SA> Matrix<N, R1, C1, SA>
where N: Scalar + Zero + ClosedAdd + ClosedMul,
where N: Scalar,
SA: Storage<N, R1, C1> {
/// Equivalent to `self.transpose() * right`.
#[inline]
pub fn tr_mul<R2: Dim, C2: Dim, SB>(&self, right: &Matrix<N, R2, C2, SB>) -> MatrixTrMul<N, R1, C1, C2, SA>
where SB: Storage<N, R2, C2>,
where N: Zero + ClosedAdd + ClosedMul,
SB: Storage<N, R2, C2>,
SA::Alloc: Allocator<N, C1, C2>,
ShapeConstraint: AreMultipliable<C1, R1, R2, C2> {
let (nrows1, ncols1) = self.shape();
@ -477,7 +479,7 @@ impl<N, R1: Dim, C1: Dim, SA> Matrix<N, R1, C1, SA>
unsafe {
for k in 0 .. nrows1 {
acc = acc + *self.get_unchecked(k, i) * *right.get_unchecked(k, j);
acc += *self.get_unchecked(k, i) * *right.get_unchecked(k, j);
}
*res.get_unchecked_mut(i, j) = acc;
@ -487,4 +489,42 @@ impl<N, R1: Dim, C1: Dim, SA> Matrix<N, R1, C1, SA>
res
}
/// The kronecker product of two matrices (aka. tensor product of the corresponding linear
/// maps).
pub fn kronecker<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<N, R2, C2, SB>)
-> OwnedMatrix<N, DimProd<R1, R2>, DimProd<C1, C2>, SA::Alloc>
where N: ClosedMul,
R1: DimMul<R2>,
C1: DimMul<C2>,
SB: Storage<N, R2, C2>,
SA::Alloc: Allocator<N, DimProd<R1, R2>, DimProd<C1, C2>> {
let (nrows1, ncols1) = self.data.shape();
let (nrows2, ncols2) = rhs.data.shape();
let mut res: OwnedMatrix<_, _, _, SA::Alloc> =
unsafe { Matrix::new_uninitialized_generic(nrows1.mul(nrows2), ncols1.mul(ncols2)) };
{
let mut data_res = res.data.ptr_mut();
for j1 in 0 .. ncols1.value() {
for j2 in 0 .. ncols2.value() {
for i1 in 0 .. nrows1.value() {
unsafe {
let coeff = *self.get_unchecked(i1, j1);
for i2 in 0 .. nrows2.value() {
*data_res = coeff * *rhs.get_unchecked(i2, j2);
data_res = data_res.offset(1);
}
}
}
}
}
}
res
}
}

View File

@ -12,11 +12,12 @@ use std::fmt::Display;
use alga::linear::FiniteDimInnerSpace;
use na::{DVector, DMatrix,
use na::{U8, U15,
DVector, DMatrix,
Vector1, Vector2, Vector3, Vector4, Vector5, Vector6,
RowVector4,
Matrix1, Matrix2, Matrix3, Matrix4, Matrix5, Matrix6,
Matrix2x3, Matrix3x2, Matrix3x4, Matrix4x3, Matrix2x4, Matrix4x6};
MatrixNM, Matrix2x3, Matrix3x2, Matrix3x4, Matrix4x3, Matrix2x4, Matrix4x5, Matrix4x6};
#[test]
@ -433,6 +434,47 @@ fn components_mut() {
assert_eq!(expected_m3, m3);
}
#[test]
fn kronecker() {
let a = Matrix2x3::new(
11, 12, 13,
21, 22, 23);
let b = Matrix4x5::new(
110, 120, 130, 140, 150,
210, 220, 230, 240, 250,
310, 320, 330, 340, 350,
410, 420, 430, 440, 450);
let expected = MatrixNM::<_, U8, U15>::from_row_slice(&[
1210, 1320, 1430, 1540, 1650, 1320, 1440, 1560, 1680, 1800, 1430, 1560, 1690, 1820, 1950,
2310, 2420, 2530, 2640, 2750, 2520, 2640, 2760, 2880, 3000, 2730, 2860, 2990, 3120, 3250,
3410, 3520, 3630, 3740, 3850, 3720, 3840, 3960, 4080, 4200, 4030, 4160, 4290, 4420, 4550,
4510, 4620, 4730, 4840, 4950, 4920, 5040, 5160, 5280, 5400, 5330, 5460, 5590, 5720, 5850,
2310, 2520, 2730, 2940, 3150, 2420, 2640, 2860, 3080, 3300, 2530, 2760, 2990, 3220, 3450,
4410, 4620, 4830, 5040, 5250, 4620, 4840, 5060, 5280, 5500, 4830, 5060, 5290, 5520, 5750,
6510, 6720, 6930, 7140, 7350, 6820, 7040, 7260, 7480, 7700, 7130, 7360, 7590, 7820, 8050,
8610, 8820, 9030, 9240, 9450, 9020, 9240, 9460, 9680, 9900, 9430, 9660, 9890, 10120, 10350 ]);
let computed = a.kronecker(&b);
assert_eq!(computed, expected);
let a = Vector2::new(1, 2);
let b = Vector3::new(10, 20, 30);
let expected = Vector6::new(10, 20, 30, 20, 40, 60);
assert_eq!(a.kronecker(&b), expected);
let a = Vector2::new(1, 2);
let b = RowVector4::new(10, 20, 30, 40);
let expected = Matrix2x4::new(
10, 20, 30, 40,
20, 40, 60, 80);
assert_eq!(a.kronecker(&b), expected);
}
#[cfg(feature = "arbitrary")]
quickcheck!{
/*