/* * This file implements some BLAS operations in such a way that they work * even if the first argument (the output parameter) is an uninitialized matrix. * * Because doing this makes the code harder to read, we only implemented the operations that we * know would benefit from this performance-wise, namely, GEMM (which we use for our matrix * multiplication code). If we identify other operations like that in the future, we could add * them here. */ #[cfg(feature = "std")] use matrixmultiply; use num::{One, Zero}; use simba::scalar::{ClosedAdd, ClosedMul}; #[cfg(feature = "std")] use std::mem; use crate::base::constraint::{ AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint, }; use crate::base::dimension::{Dim, Dyn, U1}; use crate::base::storage::{RawStorage, RawStorageMut}; use crate::base::uninit::InitStatus; use crate::base::{Matrix, Scalar, Vector}; use std::any::TypeId; // # Safety // The content of `y` must only contain values for which // `Status::assume_init_mut` is sound. #[allow(clippy::too_many_arguments)] unsafe fn array_axcpy( _: Status, y: &mut [Status::Value], a: T, x: &[T], c: T, beta: T, stride1: usize, stride2: usize, len: usize, ) where Status: InitStatus, T: Scalar + Zero + ClosedAdd + ClosedMul, { for i in 0..len { let y = Status::assume_init_mut(y.get_unchecked_mut(i * stride1)); *y = a.clone() * x.get_unchecked(i * stride2).clone() * c.clone() + beta.clone() * y.clone(); } } fn array_axc( _: Status, y: &mut [Status::Value], a: T, x: &[T], c: T, stride1: usize, stride2: usize, len: usize, ) where Status: InitStatus, T: Scalar + Zero + ClosedAdd + ClosedMul, { for i in 0..len { unsafe { Status::init( y.get_unchecked_mut(i * stride1), a.clone() * x.get_unchecked(i * stride2).clone() * c.clone(), ); } } } /// Computes `y = a * x * c + b * y`. /// /// If `b` is zero, `y` is never read from and may be uninitialized. /// /// # Safety /// This is UB if b != 0 and any component of `y` is uninitialized. #[inline(always)] #[allow(clippy::many_single_char_names)] pub unsafe fn axcpy_uninit( status: Status, y: &mut Vector, a: T, x: &Vector, c: T, b: T, ) where T: Scalar + Zero + ClosedAdd + ClosedMul, SA: RawStorageMut, SB: RawStorage, ShapeConstraint: DimEq, Status: InitStatus, { assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes."); let rstride1 = y.strides().0; let rstride2 = x.strides().0; // SAFETY: the conversion to slices is OK because we access the // elements taking the strides into account. let y = y.data.as_mut_slice_unchecked(); let x = x.data.as_slice_unchecked(); if !b.is_zero() { array_axcpy(status, y, a, x, c, b, rstride1, rstride2, x.len()); } else { array_axc(status, y, a, x, c, rstride1, rstride2, x.len()); } } /// Computes `y = alpha * a * x + beta * y`, where `a` is a matrix, `x` a vector, and /// `alpha, beta` two scalars. /// /// If `beta` is zero, `y` is never read from and may be uninitialized. /// /// # Safety /// This is UB if beta != 0 and any component of `y` is uninitialized. #[inline(always)] pub unsafe fn gemv_uninit( status: Status, y: &mut Vector, alpha: T, a: &Matrix, x: &Vector, beta: T, ) where Status: InitStatus, T: Scalar + Zero + One + ClosedAdd + ClosedMul, SA: RawStorageMut, SB: RawStorage, SC: RawStorage, ShapeConstraint: DimEq + AreMultipliable, { let dim1 = y.nrows(); let (nrows2, ncols2) = a.shape(); let dim3 = x.nrows(); assert!( ncols2 == dim3 && dim1 == nrows2, "Gemv: dimensions mismatch." ); if ncols2 == 0 { if beta.is_zero() { y.apply(|e| Status::init(e, T::zero())); } else { // SAFETY: this is UB if y is uninitialized. y.apply(|e| *Status::assume_init_mut(e) *= beta.clone()); } return; } // TODO: avoid bound checks. let col2 = a.column(0); let val = x.vget_unchecked(0).clone(); // SAFETY: this is the call that makes this method unsafe: it is UB if Status = Uninit and beta != 0. axcpy_uninit(status, y, alpha.clone(), &col2, val, beta); for j in 1..ncols2 { let col2 = a.column(j); let val = x.vget_unchecked(j).clone(); // SAFETY: safe because y was initialized above. axcpy_uninit(status, y, alpha.clone(), &col2, val, T::one()); } } /// Computes `y = alpha * a * b + beta * y`, where `a, b, y` are matrices. /// `alpha` and `beta` are scalar. /// /// If `beta` is zero, `y` is never read from and may be uninitialized. /// /// # Safety /// This is UB if beta != 0 and any component of `y` is uninitialized. #[inline(always)] pub unsafe fn gemm_uninit< Status, T, R1: Dim, C1: Dim, R2: Dim, C2: Dim, R3: Dim, C3: Dim, SA, SB, SC, >( status: Status, y: &mut Matrix, alpha: T, a: &Matrix, b: &Matrix, beta: T, ) where Status: InitStatus, T: Scalar + Zero + One + ClosedAdd + ClosedMul, SA: RawStorageMut, SB: RawStorage, SC: RawStorage, ShapeConstraint: SameNumberOfRows + SameNumberOfColumns + AreMultipliable, { let ncols1 = y.ncols(); #[cfg(feature = "std")] { // We assume large matrices will be Dyn but small matrices static. // We could use matrixmultiply for large statically-sized matrices but the performance // threshold to activate it would be different from SMALL_DIM because our code optimizes // better for statically-sized matrices. if R1::is::() || C1::is::() || R2::is::() || C2::is::() || R3::is::() || C3::is::() { // matrixmultiply can be used only if the std feature is available. let nrows1 = y.nrows(); let (nrows2, ncols2) = a.shape(); let (nrows3, ncols3) = b.shape(); // Threshold determined empirically. const SMALL_DIM: usize = 5; if nrows1 > SMALL_DIM && ncols1 > SMALL_DIM && nrows2 > SMALL_DIM && ncols2 > SMALL_DIM { assert_eq!( ncols2, nrows3, "gemm: dimensions mismatch for multiplication." ); assert_eq!( (nrows1, ncols1), (nrows2, ncols3), "gemm: dimensions mismatch for addition." ); // NOTE: this case should never happen because we enter this // codepath only when ncols2 > SMALL_DIM. Though we keep this // here just in case if in the future we change the conditions to // enter this codepath. if ncols2 == 0 { // NOTE: we can't just always multiply by beta // because we documented the guaranty that `self` is // never read if `beta` is zero. if beta.is_zero() { y.apply(|e| Status::init(e, T::zero())); } else { // SAFETY: this is UB if Status = Uninit y.apply(|e| *Status::assume_init_mut(e) *= beta.clone()); } return; } if TypeId::of::() == TypeId::of::() { let (rsa, csa) = a.strides(); let (rsb, csb) = b.strides(); let (rsc, csc) = y.strides(); matrixmultiply::sgemm( nrows2, ncols2, ncols3, mem::transmute_copy(&alpha), a.data.ptr() as *const f32, rsa as isize, csa as isize, b.data.ptr() as *const f32, rsb as isize, csb as isize, mem::transmute_copy(&beta), y.data.ptr_mut() as *mut f32, rsc as isize, csc as isize, ); return; } else if TypeId::of::() == TypeId::of::() { let (rsa, csa) = a.strides(); let (rsb, csb) = b.strides(); let (rsc, csc) = y.strides(); matrixmultiply::dgemm( nrows2, ncols2, ncols3, mem::transmute_copy(&alpha), a.data.ptr() as *const f64, rsa as isize, csa as isize, b.data.ptr() as *const f64, rsb as isize, csb as isize, mem::transmute_copy(&beta), y.data.ptr_mut() as *mut f64, rsc as isize, csc as isize, ); return; } } } } for j1 in 0..ncols1 { // TODO: avoid bound checks. // SAFETY: this is UB if Status = Uninit && beta != 0 gemv_uninit( status, &mut y.column_mut(j1), alpha.clone(), a, &b.column(j1), beta.clone(), ); } }