use std::uint::iterate; use std::num::{One, Zero}; use std::vec::{from_elem, swap}; use std::cmp::ApproxEq; use std::iterator::IteratorUtil; use traits::inv::Inv; use traits::division_ring::DivisionRing; use traits::transpose::Transpose; use traits::workarounds::rlmul::{RMul, LMul}; use ndim::dvec::{DVec, zero_vec_with_dim}; #[deriving(Eq, ToStr, Clone)] pub struct DMat { dim: uint, // FIXME: handle more than just square matrices mij: ~[N] } pub fn zero_mat_with_dim(dim: uint) -> DMat { DMat { dim: dim, mij: from_elem(dim * dim, Zero::zero()) } } pub fn is_zero_mat(mat: &DMat) -> bool { mat.mij.all(|e| e.is_zero()) } pub fn one_mat_with_dim(dim: uint) -> DMat { let mut res = zero_mat_with_dim(dim); let _1 = One::one::(); for iterate(0u, dim) |i| { res.set(i, i, &_1); } res } impl DMat { pub fn offset(&self, i: uint, j: uint) -> uint { i * self.dim + j } pub fn set(&mut self, i: uint, j: uint, t: &N) { assert!(i < self.dim); assert!(j < self.dim); self.mij[self.offset(i, j)] = *t } pub fn at(&self, i: uint, j: uint) -> N { assert!(i < self.dim); assert!(j < self.dim); self.mij[self.offset(i, j)] } } impl Index<(uint, uint), N> for DMat { fn index(&self, &(i, j): &(uint, uint)) -> N { self.at(i, j) } } impl + Add + Zero> Mul, DMat> for DMat { fn mul(&self, other: &DMat) -> DMat { assert!(self.dim == other.dim); let dim = self.dim; let mut res = zero_mat_with_dim(dim); for iterate(0u, dim) |i| { for iterate(0u, dim) |j| { let mut acc = Zero::zero::(); for iterate(0u, dim) |k| { acc += self.at(i, k) * other.at(k, j); } res.set(i, j, &acc); } } res } } impl + Mul + Zero> RMul> for DMat { fn rmul(&self, other: &DVec) -> DVec { assert!(self.dim == other.at.len()); let dim = self.dim; let mut res : DVec = zero_vec_with_dim(dim); for iterate(0u, dim) |i| { for iterate(0u, dim) |j| { res.at[i] = res.at[i] + other.at[j] * self.at(i, j); } } res } } impl + Mul + Zero> LMul> for DMat { fn lmul(&self, other: &DVec) -> DVec { assert!(self.dim == other.at.len()); let dim = self.dim; let mut res : DVec = zero_vec_with_dim(dim); for iterate(0u, dim) |i| { for iterate(0u, dim) |j| { res.at[i] = res.at[i] + other.at[j] * self.at(j, i); } } res } } impl Inv for DMat { fn inverse(&self) -> DMat { let mut res : DMat = self.clone(); res.invert(); res } fn invert(&mut self) { let dim = self.dim; let mut res = one_mat_with_dim::(dim); let _0T = Zero::zero::(); // inversion using Gauss-Jordan elimination for iterate(0u, dim) |k| { // search a non-zero value on the k-th column // FIXME: would it be worth it to spend some more time searching for the // max instead? let mut n0 = k; // index of a non-zero entry while (n0 != dim) { if (self.at(n0, k) != _0T) { break; } n0 += 1; } assert!(n0 != dim); // non inversible matrix // swap pivot line if (n0 != k) { for iterate(0u, dim) |j| { let off_n0_j = self.offset(n0, j); let off_k_j = self.offset(k, j); swap(self.mij, off_n0_j, off_k_j); swap(res.mij, off_n0_j, off_k_j); } } let pivot = self.at(k, k); for iterate(k, dim) |j| { let selfval = &(self.at(k, j) / pivot); self.set(k, j, selfval); } for iterate(0u, dim) |j| { let resval = &(res.at(k, j) / pivot); res.set(k, j, resval); } for iterate(0u, dim) |l| { if (l != k) { let normalizer = self.at(l, k); for iterate(k, dim) |j| { let selfval = &(self.at(l, j) - self.at(k, j) * normalizer); self.set(l, j, selfval); } for iterate(0u, dim) |j| { let resval = &(res.at(l, j) - res.at(k, j) * normalizer); res.set(l, j, resval); } } } } *self = res; } } impl Transpose for DMat { fn transposed(&self) -> DMat { let mut res = copy *self; res.transpose(); res } fn transpose(&mut self) { let dim = self.dim; for iterate(1u, dim) |i| { for iterate(0u, dim - 1) |j| { let off_i_j = self.offset(i, j); let off_j_i = self.offset(j, i); swap(self.mij, off_i_j, off_j_i); } } } } impl> ApproxEq for DMat { fn approx_epsilon() -> N { ApproxEq::approx_epsilon::() } fn approx_eq(&self, other: &DMat) -> bool { let mut zip = self.mij.iter().zip(other.mij.iter()); do zip.all |(a, b)| { a.approx_eq(b) } } fn approx_eq_eps(&self, other: &DMat, epsilon: &N) -> bool { let mut zip = self.mij.iter().zip(other.mij.iter()); do zip.all |(a, b)| { a.approx_eq_eps(b, epsilon) } } }