forked from M-Labs/nalgebra
Add quadform/cmpy/cdpy.
This commit is contained in:
parent
52598de44c
commit
144dfbd555
@ -5,6 +5,10 @@ documented here.
|
||||
This project adheres to [Semantic Versioning](http://semver.org/).
|
||||
|
||||
## [0.14.0] − WIP
|
||||
### Modified
|
||||
* `quadform` has been renamed `quadform_tr`. The new `quadform` method takes
|
||||
the matrix on the right-hand-side instead of the matrix on the
|
||||
left-hand-side of the quadratic form.
|
||||
### Added
|
||||
* The `mint` feature that can be enabled in order to allow conversions from
|
||||
and to types of the [mint](https://crates.io/crates/mint) crate.
|
||||
@ -12,6 +16,10 @@ This project adheres to [Semantic Versioning](http://semver.org/).
|
||||
`::from_element(...)`.
|
||||
* The `.iamin()` methods that returns the index of the vector entry with
|
||||
smallest absolute value.
|
||||
* Add blas-like operations: `cmpy, cdpy` for componentwise multiplicatons and
|
||||
division with scalar factors:
|
||||
- `self <- alpha * self + beta * a * b`
|
||||
- `self <- alpha * self + beta / a * b`
|
||||
* `UnitQuaternion::scaled_rotation_between_axis` and
|
||||
`UnitQuaternion::rotation_between_axis` that take Unit vectors instead of
|
||||
Vector as arguments.
|
||||
|
@ -481,6 +481,35 @@ impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes `self = alpha * a.transpose() * b + beta * self`, where `a, b, self` are matrices.
|
||||
/// `alpha` and `beta` are scalar.
|
||||
///
|
||||
/// If `beta` is zero, `self` is never read.
|
||||
#[inline]
|
||||
pub fn gemm_tr<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(&mut self,
|
||||
alpha: N,
|
||||
a: &Matrix<N, R2, C2, SB>,
|
||||
b: &Matrix<N, R3, C3, SC>,
|
||||
beta: N)
|
||||
where N: One,
|
||||
SB: Storage<N, R2, C2>,
|
||||
SC: Storage<N, R3, C3>,
|
||||
ShapeConstraint: SameNumberOfRows<R1, C2> +
|
||||
SameNumberOfColumns<C1, C3> +
|
||||
AreMultipliable<C2, R2, R3, C3> {
|
||||
let (nrows1, ncols1) = self.shape();
|
||||
let (nrows2, ncols2) = a.shape();
|
||||
let (nrows3, ncols3) = b.shape();
|
||||
|
||||
assert_eq!(nrows2, nrows3, "gemm: dimensions mismatch for multiplication.");
|
||||
assert_eq!((nrows1, ncols1), (ncols2, ncols3), "gemm: dimensions mismatch for addition.");
|
||||
|
||||
for j1 in 0 .. ncols1 {
|
||||
// FIXME: avoid bound checks.
|
||||
self.column_mut(j1).gemv_tr(alpha, a, &b.column(j1), beta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -520,8 +549,10 @@ impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S>
|
||||
impl<N, D1: Dim, S: StorageMut<N, D1, D1>> SquareMatrix<N, D1, S>
|
||||
where N: Scalar + Zero + One + ClosedAdd + ClosedMul {
|
||||
|
||||
/// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`.
|
||||
pub fn quadform_with_workspace<D2, S2, R3, C3, S3, D4, S4>(&mut self,
|
||||
/// Computes the quadratic form `self = alpha * lhs * mid * lhs.transpose() + beta * self`.
|
||||
///
|
||||
/// This uses the provided workspace `work` to avoid allocations for intermediate results.
|
||||
pub fn quadform_tr_with_workspace<D2, S2, R3, C3, S3, D4, S4>(&mut self,
|
||||
work: &mut Vector<N, D2, S2>,
|
||||
alpha: N,
|
||||
lhs: &Matrix<N, R3, C3, S3>,
|
||||
@ -544,8 +575,11 @@ impl<N, D1: Dim, S: StorageMut<N, D1, D1>> SquareMatrix<N, D1, S>
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`.
|
||||
pub fn quadform<R3, C3, S3, D4, S4>(&mut self,
|
||||
/// Computes the quadratic form `self = alpha * lhs * mid * lhs.transpose() + beta * self`.
|
||||
///
|
||||
/// This allocates a workspace vector of dimension D1 for intermediate results.
|
||||
/// Use `.quadform_tr_with_workspace(...)` instead to avoid allocations.
|
||||
pub fn quadform_tr<R3, C3, S3, D4, S4>(&mut self,
|
||||
alpha: N,
|
||||
lhs: &Matrix<N, R3, C3, S3>,
|
||||
mid: &SquareMatrix<N, D4, S4>,
|
||||
@ -556,6 +590,53 @@ impl<N, D1: Dim, S: StorageMut<N, D1, D1>> SquareMatrix<N, D1, S>
|
||||
ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>,
|
||||
DefaultAllocator: Allocator<N, D1> {
|
||||
let mut work = unsafe { Vector::new_uninitialized_generic(self.data.shape().0, U1) };
|
||||
self.quadform_with_workspace(&mut work, alpha, lhs, mid, beta)
|
||||
self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta)
|
||||
}
|
||||
|
||||
/// Computes the quadratic form `self = alpha * rhs.transpose() * mid * rhs + beta * self`.
|
||||
///
|
||||
/// This uses the provided workspace `work` to avoid allocations for intermediate results.
|
||||
pub fn quadform_with_workspace<D2, S2, D3, S3, R4, C4, S4>(&mut self,
|
||||
work: &mut Vector<N, D2, S2>,
|
||||
alpha: N,
|
||||
mid: &SquareMatrix<N, D3, S3>,
|
||||
rhs: &Matrix<N, R4, C4, S4>,
|
||||
beta: N)
|
||||
where D2: Dim, D3: Dim, R4: Dim, C4: Dim,
|
||||
S2: StorageMut<N, D2>,
|
||||
S3: Storage<N, D3, D3>,
|
||||
S4: Storage<N, R4, C4>,
|
||||
ShapeConstraint: DimEq<D3, R4> +
|
||||
DimEq<D1, C4> +
|
||||
DimEq<D2, D3> +
|
||||
AreMultipliable<C4, R4, D2, U1> {
|
||||
work.gemv(N::one(), mid, &rhs.column(0), N::zero());
|
||||
self.column_mut(0).gemv_tr(alpha, &rhs, work, beta);
|
||||
|
||||
for j in 1 .. rhs.ncols() {
|
||||
work.gemv(N::one(), mid, &rhs.column(j), N::zero());
|
||||
self.column_mut(j).gemv_tr(alpha, &rhs, work, beta);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the quadratic form `self = alpha * rhs.transpose() * mid * rhs + beta * self`.
|
||||
///
|
||||
/// This allocates a workspace vector of dimension D2 for intermediate results.
|
||||
/// Use `.quadform_with_workspace(...)` instead to avoid allocations.
|
||||
pub fn quadform<D2, S2, R3, C3, S3>(&mut self,
|
||||
alpha: N,
|
||||
mid: &SquareMatrix<N, D2, S2>,
|
||||
rhs: &Matrix<N, R3, C3, S3>,
|
||||
beta: N)
|
||||
where D2: Dim, R3: Dim, C3: Dim,
|
||||
S2: Storage<N, D2, D2>,
|
||||
S3: Storage<N, R3, C3>,
|
||||
ShapeConstraint: DimEq<D2, R3> +
|
||||
DimEq<D1, C3> +
|
||||
AreMultipliable<C3, R3, D2, U1>,
|
||||
DefaultAllocator: Allocator<N, D2> {
|
||||
|
||||
let mut work = unsafe { Vector::new_uninitialized_generic(mid.data.shape().0, U1) };
|
||||
self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Non-convensional componentwise operators.
|
||||
|
||||
use num::Signed;
|
||||
use std::ops::{Add, Mul};
|
||||
use num::{Zero, Signed};
|
||||
|
||||
use alga::general::{ClosedMul, ClosedDiv};
|
||||
|
||||
@ -33,7 +34,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
||||
}
|
||||
|
||||
macro_rules! component_binop_impl(
|
||||
($($binop: ident, $binop_mut: ident, $binop_assign: ident, $Trait: ident . $op_assign: ident, $desc:expr, $desc_mut:expr);* $(;)*) => {$(
|
||||
($($binop: ident, $binop_mut: ident, $binop_assign: ident, $cbpy: ident, $Trait: ident . $op: ident . $op_assign: ident, $desc:expr, $desc_mut:expr);* $(;)*) => {$(
|
||||
impl<N: Scalar, R1: Dim, C1: Dim, SA: Storage<N, R1, C1>> Matrix<N, R1, C1, SA> {
|
||||
#[doc = $desc]
|
||||
#[inline]
|
||||
@ -60,6 +61,41 @@ macro_rules! component_binop_impl(
|
||||
}
|
||||
|
||||
impl<N: Scalar, R1: Dim, C1: Dim, SA: StorageMut<N, R1, C1>> Matrix<N, R1, C1, SA> {
|
||||
// componentwise binop plus Y.
|
||||
#[inline]
|
||||
pub fn $cbpy<R2, C2, SB, R3, C3, SC>(&mut self, alpha: N, a: &Matrix<N, R2, C2, SB>, b: &Matrix<N, R3, C3, SC>, beta: N)
|
||||
where N: $Trait + Zero + Mul<N, Output = N> + Add<N, Output = N>,
|
||||
R2: Dim, C2: Dim,
|
||||
R3: Dim, C3: Dim,
|
||||
SB: Storage<N, R2, C2>,
|
||||
SC: Storage<N, R3, C3>,
|
||||
ShapeConstraint: SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C2> +
|
||||
SameNumberOfRows<R1, R3> + SameNumberOfColumns<C1, C3> {
|
||||
assert_eq!(self.shape(), a.shape(), "Componentwise mul/div: mismatched matrix dimensions.");
|
||||
assert_eq!(self.shape(), b.shape(), "Componentwise mul/div: mismatched matrix dimensions.");
|
||||
|
||||
if beta.is_zero() {
|
||||
for j in 0 .. self.ncols() {
|
||||
for i in 0 .. self.nrows() {
|
||||
unsafe {
|
||||
let res = alpha * a.get_unchecked(i, j).$op(*b.get_unchecked(i, j));
|
||||
*self.get_unchecked_mut(i, j) = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for j in 0 .. self.ncols() {
|
||||
for i in 0 .. self.nrows() {
|
||||
unsafe {
|
||||
let res = alpha * a.get_unchecked(i, j).$op(*b.get_unchecked(i, j));
|
||||
*self.get_unchecked_mut(i, j) = beta * *self.get_unchecked(i, j) + res;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc = $desc_mut]
|
||||
#[inline]
|
||||
pub fn $binop_assign<R2, C2, SB>(&mut self, rhs: &Matrix<N, R2, C2, SB>)
|
||||
@ -96,9 +132,9 @@ macro_rules! component_binop_impl(
|
||||
);
|
||||
|
||||
component_binop_impl!(
|
||||
component_mul, component_mul_mut, component_mul_assign, ClosedMul.mul_assign,
|
||||
component_mul, component_mul_mut, component_mul_assign, cmpy, ClosedMul.mul.mul_assign,
|
||||
"Componentwise matrix multiplication.", "Mutable, componentwise matrix multiplication.";
|
||||
component_div, component_div_mut, component_div_assign, ClosedDiv.div_assign,
|
||||
component_div, component_div_mut, component_div_assign, cdpy, ClosedDiv.div.div_assign,
|
||||
"Componentwise matrix division.", "Mutable, componentwise matrix division.";
|
||||
// FIXME: add other operators like bitshift, etc. ?
|
||||
);
|
||||
|
@ -74,6 +74,21 @@ quickcheck! {
|
||||
}
|
||||
|
||||
fn quadform(n: usize, alpha: f64, beta: f64) -> bool {
|
||||
let n = cmp::max(1, cmp::min(n, 50));
|
||||
let rhs = DMatrix::<f64>::new_random(6, n);
|
||||
let mid = DMatrix::<f64>::new_random(6, 6);
|
||||
let mut res = DMatrix::new_random(n, n);
|
||||
|
||||
let expected = &res * beta + rhs.transpose() * &mid * &rhs * alpha;
|
||||
|
||||
res.quadform(alpha, &mid, &rhs, beta);
|
||||
|
||||
println!("{}{}", res, expected);
|
||||
|
||||
relative_eq!(res, expected, epsilon = 1.0e-7)
|
||||
}
|
||||
|
||||
fn quadform_tr(n: usize, alpha: f64, beta: f64) -> bool {
|
||||
let n = cmp::max(1, cmp::min(n, 50));
|
||||
let lhs = DMatrix::<f64>::new_random(6, n);
|
||||
let mid = DMatrix::<f64>::new_random(n, n);
|
||||
@ -81,7 +96,7 @@ quickcheck! {
|
||||
|
||||
let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha;
|
||||
|
||||
res.quadform(alpha, &lhs, &mid , beta);
|
||||
res.quadform_tr(alpha, &lhs, &mid , beta);
|
||||
|
||||
println!("{}{}", res, expected);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user