use alga::general::{ClosedAdd, ClosedMul}; use num::{One, Zero}; use std::iter; use std::marker::PhantomData; use std::ops::{Add, Mul, Range}; use std::slice; use allocator::Allocator; use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint}; use sparse::{CsMatrix, CsStorage, CsVector}; use storage::{Storage, StorageMut}; use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1}; impl> CsMatrix { pub fn solve_lower_triangular( &self, b: &Matrix, ) -> Option> where S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { let mut b = b.clone_owned(); if self.solve_lower_triangular_mut(&mut b) { Some(b) } else { None } } pub fn tr_solve_lower_triangular( &self, b: &Matrix, ) -> Option> where S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { let mut b = b.clone_owned(); if self.tr_solve_lower_triangular_mut(&mut b) { Some(b) } else { None } } pub fn solve_lower_triangular_mut( &self, b: &mut Matrix, ) -> bool where S2: StorageMut, ShapeConstraint: SameNumberOfRows, { let (nrows, ncols) = self.data.shape(); assert_eq!(nrows.value(), ncols.value(), "The matrix must be square."); assert_eq!(nrows.value(), b.len(), "Mismatched matrix dimensions."); for j2 in 0..b.ncols() { let mut b = b.column_mut(j2); for j in 0..ncols.value() { let mut column = self.data.column_entries(j); let mut diag_found = false; while let Some((i, val)) = column.next() { if i == j { if val.is_zero() { return false; } b[j] /= val; diag_found = true; break; } } if !diag_found { return false; } for (i, val) in column { b[i] -= b[j] * val; } } } true } pub fn tr_solve_lower_triangular_mut( &self, b: &mut Matrix, ) -> bool where S2: StorageMut, ShapeConstraint: SameNumberOfRows, { let (nrows, ncols) = self.data.shape(); assert_eq!(nrows.value(), ncols.value(), "The matrix must be square."); assert_eq!(nrows.value(), b.len(), "Mismatched matrix dimensions."); for j2 in 0..b.ncols() { let mut b = b.column_mut(j2); for j in (0..ncols.value()).rev() { let mut column = self.data.column_entries(j); let mut diag = None; while let Some((i, val)) = column.next() { if i == j { if val.is_zero() { return false; } diag = Some(val); break; } } if let Some(diag) = diag { for (i, val) in column { b[j] -= val * b[i]; } b[j] /= diag; } else { return false; } } } true } pub fn solve_lower_triangular_cs( &self, b: &CsVector, ) -> Option> where S2: CsStorage, DefaultAllocator: Allocator + Allocator + Allocator, ShapeConstraint: SameNumberOfRows, { let mut reach = Vec::new(); // We don't compute a postordered reach here because it will be sorted after anyway. self.lower_triangular_reach(b, &mut reach); // We sort the reach so the result matrix has sorted indices. reach.sort(); let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) }; for i in reach.iter().cloned() { workspace[i] = N::zero(); } for (i, val) in b.data.column_entries(0) { workspace[i] = val; } for j in reach.iter().cloned() { let mut column = self.data.column_entries(j); let mut diag_found = false; while let Some((i, val)) = column.next() { if i == j { if val.is_zero() { break; } workspace[j] /= val; diag_found = true; break; } } if !diag_found { return None; } for (i, val) in column { workspace[i] -= workspace[j] * val; } } // Copy the result into a sparse vector. let mut result = CsVector::new_uninitialized_generic(b.data.shape().0, U1, reach.len()); for (i, val) in reach.iter().zip(result.data.vals.iter_mut()) { *val = workspace[*i]; } result.data.i = reach; Some(result) } // Computes the reachable, post-ordered, nodes from `b`. fn lower_triangular_reach_postordered( &self, b: &CsVector, xi: &mut Vec, ) where S2: CsStorage, DefaultAllocator: Allocator, { let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false); let mut stack = Vec::new(); for i in b.data.column_range(0) { let row_index = b.data.row_index(i); if !visited[row_index] { let rng = self.data.column_range(row_index); stack.push((row_index, rng)); self.lower_triangular_dfs(visited.as_mut_slice(), &mut stack, xi); } } } fn lower_triangular_dfs( &self, visited: &mut [bool], stack: &mut Vec<(usize, Range)>, xi: &mut Vec, ) { 'recursion: while let Some((j, rng)) = stack.pop() { visited[j] = true; for i in rng.clone() { let row_id = self.data.row_index(i); if row_id > j && !visited[row_id] { stack.push((j, (i + 1)..rng.end)); stack.push((row_id, self.data.column_range(row_id))); continue 'recursion; } } xi.push(j) } } // Computes the nodes reachable from `b` in an arbitrary order. fn lower_triangular_reach(&self, b: &CsVector, xi: &mut Vec) where S2: CsStorage, DefaultAllocator: Allocator, { let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false); let mut stack = Vec::new(); for irow in b.data.column_row_indices(0) { self.lower_triangular_bfs(irow, visited.as_mut_slice(), &mut stack, xi); } } fn lower_triangular_bfs( &self, start: usize, visited: &mut [bool], stack: &mut Vec, xi: &mut Vec, ) { if !visited[start] { stack.clear(); stack.push(start); xi.push(start); visited[start] = true; while let Some(j) = stack.pop() { for irow in self.data.column_row_indices(j) { if irow > j && !visited[irow] { stack.push(irow); xi.push(irow); visited[irow] = true; } } } } } }