From 35d2b6dc88ddac6b18ab303f48d10e1826842f01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Wed, 3 May 2017 22:27:05 +0200 Subject: [PATCH] Add kronecker product. Closes #248 --- CHANGELOG.md | 6 ++++++ src/core/ops.rs | 48 ++++++++++++++++++++++++++++++++++++++++++++---- tests/matrix.rs | 46 ++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 94 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ef19d0e..04479117 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ documented here. This project adheres to [Semantic Versioning](http://semver.org/). +## [0.13.0] - WIP +### Added + * `.kronecker(a, b)` computes the kronecker product (i.e. matrix tensor + product) of two matrices. + + ## [0.12.0] The main change of this release is the update of the dependency serde to 1.0. diff --git a/src/core/ops.rs b/src/core/ops.rs index f310f68b..a530e066 100644 --- a/src/core/ops.rs +++ b/src/core/ops.rs @@ -5,7 +5,7 @@ use num::Zero; use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg}; use core::{Scalar, Matrix, OwnedMatrix, MatrixSum, MatrixMul, MatrixTrMul}; -use core::dimension::Dim; +use core::dimension::{Dim, DimMul, DimProd}; use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable}; use core::storage::{Storage, StorageMut, OwnedStorage}; use core::allocator::{SameShapeAllocator, Allocator, OwnedAllocator}; @@ -453,13 +453,15 @@ impl<'b, N, R1, C1, R2, SA, SB> MulAssign<&'b Matrix> for Matrix< } +// Transpose-multiplication. impl Matrix - where N: Scalar + Zero + ClosedAdd + ClosedMul, + where N: Scalar, SA: Storage { /// Equivalent to `self.transpose() * right`. #[inline] pub fn tr_mul(&self, right: &Matrix) -> MatrixTrMul - where SB: Storage, + where N: Zero + ClosedAdd + ClosedMul, + SB: Storage, SA::Alloc: Allocator, ShapeConstraint: AreMultipliable { let (nrows1, ncols1) = self.shape(); @@ -477,7 +479,7 @@ impl Matrix unsafe { for k in 0 .. nrows1 { - acc = acc + *self.get_unchecked(k, i) * *right.get_unchecked(k, j); + acc += *self.get_unchecked(k, i) * *right.get_unchecked(k, j); } *res.get_unchecked_mut(i, j) = acc; @@ -487,4 +489,42 @@ impl Matrix res } + + + /// The kronecker product of two matrices (aka. tensor product of the corresponding linear + /// maps). + pub fn kronecker(&self, rhs: &Matrix) + -> OwnedMatrix, DimProd, SA::Alloc> + where N: ClosedMul, + R1: DimMul, + C1: DimMul, + SB: Storage, + SA::Alloc: Allocator, DimProd> { + let (nrows1, ncols1) = self.data.shape(); + let (nrows2, ncols2) = rhs.data.shape(); + + let mut res: OwnedMatrix<_, _, _, SA::Alloc> = + unsafe { Matrix::new_uninitialized_generic(nrows1.mul(nrows2), ncols1.mul(ncols2)) }; + + { + let mut data_res = res.data.ptr_mut(); + + for j1 in 0 .. ncols1.value() { + for j2 in 0 .. ncols2.value() { + for i1 in 0 .. nrows1.value() { + unsafe { + let coeff = *self.get_unchecked(i1, j1); + + for i2 in 0 .. nrows2.value() { + *data_res = coeff * *rhs.get_unchecked(i2, j2); + data_res = data_res.offset(1); + } + } + } + } + } + } + + res + } } diff --git a/tests/matrix.rs b/tests/matrix.rs index dcef35d7..bc5d44e5 100644 --- a/tests/matrix.rs +++ b/tests/matrix.rs @@ -12,11 +12,12 @@ use std::fmt::Display; use alga::linear::FiniteDimInnerSpace; -use na::{DVector, DMatrix, +use na::{U8, U15, + DVector, DMatrix, Vector1, Vector2, Vector3, Vector4, Vector5, Vector6, RowVector4, Matrix1, Matrix2, Matrix3, Matrix4, Matrix5, Matrix6, - Matrix2x3, Matrix3x2, Matrix3x4, Matrix4x3, Matrix2x4, Matrix4x6}; + MatrixNM, Matrix2x3, Matrix3x2, Matrix3x4, Matrix4x3, Matrix2x4, Matrix4x5, Matrix4x6}; #[test] @@ -433,6 +434,47 @@ fn components_mut() { assert_eq!(expected_m3, m3); } +#[test] +fn kronecker() { + let a = Matrix2x3::new( + 11, 12, 13, + 21, 22, 23); + + let b = Matrix4x5::new( + 110, 120, 130, 140, 150, + 210, 220, 230, 240, 250, + 310, 320, 330, 340, 350, + 410, 420, 430, 440, 450); + + let expected = MatrixNM::<_, U8, U15>::from_row_slice(&[ + 1210, 1320, 1430, 1540, 1650, 1320, 1440, 1560, 1680, 1800, 1430, 1560, 1690, 1820, 1950, + 2310, 2420, 2530, 2640, 2750, 2520, 2640, 2760, 2880, 3000, 2730, 2860, 2990, 3120, 3250, + 3410, 3520, 3630, 3740, 3850, 3720, 3840, 3960, 4080, 4200, 4030, 4160, 4290, 4420, 4550, + 4510, 4620, 4730, 4840, 4950, 4920, 5040, 5160, 5280, 5400, 5330, 5460, 5590, 5720, 5850, + 2310, 2520, 2730, 2940, 3150, 2420, 2640, 2860, 3080, 3300, 2530, 2760, 2990, 3220, 3450, + 4410, 4620, 4830, 5040, 5250, 4620, 4840, 5060, 5280, 5500, 4830, 5060, 5290, 5520, 5750, + 6510, 6720, 6930, 7140, 7350, 6820, 7040, 7260, 7480, 7700, 7130, 7360, 7590, 7820, 8050, + 8610, 8820, 9030, 9240, 9450, 9020, 9240, 9460, 9680, 9900, 9430, 9660, 9890, 10120, 10350 ]); + + let computed = a.kronecker(&b); + + assert_eq!(computed, expected); + + let a = Vector2::new(1, 2); + let b = Vector3::new(10, 20, 30); + let expected = Vector6::new(10, 20, 30, 20, 40, 60); + + assert_eq!(a.kronecker(&b), expected); + + let a = Vector2::new(1, 2); + let b = RowVector4::new(10, 20, 30, 40); + let expected = Matrix2x4::new( + 10, 20, 30, 40, + 20, 40, 60, 80); + + assert_eq!(a.kronecker(&b), expected); +} + #[cfg(feature = "arbitrary")] quickcheck!{ /*