forked from M-Labs/nalgebra
Fix Vector::axpy for noncommutative cases
One example would be performing simple matrix multiplication over a division algebra such as quaternions.
This commit is contained in:
parent
bde8fbe10f
commit
d0fa79f6e1
@ -473,7 +473,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul {
|
|||||||
for i in 0..len {
|
for i in 0..len {
|
||||||
unsafe {
|
unsafe {
|
||||||
let y = y.get_unchecked_mut(i * stride1);
|
let y = y.get_unchecked_mut(i * stride1);
|
||||||
*y = a * *x.get_unchecked(i * stride2) + beta * *y;
|
*y = *x.get_unchecked(i * stride2) * a + *y * beta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -482,7 +482,7 @@ fn array_ax<N>(y: &mut [N], a: N, x: &[N], stride1: usize, stride2: usize, len:
|
|||||||
where N: Scalar + Zero + ClosedAdd + ClosedMul {
|
where N: Scalar + Zero + ClosedAdd + ClosedMul {
|
||||||
for i in 0..len {
|
for i in 0..len {
|
||||||
unsafe {
|
unsafe {
|
||||||
*y.get_unchecked_mut(i * stride1) = a * *x.get_unchecked(i * stride2);
|
*y.get_unchecked_mut(i * stride1) = *x.get_unchecked(i * stride2) * a;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -579,13 +579,13 @@ where
|
|||||||
// FIXME: avoid bound checks.
|
// FIXME: avoid bound checks.
|
||||||
let col2 = a.column(0);
|
let col2 = a.column(0);
|
||||||
let val = unsafe { *x.vget_unchecked(0) };
|
let val = unsafe { *x.vget_unchecked(0) };
|
||||||
self.axpy(alpha * val, &col2, beta);
|
self.axpy(val * alpha, &col2, beta);
|
||||||
|
|
||||||
for j in 1..ncols2 {
|
for j in 1..ncols2 {
|
||||||
let col2 = a.column(j);
|
let col2 = a.column(j);
|
||||||
let val = unsafe { *x.vget_unchecked(j) };
|
let val = unsafe { *x.vget_unchecked(j) };
|
||||||
|
|
||||||
self.axpy(alpha * val, &col2, N::one());
|
self.axpy(val * alpha, &col2, N::one());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -624,7 +624,7 @@ where
|
|||||||
// FIXME: avoid bound checks.
|
// FIXME: avoid bound checks.
|
||||||
let col2 = a.column(0);
|
let col2 = a.column(0);
|
||||||
let val = unsafe { *x.vget_unchecked(0) };
|
let val = unsafe { *x.vget_unchecked(0) };
|
||||||
self.axpy(alpha * val, &col2, beta);
|
self.axpy(val * alpha, &col2, beta);
|
||||||
self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..));
|
self[0] += alpha * dot(&a.slice_range(1.., 0), &x.rows_range(1..));
|
||||||
|
|
||||||
for j in 1..dim2 {
|
for j in 1..dim2 {
|
||||||
@ -637,7 +637,7 @@ where
|
|||||||
*self.vget_unchecked_mut(j) += alpha * dot;
|
*self.vget_unchecked_mut(j) += alpha * dot;
|
||||||
}
|
}
|
||||||
self.rows_range_mut(j + 1..)
|
self.rows_range_mut(j + 1..)
|
||||||
.axpy(alpha * val, &col2.rows_range(j + 1..), N::one());
|
.axpy(val * alpha, &col2.rows_range(j + 1..), N::one());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -890,7 +890,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
|
|||||||
for j in 0..ncols1 {
|
for j in 0..ncols1 {
|
||||||
// FIXME: avoid bound checks.
|
// FIXME: avoid bound checks.
|
||||||
let val = unsafe { conjugate(*y.vget_unchecked(j)) };
|
let val = unsafe { conjugate(*y.vget_unchecked(j)) };
|
||||||
self.column_mut(j).axpy(alpha * val, x, beta);
|
self.column_mut(j).axpy(val * alpha, x, beta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1256,7 +1256,7 @@ where N: Scalar + Zero + ClosedAdd + ClosedMul
|
|||||||
let subdim = Dynamic::new(dim1 - j);
|
let subdim = Dynamic::new(dim1 - j);
|
||||||
// FIXME: avoid bound checks.
|
// FIXME: avoid bound checks.
|
||||||
self.generic_slice_mut((j, j), (subdim, U1)).axpy(
|
self.generic_slice_mut((j, j), (subdim, U1)).axpy(
|
||||||
alpha * val,
|
val * alpha,
|
||||||
&x.rows_range(j..),
|
&x.rows_range(j..),
|
||||||
beta,
|
beta,
|
||||||
);
|
);
|
||||||
|
@ -1,5 +1,28 @@
|
|||||||
#![cfg(feature = "arbitrary")]
|
use na::{geometry::Quaternion, Matrix2, Vector3};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gemm_noncommutative() {
|
||||||
|
type Qf64 = Quaternion<f64>;
|
||||||
|
let i = Qf64::from_imag(Vector3::new(1.0, 0.0, 0.0));
|
||||||
|
let j = Qf64::from_imag(Vector3::new(0.0, 1.0, 0.0));
|
||||||
|
let k = Qf64::from_imag(Vector3::new(0.0, 0.0, 1.0));
|
||||||
|
|
||||||
|
let m1 = Matrix2::new(k, Qf64::zero(), j, i);
|
||||||
|
// this is the inverse of m1
|
||||||
|
let m2 = Matrix2::new(-k, Qf64::zero(), Qf64::one(), -i);
|
||||||
|
|
||||||
|
let mut res: Matrix2<Qf64> = Matrix2::zero();
|
||||||
|
res.gemm(Qf64::one(), &m1, &m2, Qf64::zero());
|
||||||
|
assert_eq!(res, Matrix2::identity());
|
||||||
|
|
||||||
|
let mut res: Matrix2<Qf64> = Matrix2::identity();
|
||||||
|
res.gemm(k, &m1, &m2, -k);
|
||||||
|
assert_eq!(res, Matrix2::zero());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "arbitrary")]
|
||||||
|
mod blas_quickcheck {
|
||||||
use na::{DMatrix, DVector};
|
use na::{DMatrix, DVector};
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
|
|
||||||
@ -103,3 +126,4 @@ quickcheck! {
|
|||||||
relative_eq!(res, expected, epsilon = 1.0e-7)
|
relative_eq!(res, expected, epsilon = 1.0e-7)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user