322 lines
10 KiB
Rust
322 lines
10 KiB
Rust
/*
|
|
* This file implements some BLAS operations in such a way that they work
|
|
* even if the first argument (the output parameter) is an uninitialized matrix.
|
|
*
|
|
* Because doing this makes the code harder to read, we only implemented the operations that we
|
|
* know would benefit from this performance-wise, namely, GEMM (which we use for our matrix
|
|
* multiplication code). If we identify other operations like that in the future, we could add
|
|
* them here.
|
|
*/
|
|
|
|
#[cfg(feature = "std")]
|
|
use matrixmultiply;
|
|
use num::{One, Zero};
|
|
use simba::scalar::{ClosedAdd, ClosedMul};
|
|
#[cfg(feature = "std")]
|
|
use std::mem;
|
|
|
|
use crate::base::constraint::{
|
|
AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
|
|
};
|
|
use crate::base::dimension::{Dim, Dyn, U1};
|
|
use crate::base::storage::{RawStorage, RawStorageMut};
|
|
use crate::base::uninit::InitStatus;
|
|
use crate::base::{Matrix, Scalar, Vector};
|
|
use std::any::TypeId;
|
|
|
|
// # Safety
|
|
// The content of `y` must only contain values for which
|
|
// `Status::assume_init_mut` is sound.
|
|
#[allow(clippy::too_many_arguments)]
|
|
unsafe fn array_axcpy<Status, T>(
|
|
_: Status,
|
|
y: &mut [Status::Value],
|
|
a: T,
|
|
x: &[T],
|
|
c: T,
|
|
beta: T,
|
|
stride1: usize,
|
|
stride2: usize,
|
|
len: usize,
|
|
) where
|
|
Status: InitStatus<T>,
|
|
T: Scalar + Zero + ClosedAdd + ClosedMul,
|
|
{
|
|
for i in 0..len {
|
|
let y = Status::assume_init_mut(y.get_unchecked_mut(i * stride1));
|
|
*y =
|
|
a.clone() * x.get_unchecked(i * stride2).clone() * c.clone() + beta.clone() * y.clone();
|
|
}
|
|
}
|
|
|
|
fn array_axc<Status, T>(
|
|
_: Status,
|
|
y: &mut [Status::Value],
|
|
a: T,
|
|
x: &[T],
|
|
c: T,
|
|
stride1: usize,
|
|
stride2: usize,
|
|
len: usize,
|
|
) where
|
|
Status: InitStatus<T>,
|
|
T: Scalar + Zero + ClosedAdd + ClosedMul,
|
|
{
|
|
for i in 0..len {
|
|
unsafe {
|
|
Status::init(
|
|
y.get_unchecked_mut(i * stride1),
|
|
a.clone() * x.get_unchecked(i * stride2).clone() * c.clone(),
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Computes `y = a * x * c + b * y`.
|
|
///
|
|
/// If `b` is zero, `y` is never read from and may be uninitialized.
|
|
///
|
|
/// # Safety
|
|
/// This is UB if b != 0 and any component of `y` is uninitialized.
|
|
#[inline(always)]
|
|
#[allow(clippy::many_single_char_names)]
|
|
pub unsafe fn axcpy_uninit<Status, T, D1: Dim, D2: Dim, SA, SB>(
|
|
status: Status,
|
|
y: &mut Vector<Status::Value, D1, SA>,
|
|
a: T,
|
|
x: &Vector<T, D2, SB>,
|
|
c: T,
|
|
b: T,
|
|
) where
|
|
T: Scalar + Zero + ClosedAdd + ClosedMul,
|
|
SA: RawStorageMut<Status::Value, D1>,
|
|
SB: RawStorage<T, D2>,
|
|
ShapeConstraint: DimEq<D1, D2>,
|
|
Status: InitStatus<T>,
|
|
{
|
|
assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
|
|
|
|
let rstride1 = y.strides().0;
|
|
let rstride2 = x.strides().0;
|
|
|
|
// SAFETY: the conversion to slices is OK because we access the
|
|
// elements taking the strides into account.
|
|
let y = y.data.as_mut_slice_unchecked();
|
|
let x = x.data.as_slice_unchecked();
|
|
|
|
if !b.is_zero() {
|
|
array_axcpy(status, y, a, x, c, b, rstride1, rstride2, x.len());
|
|
} else {
|
|
array_axc(status, y, a, x, c, rstride1, rstride2, x.len());
|
|
}
|
|
}
|
|
|
|
/// Computes `y = alpha * a * x + beta * y`, where `a` is a matrix, `x` a vector, and
|
|
/// `alpha, beta` two scalars.
|
|
///
|
|
/// If `beta` is zero, `y` is never read from and may be uninitialized.
|
|
///
|
|
/// # Safety
|
|
/// This is UB if beta != 0 and any component of `y` is uninitialized.
|
|
#[inline(always)]
|
|
pub unsafe fn gemv_uninit<Status, T, D1: Dim, R2: Dim, C2: Dim, D3: Dim, SA, SB, SC>(
|
|
status: Status,
|
|
y: &mut Vector<Status::Value, D1, SA>,
|
|
alpha: T,
|
|
a: &Matrix<T, R2, C2, SB>,
|
|
x: &Vector<T, D3, SC>,
|
|
beta: T,
|
|
) where
|
|
Status: InitStatus<T>,
|
|
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
|
SA: RawStorageMut<Status::Value, D1>,
|
|
SB: RawStorage<T, R2, C2>,
|
|
SC: RawStorage<T, D3>,
|
|
ShapeConstraint: DimEq<D1, R2> + AreMultipliable<R2, C2, D3, U1>,
|
|
{
|
|
let dim1 = y.nrows();
|
|
let (nrows2, ncols2) = a.shape();
|
|
let dim3 = x.nrows();
|
|
|
|
assert!(
|
|
ncols2 == dim3 && dim1 == nrows2,
|
|
"Gemv: dimensions mismatch."
|
|
);
|
|
|
|
if ncols2 == 0 {
|
|
if beta.is_zero() {
|
|
y.apply(|e| Status::init(e, T::zero()));
|
|
} else {
|
|
// SAFETY: this is UB if y is uninitialized.
|
|
y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
|
|
}
|
|
return;
|
|
}
|
|
|
|
// TODO: avoid bound checks.
|
|
let col2 = a.column(0);
|
|
let val = x.vget_unchecked(0).clone();
|
|
|
|
// SAFETY: this is the call that makes this method unsafe: it is UB if Status = Uninit and beta != 0.
|
|
axcpy_uninit(status, y, alpha.clone(), &col2, val, beta);
|
|
|
|
for j in 1..ncols2 {
|
|
let col2 = a.column(j);
|
|
let val = x.vget_unchecked(j).clone();
|
|
|
|
// SAFETY: safe because y was initialized above.
|
|
axcpy_uninit(status, y, alpha.clone(), &col2, val, T::one());
|
|
}
|
|
}
|
|
|
|
/// Computes `y = alpha * a * b + beta * y`, where `a, b, y` are matrices.
|
|
/// `alpha` and `beta` are scalar.
|
|
///
|
|
/// If `beta` is zero, `y` is never read from and may be uninitialized.
|
|
///
|
|
/// # Safety
|
|
/// This is UB if beta != 0 and any component of `y` is uninitialized.
|
|
#[inline(always)]
|
|
pub unsafe fn gemm_uninit<
|
|
Status,
|
|
T,
|
|
R1: Dim,
|
|
C1: Dim,
|
|
R2: Dim,
|
|
C2: Dim,
|
|
R3: Dim,
|
|
C3: Dim,
|
|
SA,
|
|
SB,
|
|
SC,
|
|
>(
|
|
status: Status,
|
|
y: &mut Matrix<Status::Value, R1, C1, SA>,
|
|
alpha: T,
|
|
a: &Matrix<T, R2, C2, SB>,
|
|
b: &Matrix<T, R3, C3, SC>,
|
|
beta: T,
|
|
) where
|
|
Status: InitStatus<T>,
|
|
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
|
SA: RawStorageMut<Status::Value, R1, C1>,
|
|
SB: RawStorage<T, R2, C2>,
|
|
SC: RawStorage<T, R3, C3>,
|
|
ShapeConstraint:
|
|
SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C3> + AreMultipliable<R2, C2, R3, C3>,
|
|
{
|
|
let ncols1 = y.ncols();
|
|
|
|
#[cfg(feature = "std")]
|
|
{
|
|
// We assume large matrices will be Dyn but small matrices static.
|
|
// We could use matrixmultiply for large statically-sized matrices but the performance
|
|
// threshold to activate it would be different from SMALL_DIM because our code optimizes
|
|
// better for statically-sized matrices.
|
|
if R1::is::<Dyn>()
|
|
|| C1::is::<Dyn>()
|
|
|| R2::is::<Dyn>()
|
|
|| C2::is::<Dyn>()
|
|
|| R3::is::<Dyn>()
|
|
|| C3::is::<Dyn>()
|
|
{
|
|
// matrixmultiply can be used only if the std feature is available.
|
|
let nrows1 = y.nrows();
|
|
let (nrows2, ncols2) = a.shape();
|
|
let (nrows3, ncols3) = b.shape();
|
|
|
|
// Threshold determined empirically.
|
|
const SMALL_DIM: usize = 5;
|
|
|
|
if nrows1 > SMALL_DIM && ncols1 > SMALL_DIM && nrows2 > SMALL_DIM && ncols2 > SMALL_DIM
|
|
{
|
|
assert_eq!(
|
|
ncols2, nrows3,
|
|
"gemm: dimensions mismatch for multiplication."
|
|
);
|
|
assert_eq!(
|
|
(nrows1, ncols1),
|
|
(nrows2, ncols3),
|
|
"gemm: dimensions mismatch for addition."
|
|
);
|
|
|
|
// NOTE: this case should never happen because we enter this
|
|
// codepath only when ncols2 > SMALL_DIM. Though we keep this
|
|
// here just in case if in the future we change the conditions to
|
|
// enter this codepath.
|
|
if ncols2 == 0 {
|
|
// NOTE: we can't just always multiply by beta
|
|
// because we documented the guaranty that `self` is
|
|
// never read if `beta` is zero.
|
|
if beta.is_zero() {
|
|
y.apply(|e| Status::init(e, T::zero()));
|
|
} else {
|
|
// SAFETY: this is UB if Status = Uninit
|
|
y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
|
|
}
|
|
return;
|
|
}
|
|
|
|
if TypeId::of::<T>() == TypeId::of::<f32>() {
|
|
let (rsa, csa) = a.strides();
|
|
let (rsb, csb) = b.strides();
|
|
let (rsc, csc) = y.strides();
|
|
|
|
matrixmultiply::sgemm(
|
|
nrows2,
|
|
ncols2,
|
|
ncols3,
|
|
mem::transmute_copy(&alpha),
|
|
a.data.ptr() as *const f32,
|
|
rsa as isize,
|
|
csa as isize,
|
|
b.data.ptr() as *const f32,
|
|
rsb as isize,
|
|
csb as isize,
|
|
mem::transmute_copy(&beta),
|
|
y.data.ptr_mut() as *mut f32,
|
|
rsc as isize,
|
|
csc as isize,
|
|
);
|
|
return;
|
|
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
|
|
let (rsa, csa) = a.strides();
|
|
let (rsb, csb) = b.strides();
|
|
let (rsc, csc) = y.strides();
|
|
|
|
matrixmultiply::dgemm(
|
|
nrows2,
|
|
ncols2,
|
|
ncols3,
|
|
mem::transmute_copy(&alpha),
|
|
a.data.ptr() as *const f64,
|
|
rsa as isize,
|
|
csa as isize,
|
|
b.data.ptr() as *const f64,
|
|
rsb as isize,
|
|
csb as isize,
|
|
mem::transmute_copy(&beta),
|
|
y.data.ptr_mut() as *mut f64,
|
|
rsc as isize,
|
|
csc as isize,
|
|
);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for j1 in 0..ncols1 {
|
|
// TODO: avoid bound checks.
|
|
// SAFETY: this is UB if Status = Uninit && beta != 0
|
|
gemv_uninit(
|
|
status,
|
|
&mut y.column_mut(j1),
|
|
alpha.clone(),
|
|
a,
|
|
&b.column(j1),
|
|
beta.clone(),
|
|
);
|
|
}
|
|
}
|