diff --git a/src/linalg/exp.rs b/src/linalg/exp.rs index 8e8227d5..877ce0d3 100644 --- a/src/linalg/exp.rs +++ b/src/linalg/exp.rs @@ -121,28 +121,28 @@ where fn d4_tight(&mut self) -> N { if self.d4_exact.is_none() { - self.d4_exact = Some(self.a4().amax().powf(N::from_f64(0.25).unwrap())); + self.d4_exact = Some(one_norm(self.a4()).powf(N::from_f64(0.25).unwrap())); } self.d4_exact.unwrap() } fn d6_tight(&mut self) -> N { if self.d6_exact.is_none() { - self.d6_exact = Some(self.a6().amax().powf(N::from_f64(1.0 / 6.0).unwrap())); + self.d6_exact = Some(one_norm(self.a6()).powf(N::from_f64(1.0 / 6.0).unwrap())); } self.d6_exact.unwrap() } fn d8_tight(&mut self) -> N { if self.d8_exact.is_none() { - self.d8_exact = Some(self.a8().amax().powf(N::from_f64(1.0 / 8.0).unwrap())); + self.d8_exact = Some(one_norm(self.a8()).powf(N::from_f64(1.0 / 8.0).unwrap())); } self.d8_exact.unwrap() } fn d10_tight(&mut self) -> N { if self.d10_exact.is_none() { - self.d10_exact = Some(self.a10().amax().powf(N::from_f64(1.0 / 10.0).unwrap())); + self.d10_exact = Some(one_norm(self.a10()).powf(N::from_f64(1.0 / 10.0).unwrap())); } self.d10_exact.unwrap() } @@ -157,7 +157,7 @@ where } if self.d4_approx.is_none() { - self.d4_approx = Some(self.a4().amax().powf(N::from_f64(0.25).unwrap())); + self.d4_approx = Some(one_norm(self.a4()).powf(N::from_f64(0.25).unwrap())); } self.d4_approx.unwrap() @@ -173,7 +173,7 @@ where } if self.d6_approx.is_none() { - self.d6_approx = Some(self.a6().amax().powf(N::from_f64(1.0 / 6.0).unwrap())); + self.d6_approx = Some(one_norm(self.a6()).powf(N::from_f64(1.0 / 6.0).unwrap())); } self.d6_approx.unwrap() @@ -189,7 +189,7 @@ where } if self.d8_approx.is_none() { - self.d8_approx = Some(self.a8().amax().powf(N::from_f64(1.0 / 8.0).unwrap())); + self.d8_approx = Some(one_norm(self.a8()).powf(N::from_f64(1.0 / 8.0).unwrap())); } self.d8_approx.unwrap() @@ -205,7 +205,7 @@ where } if self.d10_approx.is_none() { - self.d10_approx = Some(self.a10().amax().powf(N::from_f64(1.0 / 10.0).unwrap())); + self.d10_approx = Some(one_norm(self.a10()).powf(N::from_f64(1.0 / 10.0).unwrap())); } self.d10_approx.unwrap() @@ -333,7 +333,7 @@ where v = &m * v; } - v.amax() + one_norm(&v) } fn ell(a: &MatrixN, m: u64) -> u64 @@ -353,7 +353,7 @@ where let choose_2m_m = factorial(2 * m as u128) / (factorial(m as u128) * factorial(2 * m as u128 - m as u128)); let abs_c_recip = choose_2m_m * factorial(2 * m as u128 + 1); - let alpha = a_abs_onenorm / a.amax(); + let alpha = a_abs_onenorm / one_norm(a); let alpha = alpha / N::from_u128(abs_c_recip).unwrap(); let u = N::from_f64(2_f64.powf(-53.0)).unwrap(); @@ -378,6 +378,24 @@ where q.lu().solve(&p).unwrap() } +pub fn one_norm(m: &MatrixN) -> N +where + N: RealField, + R: DimName, + DefaultAllocator: Allocator, +{ + let mut col_sums = vec![N::zero(); m.ncols()]; + for i in 0..m.ncols() { + let col = m.column(i); + col.iter().for_each(|v| col_sums[i] += v.abs()); + } + let mut max = col_sums[0]; + for i in 1..col_sums.len() { + max = N::max(max, col_sums[i]); + } + max +} + impl + DimName> MatrixN where DefaultAllocator: Allocator + Allocator<(usize, usize), DimMinimum>, @@ -389,7 +407,7 @@ where return self.clone().map(|v| v.exp()); } - let mut h = ExpmPadeHelper::new(self.clone(), true); + let mut h = ExpmPadeHelper::new(self.clone(), false); let eta_1 = N::max(h.d4_loose(), h.d6_loose()); if eta_1 < N::from_f64(1.495585217958292e-002).unwrap() && ell(&h.a, 3) == 0 { @@ -443,3 +461,14 @@ where x } } + +#[cfg(test)] +mod tests { + #[test] + fn one_norm() { + use crate::Matrix3; + let m = Matrix3::new(-3.0, 5.0, 7.0, 2.0, 6.0, 4.0, 0.0, 2.0, 8.0); + + assert_eq!(super::one_norm(&m), 19.0); + } +}