Merge pull request #992 from MaxVerevkin/exp-rs

exp.rs: factorial(): use precomputed factorial array
This commit is contained in:
Sébastien Crozet 2021-09-25 12:32:11 +02:00 committed by GitHub
commit dd8b6800f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 56 additions and 13 deletions

View File

@ -11,6 +11,47 @@ use crate::{
use crate::num::Zero; use crate::num::Zero;
/// Precomputed factorials for integers in range `0..=34`.
/// Note: `35!` does not fit into 128 bits.
// TODO: find a better place for this array?
const FACTORIAL: [u128; 35] = [
1,
1,
2,
6,
24,
120,
720,
5040,
40320,
362880,
3628800,
39916800,
479001600,
6227020800,
87178291200,
1307674368000,
20922789888000,
355687428096000,
6402373705728000,
121645100408832000,
2432902008176640000,
51090942171709440000,
1124000727777607680000,
25852016738884976640000,
620448401733239439360000,
15511210043330985984000000,
403291461126605635584000000,
10888869450418352160768000000,
304888344611713860501504000000,
8841761993739701954543616000000,
265252859812191058636308480000000,
8222838654177922817725562880000000,
263130836933693530167218012160000000,
8683317618811886495518194401280000000,
295232799039604140847618609643520000000,
];
// https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py // https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py
struct ExpmPadeHelper<T, D> struct ExpmPadeHelper<T, D>
where where
@ -321,8 +362,8 @@ where
self.calc_a2(); self.calc_a2();
self.calc_a4(); self.calc_a4();
self.calc_a6(); self.calc_a6();
let mb2 = self.a2.as_ref().unwrap() * convert::<f64, T>(2.0_f64.powf(-2.0 * s.clone())); let mb2 = self.a2.as_ref().unwrap() * convert::<f64, T>(2.0_f64.powf(-2.0 * s));
let mb4 = self.a4.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-4.0 * s.clone())); let mb4 = self.a4.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-4.0 * s));
let mb6 = self.a6.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-6.0 * s)); let mb6 = self.a6.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-6.0 * s));
let u2 = &mb6 * (&mb6 * b[13].clone() + &mb4 * b[11].clone() + &mb2 * b[9].clone()); let u2 = &mb6 * (&mb6 * b[13].clone() + &mb4 * b[11].clone() + &mb2 * b[9].clone());
@ -342,15 +383,17 @@ where
} }
} }
fn factorial(n: u128) -> u128 { /// Compute `n!`
if n == 1 { #[inline(always)]
return 1; fn factorial(n: usize) -> u128 {
match FACTORIAL.get(n) {
Some(f) => *f,
None => panic!("{}! is greater than u128::MAX", n),
} }
n * factorial(n - 1)
} }
/// Compute the 1-norm of a non-negative integer power of a non-negative matrix. /// Compute the 1-norm of a non-negative integer power of a non-negative matrix.
fn onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: u64) -> T fn onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: usize) -> T
where where
T: RealField, T: RealField,
D: Dim, D: Dim,
@ -367,7 +410,7 @@ where
v.max() v.max()
} }
fn ell<T, D>(a: &OMatrix<T, D, D>, m: u64) -> u64 fn ell<T, D>(a: &OMatrix<T, D, D>, m: usize) -> u64
where where
T: ComplexField, T: ComplexField,
D: Dim, D: Dim,
@ -376,8 +419,6 @@ where
+ Allocator<T::RealField, D> + Allocator<T::RealField, D>
+ Allocator<T::RealField, D, D>, + Allocator<T::RealField, D, D>,
{ {
// 2m choose m = (2m)!/(m! * (2m-m)!)
let a_abs = a.map(|x| x.abs()); let a_abs = a.map(|x| x.abs());
let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1); let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1);
@ -386,9 +427,11 @@ where
return 0; return 0;
} }
let choose_2m_m = // 2m choose m = (2m)!/(m! * (2m-m)!) = (2m)!/((m!)^2)
factorial(2 * m as u128) / (factorial(m as u128) * factorial(2 * m as u128 - m as u128)); let m_factorial = factorial(m);
let abs_c_recip = choose_2m_m * factorial(2 * m as u128 + 1); let choose_2m_m = factorial(2 * m) / (m_factorial * m_factorial);
let abs_c_recip = choose_2m_m * factorial(2 * m + 1);
let alpha = a_abs_onenorm / one_norm(a); let alpha = a_abs_onenorm / one_norm(a);
let alpha: f64 = try_convert(alpha).unwrap() / abs_c_recip as f64; let alpha: f64 = try_convert(alpha).unwrap() / abs_c_recip as f64;