Add matrix exponential for complex matrices (#744)

Added matrix exponential for complex matrices.
This commit is contained in:
danielschlaugies 2020-07-16 10:29:52 +02:00 committed by GitHub
parent bc70258e5c
commit f9f7169558
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 35 deletions

View File

@ -10,10 +10,12 @@ use crate::{
convert, try_convert, ComplexField, MatrixN, RealField, convert, try_convert, ComplexField, MatrixN, RealField,
}; };
use crate::num::Zero;
// https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py // https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py
struct ExpmPadeHelper<N, D> struct ExpmPadeHelper<N, D>
where where
N: RealField, N: ComplexField,
D: DimMin<D>, D: DimMin<D>,
DefaultAllocator: Allocator<N, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>, DefaultAllocator: Allocator<N, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
{ {
@ -27,20 +29,20 @@ where
a8: Option<MatrixN<N, D>>, a8: Option<MatrixN<N, D>>,
a10: Option<MatrixN<N, D>>, a10: Option<MatrixN<N, D>>,
d4_exact: Option<N>, d4_exact: Option<N::RealField>,
d6_exact: Option<N>, d6_exact: Option<N::RealField>,
d8_exact: Option<N>, d8_exact: Option<N::RealField>,
d10_exact: Option<N>, d10_exact: Option<N::RealField>,
d4_approx: Option<N>, d4_approx: Option<N::RealField>,
d6_approx: Option<N>, d6_approx: Option<N::RealField>,
d8_approx: Option<N>, d8_approx: Option<N::RealField>,
d10_approx: Option<N>, d10_approx: Option<N::RealField>,
} }
impl<N, D> ExpmPadeHelper<N, D> impl<N, D> ExpmPadeHelper<N, D>
where where
N: RealField, N: ComplexField,
D: DimMin<D>, D: DimMin<D>,
DefaultAllocator: Allocator<N, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>, DefaultAllocator: Allocator<N, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
{ {
@ -110,7 +112,7 @@ where
} }
} }
fn d4_tight(&mut self) -> N { fn d4_tight(&mut self) -> N::RealField {
if self.d4_exact.is_none() { if self.d4_exact.is_none() {
self.calc_a4(); self.calc_a4();
self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25))); self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
@ -118,7 +120,7 @@ where
self.d4_exact.unwrap() self.d4_exact.unwrap()
} }
fn d6_tight(&mut self) -> N { fn d6_tight(&mut self) -> N::RealField {
if self.d6_exact.is_none() { if self.d6_exact.is_none() {
self.calc_a6(); self.calc_a6();
self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0))); self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
@ -126,7 +128,7 @@ where
self.d6_exact.unwrap() self.d6_exact.unwrap()
} }
fn d8_tight(&mut self) -> N { fn d8_tight(&mut self) -> N::RealField {
if self.d8_exact.is_none() { if self.d8_exact.is_none() {
self.calc_a8(); self.calc_a8();
self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0))); self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
@ -134,7 +136,7 @@ where
self.d8_exact.unwrap() self.d8_exact.unwrap()
} }
fn d10_tight(&mut self) -> N { fn d10_tight(&mut self) -> N::RealField {
if self.d10_exact.is_none() { if self.d10_exact.is_none() {
self.calc_a10(); self.calc_a10();
self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0))); self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
@ -142,7 +144,7 @@ where
self.d10_exact.unwrap() self.d10_exact.unwrap()
} }
fn d4_loose(&mut self) -> N { fn d4_loose(&mut self) -> N::RealField {
if self.use_exact_norm { if self.use_exact_norm {
return self.d4_tight(); return self.d4_tight();
} }
@ -159,7 +161,7 @@ where
self.d4_approx.unwrap() self.d4_approx.unwrap()
} }
fn d6_loose(&mut self) -> N { fn d6_loose(&mut self) -> N::RealField {
if self.use_exact_norm { if self.use_exact_norm {
return self.d6_tight(); return self.d6_tight();
} }
@ -176,7 +178,7 @@ where
self.d6_approx.unwrap() self.d6_approx.unwrap()
} }
fn d8_loose(&mut self) -> N { fn d8_loose(&mut self) -> N::RealField {
if self.use_exact_norm { if self.use_exact_norm {
return self.d8_tight(); return self.d8_tight();
} }
@ -193,7 +195,7 @@ where
self.d8_approx.unwrap() self.d8_approx.unwrap()
} }
fn d10_loose(&mut self) -> N { fn d10_loose(&mut self) -> N::RealField {
if self.use_exact_norm { if self.use_exact_norm {
return self.d10_tight(); return self.d10_tight();
} }
@ -359,15 +361,20 @@ where
fn ell<N, D>(a: &MatrixN<N, D>, m: u64) -> u64 fn ell<N, D>(a: &MatrixN<N, D>, m: u64) -> u64
where where
N: RealField, N: ComplexField,
D: Dim, D: Dim,
DefaultAllocator: Allocator<N, D, D> + Allocator<N, D>, DefaultAllocator: Allocator<N, D, D>
+ Allocator<N, D>
+ Allocator<N::RealField, D>
+ Allocator<N::RealField, D, D>,
{ {
// 2m choose m = (2m)!/(m! * (2m-m)!) // 2m choose m = (2m)!/(m! * (2m-m)!)
let a_abs_onenorm = onenorm_matrix_power_nonm(&a.abs(), 2 * m + 1); let a_abs = a.map(|x| x.abs());
if a_abs_onenorm == N::zero() { let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1);
if a_abs_onenorm == <N as ComplexField>::RealField::zero() {
return 0; return 0;
} }
@ -399,27 +406,33 @@ where
q.lu().solve(&p).unwrap() q.lu().solve(&p).unwrap()
} }
fn one_norm<N, D>(m: &MatrixN<N, D>) -> N fn one_norm<N, D>(m: &MatrixN<N, D>) -> N::RealField
where where
N: RealField, N: ComplexField,
D: Dim, D: Dim,
DefaultAllocator: Allocator<N, D, D>, DefaultAllocator: Allocator<N, D, D>,
{ {
let mut max = N::zero(); let mut max = <N as ComplexField>::RealField::zero();
for i in 0..m.ncols() { for i in 0..m.ncols() {
let col = m.column(i); let col = m.column(i);
max = max.max(col.iter().fold(N::zero(), |a, b| a + b.abs())); max = max.max(
col.iter()
.fold(<N as ComplexField>::RealField::zero(), |a, b| a + b.abs()),
);
} }
max max
} }
impl<N: RealField, D> MatrixN<N, D> impl<N: ComplexField, D> MatrixN<N, D>
where where
D: DimMin<D, Output = D>, D: DimMin<D, Output = D>,
DefaultAllocator: DefaultAllocator: Allocator<N, D, D>
Allocator<N, D, D> + Allocator<(usize, usize), DimMinimum<D, D>> + Allocator<N, D>, + Allocator<(usize, usize), DimMinimum<D, D>>
+ Allocator<N, D>
+ Allocator<N::RealField, D>
+ Allocator<N::RealField, D, D>,
{ {
/// Computes exponential of this matrix /// Computes exponential of this matrix
pub fn exp(&self) -> Self { pub fn exp(&self) -> Self {
@ -430,19 +443,19 @@ where
let mut h = ExpmPadeHelper::new(self.clone(), true); let mut h = ExpmPadeHelper::new(self.clone(), true);
let eta_1 = N::max(h.d4_loose(), h.d6_loose()); let eta_1 = N::RealField::max(h.d4_loose(), h.d6_loose());
if eta_1 < convert(1.495585217958292e-002) && ell(&h.a, 3) == 0 { if eta_1 < convert(1.495585217958292e-002) && ell(&h.a, 3) == 0 {
let (u, v) = h.pade3(); let (u, v) = h.pade3();
return solve_p_q(u, v); return solve_p_q(u, v);
} }
let eta_2 = N::max(h.d4_tight(), h.d6_loose()); let eta_2 = N::RealField::max(h.d4_tight(), h.d6_loose());
if eta_2 < convert(2.539398330063230e-001) && ell(&h.a, 5) == 0 { if eta_2 < convert(2.539398330063230e-001) && ell(&h.a, 5) == 0 {
let (u, v) = h.pade5(); let (u, v) = h.pade5();
return solve_p_q(u, v); return solve_p_q(u, v);
} }
let eta_3 = N::max(h.d6_tight(), h.d8_loose()); let eta_3 = N::RealField::max(h.d6_tight(), h.d8_loose());
if eta_3 < convert(9.504178996162932e-001) && ell(&h.a, 7) == 0 { if eta_3 < convert(9.504178996162932e-001) && ell(&h.a, 7) == 0 {
let (u, v) = h.pade7(); let (u, v) = h.pade7();
return solve_p_q(u, v); return solve_p_q(u, v);
@ -452,11 +465,11 @@ where
return solve_p_q(u, v); return solve_p_q(u, v);
} }
let eta_4 = N::max(h.d8_loose(), h.d10_loose()); let eta_4 = N::RealField::max(h.d8_loose(), h.d10_loose());
let eta_5 = N::min(eta_3, eta_4); let eta_5 = N::RealField::min(eta_3, eta_4);
let theta_13 = convert(4.25); let theta_13 = convert(4.25);
let mut s = if eta_5 == N::zero() { let mut s = if eta_5 == N::RealField::zero() {
0 0
} else { } else {
let l2 = try_convert((eta_5 / theta_13).log2().ceil()).unwrap(); let l2 = try_convert((eta_5 / theta_13).log2().ceil()).unwrap();

View File

@ -126,4 +126,51 @@ mod tests {
assert!(relative_eq!(f, m.exp(), epsilon = 1.0e-7)); assert!(relative_eq!(f, m.exp(), epsilon = 1.0e-7));
} }
#[test]
fn exp_complex() {
use nalgebra::{Complex, ComplexField, DMatrix, DVector, Matrix2, RealField};
{
let z = Matrix2::<Complex<f64>>::zeros();
let identity = Matrix2::<Complex<f64>>::identity();
assert!((z.exp() - identity).norm() < 1e-7);
}
{
let a = Matrix2::<Complex<f64>>::new(
Complex::<f64>::new(0.0, 1.0),
Complex::<f64>::new(0.0, 2.0),
Complex::<f64>::new(0.0, -1.0),
Complex::<f64>::new(0.0, 3.0),
);
let b = Matrix2::<Complex<f64>>::new(
Complex::<f64>::new(0.42645929666726, 1.89217550966333),
Complex::<f64>::new(-2.13721484276556, -0.97811251808259),
Complex::<f64>::new(1.06860742138278, 0.48905625904129),
Complex::<f64>::new(-1.7107555460983, 0.91406299158075),
);
assert!((a.exp() - b).norm() < 1.0e-07);
}
{
let d1 = Complex::<f64>::new(0.0, <f64 as RealField>::pi());
let d2 = Complex::<f64>::new(0.0, <f64 as RealField>::frac_pi_2());
let d3 = Complex::<f64>::new(0.0, <f64 as RealField>::frac_pi_4());
let m = DMatrix::<Complex<f64>>::from_diagonal(&DVector::from_row_slice(&[d1, d2, d3]));
let res = DMatrix::<Complex<f64>>::from_diagonal(&DVector::from_row_slice(&[
d1.exp(),
d2.exp(),
d3.exp(),
]));
assert!((m.exp() - res).norm() < 1e-07);
}
}
} }