From 0a3ee99cdb0dfa9b9202ced71ae5c277dc18b0f4 Mon Sep 17 00:00:00 2001 From: Fredrik Jansson Date: Sun, 12 Apr 2020 11:46:00 +0200 Subject: [PATCH] Changed dimension name R to D Changed N::from_x to crate::convert --- src/linalg/exp.rs | 230 ++++++++++++++++++++++------------------------ 1 file changed, 111 insertions(+), 119 deletions(-) diff --git a/src/linalg/exp.rs b/src/linalg/exp.rs index 8822027b..f68cbe6c 100644 --- a/src/linalg/exp.rs +++ b/src/linalg/exp.rs @@ -6,25 +6,25 @@ use crate::{ dimension::{DimMin, DimMinimum, DimName}, DefaultAllocator, }, - try_convert, ComplexField, MatrixN, RealField, + convert, try_convert, ComplexField, MatrixN, RealField, }; // https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py -struct ExpmPadeHelper +struct ExpmPadeHelper where N: RealField, - R: DimName + DimMin, - DefaultAllocator: Allocator + Allocator<(usize, usize), DimMinimum>, + D: DimName + DimMin, + DefaultAllocator: Allocator + Allocator<(usize, usize), DimMinimum>, { use_exact_norm: bool, - ident: MatrixN, + ident: MatrixN, - a: MatrixN, - a2: Option>, - a4: Option>, - a6: Option>, - a8: Option>, - a10: Option>, + a: MatrixN, + a2: Option>, + a4: Option>, + a6: Option>, + a8: Option>, + a10: Option>, d4_exact: Option, d6_exact: Option, @@ -37,16 +37,16 @@ where d10_approx: Option, } -impl ExpmPadeHelper +impl ExpmPadeHelper where N: RealField, - R: DimName + DimMin, - DefaultAllocator: Allocator + Allocator<(usize, usize), DimMinimum>, + D: DimName + DimMin, + DefaultAllocator: Allocator + Allocator<(usize, usize), DimMinimum>, { - fn new(a: MatrixN, use_exact_norm: bool) -> Self { + fn new(a: MatrixN, use_exact_norm: bool) -> Self { ExpmPadeHelper { use_exact_norm, - ident: MatrixN::::identity(), + ident: MatrixN::::identity(), a, a2: None, a4: None, @@ -64,9 +64,9 @@ where } } - fn a2(&self) -> &MatrixN { + fn a2(&self) -> &MatrixN { if self.a2.is_none() { - let ap = &self.a2 as *const Option> as *mut Option>; + let ap = &self.a2 as *const Option> as *mut Option>; unsafe { *ap = Some(&self.a * &self.a); } @@ -74,9 +74,9 @@ where self.a2.as_ref().unwrap() } - fn a4(&self) -> &MatrixN { + fn a4(&self) -> &MatrixN { if self.a4.is_none() { - let ap = &self.a4 as *const Option> as *mut Option>; + let ap = &self.a4 as *const Option> as *mut Option>; let a2 = self.a2(); unsafe { *ap = Some(a2 * a2); @@ -85,11 +85,11 @@ where self.a4.as_ref().unwrap() } - fn a6(&self) -> &MatrixN { + fn a6(&self) -> &MatrixN { if self.a6.is_none() { let a2 = self.a2(); let a4 = self.a4(); - let ap = &self.a6 as *const Option> as *mut Option>; + let ap = &self.a6 as *const Option> as *mut Option>; unsafe { *ap = Some(a4 * a2); } @@ -97,11 +97,11 @@ where self.a6.as_ref().unwrap() } - fn a8(&self) -> &MatrixN { + fn a8(&self) -> &MatrixN { if self.a8.is_none() { let a2 = self.a2(); let a6 = self.a6(); - let ap = &self.a8 as *const Option> as *mut Option>; + let ap = &self.a8 as *const Option> as *mut Option>; unsafe { *ap = Some(a6 * a2); } @@ -109,11 +109,11 @@ where self.a8.as_ref().unwrap() } - fn a10(&mut self) -> &MatrixN { + fn a10(&mut self) -> &MatrixN { if self.a10.is_none() { let a4 = self.a4(); let a6 = self.a6(); - let ap = &self.a10 as *const Option> as *mut Option>; + let ap = &self.a10 as *const Option> as *mut Option>; unsafe { *ap = Some(a6 * a4); } @@ -123,28 +123,28 @@ where fn d4_tight(&mut self) -> N { if self.d4_exact.is_none() { - self.d4_exact = Some(one_norm(self.a4()).powf(N::from_f64(0.25).unwrap())); + self.d4_exact = Some(one_norm(self.a4()).powf(convert(0.25))); } self.d4_exact.unwrap() } fn d6_tight(&mut self) -> N { if self.d6_exact.is_none() { - self.d6_exact = Some(one_norm(self.a6()).powf(N::from_f64(1.0 / 6.0).unwrap())); + self.d6_exact = Some(one_norm(self.a6()).powf(convert(1.0 / 6.0))); } self.d6_exact.unwrap() } fn d8_tight(&mut self) -> N { if self.d8_exact.is_none() { - self.d8_exact = Some(one_norm(self.a8()).powf(N::from_f64(1.0 / 8.0).unwrap())); + self.d8_exact = Some(one_norm(self.a8()).powf(convert(1.0 / 8.0))); } self.d8_exact.unwrap() } fn d10_tight(&mut self) -> N { if self.d10_exact.is_none() { - self.d10_exact = Some(one_norm(self.a10()).powf(N::from_f64(1.0 / 10.0).unwrap())); + self.d10_exact = Some(one_norm(self.a10()).powf(convert(1.0 / 10.0))); } self.d10_exact.unwrap() } @@ -159,7 +159,7 @@ where } if self.d4_approx.is_none() { - self.d4_approx = Some(one_norm(self.a4()).powf(N::from_f64(0.25).unwrap())); + self.d4_approx = Some(one_norm(self.a4()).powf(convert(0.25))); } self.d4_approx.unwrap() @@ -175,7 +175,7 @@ where } if self.d6_approx.is_none() { - self.d6_approx = Some(one_norm(self.a6()).powf(N::from_f64(1.0 / 6.0).unwrap())); + self.d6_approx = Some(one_norm(self.a6()).powf(convert(1.0 / 6.0))); } self.d6_approx.unwrap() @@ -191,7 +191,7 @@ where } if self.d8_approx.is_none() { - self.d8_approx = Some(one_norm(self.a8()).powf(N::from_f64(1.0 / 8.0).unwrap())); + self.d8_approx = Some(one_norm(self.a8()).powf(convert(1.0 / 8.0))); } self.d8_approx.unwrap() @@ -207,48 +207,43 @@ where } if self.d10_approx.is_none() { - self.d10_approx = Some(one_norm(self.a10()).powf(N::from_f64(1.0 / 10.0).unwrap())); + self.d10_approx = Some(one_norm(self.a10()).powf(convert(1.0 / 10.0))); } self.d10_approx.unwrap() } - fn pade3(&mut self) -> (MatrixN, MatrixN) { - let b = [ - N::from_f64(120.0).unwrap(), - N::from_f64(60.0).unwrap(), - N::from_f64(12.0).unwrap(), - N::from_f64(1.0).unwrap(), - ]; + fn pade3(&mut self) -> (MatrixN, MatrixN) { + let b: [N; 4] = [convert(120.0), convert(60.0), convert(12.0), convert(1.0)]; let u = &self.a * (self.a2() * b[3] + &self.ident * b[1]); let v = self.a2() * b[2] + &self.ident * b[0]; (u, v) } - fn pade5(&mut self) -> (MatrixN, MatrixN) { - let b = [ - N::from_f64(30240.0).unwrap(), - N::from_f64(15120.0).unwrap(), - N::from_f64(3360.0).unwrap(), - N::from_f64(420.0).unwrap(), - N::from_f64(30.0).unwrap(), - N::from_f64(1.0).unwrap(), + fn pade5(&mut self) -> (MatrixN, MatrixN) { + let b: [N; 6] = [ + convert(30240.0), + convert(15120.0), + convert(3360.0), + convert(420.0), + convert(30.0), + convert(1.0), ]; let u = &self.a * (self.a4() * b[5] + self.a2() * b[3] + &self.ident * b[1]); let v = self.a4() * b[4] + self.a2() * b[2] + &self.ident * b[0]; (u, v) } - fn pade7(&mut self) -> (MatrixN, MatrixN) { - let b = [ - N::from_f64(17297280.0).unwrap(), - N::from_f64(8648640.0).unwrap(), - N::from_f64(1995840.0).unwrap(), - N::from_f64(277200.0).unwrap(), - N::from_f64(25200.0).unwrap(), - N::from_f64(1512.0).unwrap(), - N::from_f64(56.0).unwrap(), - N::from_f64(1.0).unwrap(), + fn pade7(&mut self) -> (MatrixN, MatrixN) { + let b: [N; 8] = [ + convert(17297280.0), + convert(8648640.0), + convert(1995840.0), + convert(277200.0), + convert(25200.0), + convert(1512.0), + convert(56.0), + convert(1.0), ]; let u = &self.a * (self.a6() * b[7] + self.a4() * b[5] + self.a2() * b[3] + &self.ident * b[1]); @@ -256,18 +251,18 @@ where (u, v) } - fn pade9(&mut self) -> (MatrixN, MatrixN) { - let b = [ - N::from_f64(17643225600.0).unwrap(), - N::from_f64(8821612800.0).unwrap(), - N::from_f64(2075673600.0).unwrap(), - N::from_f64(302702400.0).unwrap(), - N::from_f64(30270240.0).unwrap(), - N::from_f64(2162160.0).unwrap(), - N::from_f64(110880.0).unwrap(), - N::from_f64(3960.0).unwrap(), - N::from_f64(90.0).unwrap(), - N::from_f64(1.0).unwrap(), + fn pade9(&mut self) -> (MatrixN, MatrixN) { + let b: [N; 10] = [ + convert(17643225600.0), + convert(8821612800.0), + convert(2075673600.0), + convert(302702400.0), + convert(30270240.0), + convert(2162160.0), + convert(110880.0), + convert(3960.0), + convert(90.0), + convert(1.0), ]; let u = &self.a * (self.a8() * b[9] @@ -283,29 +278,29 @@ where (u, v) } - fn pade13_scaled(&mut self, s: u64) -> (MatrixN, MatrixN) { - let b = [ - N::from_f64(64764752532480000.0).unwrap(), - N::from_f64(32382376266240000.0).unwrap(), - N::from_f64(7771770303897600.0).unwrap(), - N::from_f64(1187353796428800.0).unwrap(), - N::from_f64(129060195264000.0).unwrap(), - N::from_f64(10559470521600.0).unwrap(), - N::from_f64(670442572800.0).unwrap(), - N::from_f64(33522128640.0).unwrap(), - N::from_f64(1323241920.0).unwrap(), - N::from_f64(40840800.0).unwrap(), - N::from_f64(960960.0).unwrap(), - N::from_f64(16380.0).unwrap(), - N::from_f64(182.0).unwrap(), - N::from_f64(1.0).unwrap(), + fn pade13_scaled(&mut self, s: u64) -> (MatrixN, MatrixN) { + let b: [N; 14] = [ + convert(64764752532480000.0), + convert(32382376266240000.0), + convert(7771770303897600.0), + convert(1187353796428800.0), + convert(129060195264000.0), + convert(10559470521600.0), + convert(670442572800.0), + convert(33522128640.0), + convert(1323241920.0), + convert(40840800.0), + convert(960960.0), + convert(16380.0), + convert(182.0), + convert(1.0), ]; let s = s as f64; - let mb = &self.a * N::from_f64(2.0.powf(-s)).unwrap(); - let mb2 = self.a2() * N::from_f64(2.0.powf(-2.0 * s)).unwrap(); - let mb4 = self.a4() * N::from_f64(2.0.powf(-4.0 * s)).unwrap(); - let mb6 = self.a6() * N::from_f64(2.0.powf(-6.0 * s)).unwrap(); + let mb = &self.a * convert::(2.0_f64.powf(-s)); + let mb2 = self.a2() * convert::(2.0_f64.powf(-2.0 * s)); + let mb4 = self.a4() * convert::(2.0.powf(-4.0 * s)); + let mb6 = self.a6() * convert::(2.0.powf(-6.0 * s)); let u2 = &mb6 * (&mb6 * b[13] + &mb4 * b[11] + &mb2 * b[9]); let u = &mb * (&u2 + &mb6 * b[7] + &mb4 * b[5] + &mb2 * b[3] + &self.ident * b[1]); @@ -323,13 +318,13 @@ fn factorial(n: u128) -> u128 { } /// Compute the 1-norm of a non-negative integer power of a non-negative matrix. -fn onenorm_matrix_power_nonm(a: &MatrixN, p: u64) -> N +fn onenorm_matrix_power_nonm(a: &MatrixN, p: u64) -> N where N: RealField, - R: DimName, - DefaultAllocator: Allocator + Allocator, + D: DimName, + DefaultAllocator: Allocator + Allocator, { - let mut v = crate::VectorN::::repeat(N::from_f64(1.0).unwrap()); + let mut v = crate::VectorN::::repeat(convert(1.0)); let m = a.transpose(); for _ in 0..p { @@ -339,11 +334,11 @@ where v.max() } -fn ell(a: &MatrixN, m: u64) -> u64 +fn ell(a: &MatrixN, m: u64) -> u64 where N: RealField, - R: DimName, - DefaultAllocator: Allocator + Allocator, + D: DimName, + DefaultAllocator: Allocator + Allocator, { // 2m choose m = (2m)!/(m! * (2m-m)!) @@ -357,10 +352,10 @@ where factorial(2 * m as u128) / (factorial(m as u128) * factorial(2 * m as u128 - m as u128)); let abs_c_recip = choose_2m_m * factorial(2 * m as u128 + 1); let alpha = a_abs_onenorm / one_norm(a); - let alpha = alpha / N::from_u128(abs_c_recip).unwrap(); + let alpha: f64 = try_convert(alpha).unwrap() / abs_c_recip as f64; - let u = N::from_f64(2_f64.powf(-53.0)).unwrap(); - let log2_alpha_div_u = try_convert((alpha / u).log2()).unwrap(); + let u = 2_f64.powf(-53.0); + let log2_alpha_div_u = (alpha / u).log2(); let value = (log2_alpha_div_u / (2.0 * m as f64)).ceil(); if value > 0.0 { value as u64 @@ -369,11 +364,11 @@ where } } -fn solve_p_q(u: MatrixN, v: MatrixN) -> MatrixN +fn solve_p_q(u: MatrixN, v: MatrixN) -> MatrixN where N: ComplexField, - R: DimMin + DimName, - DefaultAllocator: Allocator + Allocator<(usize, usize), DimMinimum>, + D: DimMin + DimName, + DefaultAllocator: Allocator + Allocator<(usize, usize), DimMinimum>, { let p = &u + &v; let q = &v - &u; @@ -381,11 +376,11 @@ where q.lu().solve(&p).unwrap() } -fn one_norm(m: &MatrixN) -> N +fn one_norm(m: &MatrixN) -> N where N: RealField, - R: DimName, - DefaultAllocator: Allocator, + D: DimName, + DefaultAllocator: Allocator, { let mut col_sums = vec![N::zero(); m.ncols()]; for i in 0..m.ncols() { @@ -399,11 +394,11 @@ where max } -impl MatrixN +impl MatrixN where - R: DimMin + DimName, + D: DimMin + DimName, DefaultAllocator: - Allocator + Allocator<(usize, usize), DimMinimum> + Allocator, + Allocator + Allocator<(usize, usize), DimMinimum> + Allocator, { /// Computes exponential of this matrix pub fn exp(&self) -> Self { @@ -415,30 +410,30 @@ where let mut h = ExpmPadeHelper::new(self.clone(), true); let eta_1 = N::max(h.d4_loose(), h.d6_loose()); - if eta_1 < N::from_f64(1.495585217958292e-002).unwrap() && ell(&h.a, 3) == 0 { + if eta_1 < convert(1.495585217958292e-002) && ell(&h.a, 3) == 0 { let (u, v) = h.pade3(); return solve_p_q(u, v); } let eta_2 = N::max(h.d4_tight(), h.d6_loose()); - if eta_2 < N::from_f64(2.539398330063230e-001).unwrap() && ell(&h.a, 5) == 0 { + if eta_2 < convert(2.539398330063230e-001) && ell(&h.a, 5) == 0 { let (u, v) = h.pade5(); return solve_p_q(u, v); } let eta_3 = N::max(h.d6_tight(), h.d8_loose()); - if eta_3 < N::from_f64(9.504178996162932e-001).unwrap() && ell(&h.a, 7) == 0 { + if eta_3 < convert(9.504178996162932e-001) && ell(&h.a, 7) == 0 { let (u, v) = h.pade7(); return solve_p_q(u, v); } - if eta_3 < N::from_f64(2.097847961257068e+000).unwrap() && ell(&h.a, 9) == 0 { + if eta_3 < convert(2.097847961257068e+000) && ell(&h.a, 9) == 0 { let (u, v) = h.pade9(); return solve_p_q(u, v); } let eta_4 = N::max(h.d8_loose(), h.d10_loose()); let eta_5 = N::min(eta_3, eta_4); - let theta_13 = N::from_f64(4.25).unwrap(); + let theta_13 = convert(4.25); let mut s = if eta_5 == N::zero() { 0 @@ -452,10 +447,7 @@ where } }; - s += ell( - &(&h.a * N::from_f64(2.0_f64.powf(-(s as f64))).unwrap()), - 13, - ); + s += ell(&(&h.a * convert::(2.0_f64.powf(-(s as f64)))), 13); let (u, v) = h.pade13_scaled(s); let mut x = solve_p_q(u, v);