Tied some blas loose strings
This commit is contained in:
parent
df9b6f5f64
commit
54e9750191
@ -6,7 +6,7 @@
|
||||
//! that return an owned matrix that would otherwise result from setting a
|
||||
//! parameter to zero in the other methods.
|
||||
|
||||
use crate::{OMatrix, OVector, SimdComplexField};
|
||||
use crate::{OMatrix, SimdComplexField};
|
||||
#[cfg(feature = "std")]
|
||||
use matrixmultiply;
|
||||
use num::{One, Zero};
|
||||
@ -795,7 +795,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, R1: Dim, C1: Dim> OMatrix<T, R1, C1>
|
||||
impl<T, R1: Dim, C1: Dim, S: StorageMut<MaybeUninit<T>, R1, C1>> Matrix<MaybeUninit<T>, R1, C1, S>
|
||||
where
|
||||
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
||||
DefaultAllocator: Allocator<T, R1, C1>,
|
||||
@ -821,27 +821,18 @@ where
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn gemm_z<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
|
||||
&mut self,
|
||||
alpha: T,
|
||||
a: &Matrix<T, R2, C2, SB>,
|
||||
b: &Matrix<T, R3, C3, SC>,
|
||||
) -> Self
|
||||
where
|
||||
) where
|
||||
SB: Storage<T, R2, C2>,
|
||||
SC: Storage<T, R3, C3>,
|
||||
ShapeConstraint: SameNumberOfRows<R1, R2>
|
||||
+ SameNumberOfColumns<C1, C3>
|
||||
+ AreMultipliable<R2, C2, R3, C3>,
|
||||
{
|
||||
let (nrows1, ncols1) = a.shape();
|
||||
let (nrows2, ncols2) = b.shape();
|
||||
|
||||
assert_eq!(
|
||||
ncols1, nrows2,
|
||||
"gemm: dimensions mismatch for multiplication."
|
||||
);
|
||||
|
||||
let mut res =
|
||||
Matrix::new_uninitialized_generic(R1::from_usize(nrows1), C1::from_usize(ncols2));
|
||||
let ncols1 = self.ncols();
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
@ -857,6 +848,9 @@ where
|
||||
|| C3::is::<Dynamic>()
|
||||
{
|
||||
// matrixmultiply can be used only if the std feature is available.
|
||||
let nrows1 = self.nrows();
|
||||
let (nrows2, ncols2) = a.shape();
|
||||
let (nrows3, ncols3) = b.shape();
|
||||
|
||||
// Threshold determined empirically.
|
||||
const SMALL_DIM: usize = 5;
|
||||
@ -866,29 +860,35 @@ where
|
||||
&& nrows2 > SMALL_DIM
|
||||
&& ncols2 > SMALL_DIM
|
||||
{
|
||||
assert_eq!(
|
||||
ncols1, nrows2,
|
||||
"gemm: dimensions mismatch for multiplication."
|
||||
);
|
||||
assert_eq!(
|
||||
(nrows1, ncols1),
|
||||
(nrows2, ncols3),
|
||||
"gemm: dimensions mismatch for addition."
|
||||
);
|
||||
|
||||
// NOTE: this case should never happen because we enter this
|
||||
// codepath only when ncols2 > SMALL_DIM. Though we keep this
|
||||
// here just in case if in the future we change the conditions to
|
||||
// enter this codepath.
|
||||
if ncols1 == 0 {
|
||||
// NOTE: we can't just always multiply by beta
|
||||
// because we documented the guaranty that `self` is
|
||||
// never read if `beta` is zero.
|
||||
|
||||
// Safety: this buffer is empty.
|
||||
return res.assume_init();
|
||||
self.fill_fn(|| MaybeUninit::new(T::zero()));
|
||||
return;
|
||||
}
|
||||
|
||||
let (rsa, csa) = a.strides();
|
||||
let (rsb, csb) = b.strides();
|
||||
let (rsc, csc) = res.strides();
|
||||
let (rsc, csc) = self.strides();
|
||||
|
||||
if T::is::<f32>() {
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
nrows1,
|
||||
ncols1,
|
||||
nrows2,
|
||||
ncols2,
|
||||
ncols3,
|
||||
mem::transmute_copy(&alpha),
|
||||
a.data.ptr() as *const f32,
|
||||
rsa as isize,
|
||||
@ -897,19 +897,19 @@ where
|
||||
rsb as isize,
|
||||
csb as isize,
|
||||
0.0,
|
||||
res.data.ptr_mut() as *mut f32,
|
||||
self.data.ptr_mut() as *mut f32,
|
||||
rsc as isize,
|
||||
csc as isize,
|
||||
);
|
||||
|
||||
return res.assume_init();
|
||||
return;
|
||||
}
|
||||
} else if T::is::<f64>() {
|
||||
unsafe {
|
||||
matrixmultiply::dgemm(
|
||||
nrows1,
|
||||
ncols1,
|
||||
nrows2,
|
||||
ncols2,
|
||||
ncols3,
|
||||
mem::transmute_copy(&alpha),
|
||||
a.data.ptr() as *const f64,
|
||||
rsa as isize,
|
||||
@ -918,12 +918,12 @@ where
|
||||
rsb as isize,
|
||||
csb as isize,
|
||||
0.0,
|
||||
res.data.ptr_mut() as *mut f64,
|
||||
self.data.ptr_mut() as *mut f64,
|
||||
rsc as isize,
|
||||
csc as isize,
|
||||
);
|
||||
|
||||
return res.assume_init();
|
||||
return ;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -932,11 +932,9 @@ where
|
||||
|
||||
for j1 in 0..ncols1 {
|
||||
// TODO: avoid bound checks.
|
||||
res.column_mut(j1)
|
||||
self.column_mut(j1)
|
||||
.gemv_z(alpha.inlined_clone(), a, &b.column(j1));
|
||||
}
|
||||
|
||||
unsafe { res.assume_init() }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -633,13 +633,13 @@ where
|
||||
); // Arguments for non-generic constructors.
|
||||
}
|
||||
|
||||
impl<T, R: DimName, C: DimName> OMatrix<MaybeUninit<T>, R, C>
|
||||
impl<T, R: DimName, C: DimName> OMatrix<T, R, C>
|
||||
where
|
||||
DefaultAllocator: Allocator<T, R, C>,
|
||||
{
|
||||
/// Creates a new uninitialized matrix or vector.
|
||||
#[inline]
|
||||
pub fn new_uninitialized() -> Self {
|
||||
pub fn new_uninitialized() -> OMatrix<MaybeUninit<T>, R, C> {
|
||||
Self::new_uninitialized_generic(R::name(), C::name())
|
||||
}
|
||||
}
|
||||
@ -655,13 +655,13 @@ where
|
||||
ncols);
|
||||
}
|
||||
|
||||
impl<T, R: DimName> OMatrix<MaybeUninit<T>, R, Dynamic>
|
||||
impl<T, R: DimName> OMatrix<T, R, Dynamic>
|
||||
where
|
||||
DefaultAllocator: Allocator<T, R, Dynamic>,
|
||||
{
|
||||
/// Creates a new uninitialized matrix or vector.
|
||||
#[inline]
|
||||
pub fn new_uninitialized(ncols: usize) -> Self {
|
||||
pub fn new_uninitialized(ncols: usize) -> OMatrix<MaybeUninit<T>, R, Dynamic> {
|
||||
Self::new_uninitialized_generic(R::name(), Dynamic::new(ncols))
|
||||
}
|
||||
}
|
||||
@ -677,13 +677,13 @@ where
|
||||
nrows);
|
||||
}
|
||||
|
||||
impl<T, C: DimName> OMatrix<MaybeUninit<T>, Dynamic, C>
|
||||
impl<T, C: DimName> OMatrix<T, Dynamic, C>
|
||||
where
|
||||
DefaultAllocator: Allocator<T, Dynamic, C>,
|
||||
{
|
||||
/// Creates a new uninitialized matrix or vector.
|
||||
#[inline]
|
||||
pub fn new_uninitialized(nrows: usize) -> Self {
|
||||
pub fn new_uninitialized(nrows: usize) -> OMatrix<MaybeUninit<T>, Dynamic, C> {
|
||||
Self::new_uninitialized_generic(Dynamic::new(nrows), C::name())
|
||||
}
|
||||
}
|
||||
@ -699,13 +699,13 @@ where
|
||||
nrows, ncols);
|
||||
}
|
||||
|
||||
impl<T> OMatrix<MaybeUninit<T>, Dynamic, Dynamic>
|
||||
impl<T> OMatrix<T, Dynamic, Dynamic>
|
||||
where
|
||||
DefaultAllocator: Allocator<T, Dynamic, Dynamic>,
|
||||
{
|
||||
/// Creates a new uninitialized matrix or vector.
|
||||
#[inline]
|
||||
pub fn new_uninitialized(nrows: usize, ncols: usize) -> Self {
|
||||
pub fn new_uninitialized(nrows: usize, ncols: usize) -> OMatrix<MaybeUninit<T>, Dynamic, Dynamic> {
|
||||
Self::new_uninitialized_generic(Dynamic::new(nrows), Dynamic::new(ncols))
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,6 @@
|
||||
//! heap-allocated buffers for matrices with at least one dimension unknown at compile-time.
|
||||
|
||||
use std::cmp;
|
||||
use std::mem;
|
||||
use std::mem::ManuallyDrop;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::ptr;
|
||||
|
@ -53,7 +53,8 @@ impl<T: Scalar + Zero, R: Dim, C: Dim, S: Storage<T, R, C>> Matrix<T, R, C, S> {
|
||||
{
|
||||
let irows = irows.into_iter();
|
||||
let ncols = self.data.shape().1;
|
||||
let mut res = OMatrix::<T, Dynamic, C>::new_uninitialized_generic(Dynamic::new(irows.len()), ncols);
|
||||
let mut res =
|
||||
OMatrix::<T, Dynamic, C>::new_uninitialized_generic(Dynamic::new(irows.len()), ncols);
|
||||
|
||||
// First, check that all the indices from irows are valid.
|
||||
// This will allow us to use unchecked access in the inner loop.
|
||||
|
@ -223,7 +223,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
|
||||
SliceStorage<'a, MaybeUninit<T>, R, C, RStride, CStride>
|
||||
{
|
||||
pub unsafe fn assume_init(self) -> SliceStorage<'a, T, R, C, RStride, CStride> {
|
||||
Self::from_raw_parts(self.ptr as *const T, self.shape, self.strides)
|
||||
SliceStorage::from_raw_parts(self.ptr as *const T, self.shape, self.strides)
|
||||
}
|
||||
}
|
||||
|
||||
@ -231,7 +231,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
|
||||
SliceStorageMut<'a, MaybeUninit<T>, R, C, RStride, CStride>
|
||||
{
|
||||
pub unsafe fn assume_init(self) -> SliceStorageMut<'a, T, R, C, RStride, CStride> {
|
||||
Self::from_raw_parts(self.ptr as *mut T, self.shape, self.strides)
|
||||
SliceStorageMut::from_raw_parts(self.ptr as *mut T, self.shape, self.strides)
|
||||
}
|
||||
}
|
||||
|
||||
@ -606,7 +606,7 @@ macro_rules! matrix_slice_impl(
|
||||
/// Returns a slice containing the entire matrix.
|
||||
pub fn $full_slice($me: $Me) -> $MatrixSlice<T, R, C, S::RStride, S::CStride> {
|
||||
let (nrows, ncols) = $me.shape();
|
||||
$me.generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols)))
|
||||
$me.$generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols)))
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -640,7 +640,8 @@ where
|
||||
// TODO: this is too restrictive:
|
||||
// − we can't use `a *= b` when `a` is a mutable slice.
|
||||
// − we can't use `a *= b` when C2 is not equal to C1.
|
||||
impl<T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<Matrix<T, R2, C1, SB>> for Matrix<T, R1, C1, SA>
|
||||
impl<T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<Matrix<T, R2, C1, SB>>
|
||||
for Matrix<T, R1, C1, SA>
|
||||
where
|
||||
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
||||
SB: Storage<T, R2, C1>,
|
||||
@ -654,7 +655,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix<T, R2, C1, SB>> for Matrix<T, R1, C1, SA>
|
||||
impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix<T, R2, C1, SB>>
|
||||
for Matrix<T, R1, C1, SA>
|
||||
where
|
||||
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
||||
SB: Storage<T, R2, C1>,
|
||||
@ -794,6 +796,7 @@ where
|
||||
ShapeConstraint: SameNumberOfRows<R3, R1>
|
||||
+ SameNumberOfColumns<C3, C2>
|
||||
+ AreMultipliable<R1, C1, R2, C2>,
|
||||
DefaultAllocator: Allocator<T, R3, C3>,
|
||||
{
|
||||
out.gemm_z(T::one(), self, rhs);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user