Fix compilation of matrix exponential when targetting no-std.

This commit is contained in:
sebcrozet 2020-04-21 10:19:03 +02:00
parent f30476eb2b
commit f4c0897764

View File

@ -42,7 +42,8 @@ impl<N, D> ExpmPadeHelper<N, D>
where where
N: RealField, N: RealField,
D: DimMin<D>, D: DimMin<D>,
DefaultAllocator: Allocator<N, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>, DefaultAllocator:
Allocator<N, D, D> + Allocator<N, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
{ {
fn new(a: MatrixN<N, D>, use_exact_norm: bool) -> Self { fn new(a: MatrixN<N, D>, use_exact_norm: bool) -> Self {
let (nrows, ncols) = a.data.shape(); let (nrows, ncols) = a.data.shape();
@ -405,15 +406,13 @@ where
D: Dim, D: Dim,
DefaultAllocator: Allocator<N, D, D>, DefaultAllocator: Allocator<N, D, D>,
{ {
let mut col_sums = vec![N::zero(); m.ncols()]; let mut max = N::zero();
for i in 0..m.ncols() { for i in 0..m.ncols() {
let col = m.column(i); let col = m.column(i);
col.iter().for_each(|v| col_sums[i] += v.abs()); max = max.max(col.iter().fold(N::zero(), |a, b| a + b.abs()));
}
let mut max = col_sums[0];
for i in 1..col_sums.len() {
max = N::max(max, col_sums[i]);
} }
max max
} }
@ -427,7 +426,7 @@ where
pub fn exp(&self) -> Self { pub fn exp(&self) -> Self {
// Simple case // Simple case
if self.nrows() == 1 { if self.nrows() == 1 {
return self.clone().map(|v| v.exp()); return self.map(|v| v.exp());
} }
let mut h = ExpmPadeHelper::new(self.clone(), true); let mut h = ExpmPadeHelper::new(self.clone(), true);