nalgebra/src/base/blas_uninit.rs

322 lines
10 KiB
Rust
Raw Normal View History

/*
* This file implements some BLAS operations in such a way that they work
* even if the first argument (the output parameter) is an uninitialized matrix.
*
* Because doing this makes the code harder to read, we only implemented the operations that we
* know would benefit from this performance-wise, namely, GEMM (which we use for our matrix
* multiplication code). If we identify other operations like that in the future, we could add
* them here.
*/
#[cfg(feature = "std")]
use matrixmultiply;
use num::{One, Zero};
use simba::scalar::{ClosedAdd, ClosedMul};
#[cfg(feature = "std")]
use std::mem;
use crate::base::constraint::{
AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
};
2023-01-14 23:22:27 +08:00
use crate::base::dimension::{Dim, Dyn, U1};
use crate::base::storage::{RawStorage, RawStorageMut};
2021-08-04 17:19:57 +08:00
use crate::base::uninit::InitStatus;
use crate::base::{Matrix, Scalar, Vector};
use std::any::TypeId;
// # Safety
// The content of `y` must only contain values for which
// `Status::assume_init_mut` is sound.
#[allow(clippy::too_many_arguments)]
unsafe fn array_axcpy<Status, T>(
_: Status,
y: &mut [Status::Value],
a: T,
x: &[T],
c: T,
beta: T,
stride1: usize,
stride2: usize,
len: usize,
) where
Status: InitStatus<T>,
T: Scalar + Zero + ClosedAdd + ClosedMul,
{
for i in 0..len {
let y = Status::assume_init_mut(y.get_unchecked_mut(i * stride1));
*y =
a.clone() * x.get_unchecked(i * stride2).clone() * c.clone() + beta.clone() * y.clone();
}
}
fn array_axc<Status, T>(
_: Status,
y: &mut [Status::Value],
a: T,
x: &[T],
c: T,
stride1: usize,
stride2: usize,
len: usize,
) where
Status: InitStatus<T>,
T: Scalar + Zero + ClosedAdd + ClosedMul,
{
for i in 0..len {
unsafe {
Status::init(
y.get_unchecked_mut(i * stride1),
a.clone() * x.get_unchecked(i * stride2).clone() * c.clone(),
);
}
}
}
2021-08-03 23:26:56 +08:00
/// Computes `y = a * x * c + b * y`.
///
2021-08-03 23:26:56 +08:00
/// If `b` is zero, `y` is never read from and may be uninitialized.
///
2021-08-03 23:26:56 +08:00
/// # Safety
2021-08-04 17:19:57 +08:00
/// This is UB if b != 0 and any component of `y` is uninitialized.
#[inline(always)]
#[allow(clippy::many_single_char_names)]
pub unsafe fn axcpy_uninit<Status, T, D1: Dim, D2: Dim, SA, SB>(
status: Status,
y: &mut Vector<Status::Value, D1, SA>,
a: T,
x: &Vector<T, D2, SB>,
c: T,
b: T,
) where
T: Scalar + Zero + ClosedAdd + ClosedMul,
SA: RawStorageMut<Status::Value, D1>,
SB: RawStorage<T, D2>,
ShapeConstraint: DimEq<D1, D2>,
Status: InitStatus<T>,
{
assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
let rstride1 = y.strides().0;
let rstride2 = x.strides().0;
// SAFETY: the conversion to slices is OK because we access the
// elements taking the strides into account.
let y = y.data.as_mut_slice_unchecked();
let x = x.data.as_slice_unchecked();
if !b.is_zero() {
array_axcpy(status, y, a, x, c, b, rstride1, rstride2, x.len());
} else {
array_axc(status, y, a, x, c, rstride1, rstride2, x.len());
}
}
2021-08-03 23:26:56 +08:00
/// Computes `y = alpha * a * x + beta * y`, where `a` is a matrix, `x` a vector, and
/// `alpha, beta` two scalars.
///
2021-08-03 23:26:56 +08:00
/// If `beta` is zero, `y` is never read from and may be uninitialized.
///
2021-08-03 23:26:56 +08:00
/// # Safety
2021-08-04 17:19:57 +08:00
/// This is UB if beta != 0 and any component of `y` is uninitialized.
#[inline(always)]
pub unsafe fn gemv_uninit<Status, T, D1: Dim, R2: Dim, C2: Dim, D3: Dim, SA, SB, SC>(
status: Status,
y: &mut Vector<Status::Value, D1, SA>,
alpha: T,
a: &Matrix<T, R2, C2, SB>,
x: &Vector<T, D3, SC>,
beta: T,
) where
Status: InitStatus<T>,
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
SA: RawStorageMut<Status::Value, D1>,
SB: RawStorage<T, R2, C2>,
SC: RawStorage<T, D3>,
ShapeConstraint: DimEq<D1, R2> + AreMultipliable<R2, C2, D3, U1>,
{
let dim1 = y.nrows();
let (nrows2, ncols2) = a.shape();
let dim3 = x.nrows();
assert!(
ncols2 == dim3 && dim1 == nrows2,
"Gemv: dimensions mismatch."
);
if ncols2 == 0 {
if beta.is_zero() {
y.apply(|e| Status::init(e, T::zero()));
} else {
// SAFETY: this is UB if y is uninitialized.
y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
}
return;
}
// TODO: avoid bound checks.
let col2 = a.column(0);
let val = x.vget_unchecked(0).clone();
// SAFETY: this is the call that makes this method unsafe: it is UB if Status = Uninit and beta != 0.
axcpy_uninit(status, y, alpha.clone(), &col2, val, beta);
for j in 1..ncols2 {
let col2 = a.column(j);
let val = x.vget_unchecked(j).clone();
2021-08-04 17:19:57 +08:00
// SAFETY: safe because y was initialized above.
axcpy_uninit(status, y, alpha.clone(), &col2, val, T::one());
}
}
2021-08-03 23:26:56 +08:00
/// Computes `y = alpha * a * b + beta * y`, where `a, b, y` are matrices.
/// `alpha` and `beta` are scalar.
///
2021-08-03 23:26:56 +08:00
/// If `beta` is zero, `y` is never read from and may be uninitialized.
///
2021-08-03 23:26:56 +08:00
/// # Safety
2021-08-04 17:19:57 +08:00
/// This is UB if beta != 0 and any component of `y` is uninitialized.
#[inline(always)]
pub unsafe fn gemm_uninit<
Status,
T,
R1: Dim,
C1: Dim,
R2: Dim,
C2: Dim,
R3: Dim,
C3: Dim,
SA,
SB,
SC,
>(
status: Status,
y: &mut Matrix<Status::Value, R1, C1, SA>,
alpha: T,
a: &Matrix<T, R2, C2, SB>,
b: &Matrix<T, R3, C3, SC>,
beta: T,
) where
Status: InitStatus<T>,
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
SA: RawStorageMut<Status::Value, R1, C1>,
SB: RawStorage<T, R2, C2>,
SC: RawStorage<T, R3, C3>,
ShapeConstraint:
SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C3> + AreMultipliable<R2, C2, R3, C3>,
{
let ncols1 = y.ncols();
#[cfg(feature = "std")]
{
2023-01-14 23:22:27 +08:00
// We assume large matrices will be Dyn but small matrices static.
// We could use matrixmultiply for large statically-sized matrices but the performance
// threshold to activate it would be different from SMALL_DIM because our code optimizes
// better for statically-sized matrices.
2023-01-14 23:22:27 +08:00
if R1::is::<Dyn>()
|| C1::is::<Dyn>()
|| R2::is::<Dyn>()
|| C2::is::<Dyn>()
|| R3::is::<Dyn>()
|| C3::is::<Dyn>()
{
// matrixmultiply can be used only if the std feature is available.
let nrows1 = y.nrows();
let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape();
// Threshold determined empirically.
const SMALL_DIM: usize = 5;
if nrows1 > SMALL_DIM && ncols1 > SMALL_DIM && nrows2 > SMALL_DIM && ncols2 > SMALL_DIM
{
assert_eq!(
ncols2, nrows3,
"gemm: dimensions mismatch for multiplication."
);
assert_eq!(
(nrows1, ncols1),
(nrows2, ncols3),
"gemm: dimensions mismatch for addition."
);
// NOTE: this case should never happen because we enter this
// codepath only when ncols2 > SMALL_DIM. Though we keep this
// here just in case if in the future we change the conditions to
// enter this codepath.
if ncols2 == 0 {
// NOTE: we can't just always multiply by beta
// because we documented the guaranty that `self` is
// never read if `beta` is zero.
if beta.is_zero() {
y.apply(|e| Status::init(e, T::zero()));
} else {
// SAFETY: this is UB if Status = Uninit
y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
}
return;
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = y.strides();
matrixmultiply::sgemm(
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f32,
rsa as isize,
csa as isize,
b.data.ptr() as *const f32,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
y.data.ptr_mut() as *mut f32,
rsc as isize,
csc as isize,
);
return;
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = y.strides();
matrixmultiply::dgemm(
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f64,
rsa as isize,
csa as isize,
b.data.ptr() as *const f64,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
y.data.ptr_mut() as *mut f64,
rsc as isize,
csc as isize,
);
return;
}
}
}
}
for j1 in 0..ncols1 {
// TODO: avoid bound checks.
// SAFETY: this is UB if Status = Uninit && beta != 0
gemv_uninit(
status,
&mut y.column_mut(j1),
alpha.clone(),
a,
&b.column(j1),
beta.clone(),
);
}
}