Tied some blas loose strings
This commit is contained in:
parent
df9b6f5f64
commit
54e9750191
|
@ -6,7 +6,7 @@
|
||||||
//! that return an owned matrix that would otherwise result from setting a
|
//! that return an owned matrix that would otherwise result from setting a
|
||||||
//! parameter to zero in the other methods.
|
//! parameter to zero in the other methods.
|
||||||
|
|
||||||
use crate::{OMatrix, OVector, SimdComplexField};
|
use crate::{OMatrix, SimdComplexField};
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
use matrixmultiply;
|
use matrixmultiply;
|
||||||
use num::{One, Zero};
|
use num::{One, Zero};
|
||||||
|
@ -795,7 +795,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, R1: Dim, C1: Dim> OMatrix<T, R1, C1>
|
impl<T, R1: Dim, C1: Dim, S: StorageMut<MaybeUninit<T>, R1, C1>> Matrix<MaybeUninit<T>, R1, C1, S>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
||||||
DefaultAllocator: Allocator<T, R1, C1>,
|
DefaultAllocator: Allocator<T, R1, C1>,
|
||||||
|
@ -821,27 +821,18 @@ where
|
||||||
/// ```
|
/// ```
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn gemm_z<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
|
pub fn gemm_z<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
|
||||||
|
&mut self,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
a: &Matrix<T, R2, C2, SB>,
|
a: &Matrix<T, R2, C2, SB>,
|
||||||
b: &Matrix<T, R3, C3, SC>,
|
b: &Matrix<T, R3, C3, SC>,
|
||||||
) -> Self
|
) where
|
||||||
where
|
|
||||||
SB: Storage<T, R2, C2>,
|
SB: Storage<T, R2, C2>,
|
||||||
SC: Storage<T, R3, C3>,
|
SC: Storage<T, R3, C3>,
|
||||||
ShapeConstraint: SameNumberOfRows<R1, R2>
|
ShapeConstraint: SameNumberOfRows<R1, R2>
|
||||||
+ SameNumberOfColumns<C1, C3>
|
+ SameNumberOfColumns<C1, C3>
|
||||||
+ AreMultipliable<R2, C2, R3, C3>,
|
+ AreMultipliable<R2, C2, R3, C3>,
|
||||||
{
|
{
|
||||||
let (nrows1, ncols1) = a.shape();
|
let ncols1 = self.ncols();
|
||||||
let (nrows2, ncols2) = b.shape();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
ncols1, nrows2,
|
|
||||||
"gemm: dimensions mismatch for multiplication."
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut res =
|
|
||||||
Matrix::new_uninitialized_generic(R1::from_usize(nrows1), C1::from_usize(ncols2));
|
|
||||||
|
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
{
|
{
|
||||||
|
@ -857,6 +848,9 @@ where
|
||||||
|| C3::is::<Dynamic>()
|
|| C3::is::<Dynamic>()
|
||||||
{
|
{
|
||||||
// matrixmultiply can be used only if the std feature is available.
|
// matrixmultiply can be used only if the std feature is available.
|
||||||
|
let nrows1 = self.nrows();
|
||||||
|
let (nrows2, ncols2) = a.shape();
|
||||||
|
let (nrows3, ncols3) = b.shape();
|
||||||
|
|
||||||
// Threshold determined empirically.
|
// Threshold determined empirically.
|
||||||
const SMALL_DIM: usize = 5;
|
const SMALL_DIM: usize = 5;
|
||||||
|
@ -866,29 +860,35 @@ where
|
||||||
&& nrows2 > SMALL_DIM
|
&& nrows2 > SMALL_DIM
|
||||||
&& ncols2 > SMALL_DIM
|
&& ncols2 > SMALL_DIM
|
||||||
{
|
{
|
||||||
|
assert_eq!(
|
||||||
|
ncols1, nrows2,
|
||||||
|
"gemm: dimensions mismatch for multiplication."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
(nrows1, ncols1),
|
||||||
|
(nrows2, ncols3),
|
||||||
|
"gemm: dimensions mismatch for addition."
|
||||||
|
);
|
||||||
|
|
||||||
// NOTE: this case should never happen because we enter this
|
// NOTE: this case should never happen because we enter this
|
||||||
// codepath only when ncols2 > SMALL_DIM. Though we keep this
|
// codepath only when ncols2 > SMALL_DIM. Though we keep this
|
||||||
// here just in case if in the future we change the conditions to
|
// here just in case if in the future we change the conditions to
|
||||||
// enter this codepath.
|
// enter this codepath.
|
||||||
if ncols1 == 0 {
|
if ncols1 == 0 {
|
||||||
// NOTE: we can't just always multiply by beta
|
self.fill_fn(|| MaybeUninit::new(T::zero()));
|
||||||
// because we documented the guaranty that `self` is
|
return;
|
||||||
// never read if `beta` is zero.
|
|
||||||
|
|
||||||
// Safety: this buffer is empty.
|
|
||||||
return res.assume_init();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let (rsa, csa) = a.strides();
|
let (rsa, csa) = a.strides();
|
||||||
let (rsb, csb) = b.strides();
|
let (rsb, csb) = b.strides();
|
||||||
let (rsc, csc) = res.strides();
|
let (rsc, csc) = self.strides();
|
||||||
|
|
||||||
if T::is::<f32>() {
|
if T::is::<f32>() {
|
||||||
unsafe {
|
unsafe {
|
||||||
matrixmultiply::sgemm(
|
matrixmultiply::sgemm(
|
||||||
nrows1,
|
nrows2,
|
||||||
ncols1,
|
|
||||||
ncols2,
|
ncols2,
|
||||||
|
ncols3,
|
||||||
mem::transmute_copy(&alpha),
|
mem::transmute_copy(&alpha),
|
||||||
a.data.ptr() as *const f32,
|
a.data.ptr() as *const f32,
|
||||||
rsa as isize,
|
rsa as isize,
|
||||||
|
@ -897,19 +897,19 @@ where
|
||||||
rsb as isize,
|
rsb as isize,
|
||||||
csb as isize,
|
csb as isize,
|
||||||
0.0,
|
0.0,
|
||||||
res.data.ptr_mut() as *mut f32,
|
self.data.ptr_mut() as *mut f32,
|
||||||
rsc as isize,
|
rsc as isize,
|
||||||
csc as isize,
|
csc as isize,
|
||||||
);
|
);
|
||||||
|
|
||||||
return res.assume_init();
|
return;
|
||||||
}
|
}
|
||||||
} else if T::is::<f64>() {
|
} else if T::is::<f64>() {
|
||||||
unsafe {
|
unsafe {
|
||||||
matrixmultiply::dgemm(
|
matrixmultiply::dgemm(
|
||||||
nrows1,
|
nrows2,
|
||||||
ncols1,
|
|
||||||
ncols2,
|
ncols2,
|
||||||
|
ncols3,
|
||||||
mem::transmute_copy(&alpha),
|
mem::transmute_copy(&alpha),
|
||||||
a.data.ptr() as *const f64,
|
a.data.ptr() as *const f64,
|
||||||
rsa as isize,
|
rsa as isize,
|
||||||
|
@ -918,12 +918,12 @@ where
|
||||||
rsb as isize,
|
rsb as isize,
|
||||||
csb as isize,
|
csb as isize,
|
||||||
0.0,
|
0.0,
|
||||||
res.data.ptr_mut() as *mut f64,
|
self.data.ptr_mut() as *mut f64,
|
||||||
rsc as isize,
|
rsc as isize,
|
||||||
csc as isize,
|
csc as isize,
|
||||||
);
|
);
|
||||||
|
|
||||||
return res.assume_init();
|
return ;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -932,11 +932,9 @@ where
|
||||||
|
|
||||||
for j1 in 0..ncols1 {
|
for j1 in 0..ncols1 {
|
||||||
// TODO: avoid bound checks.
|
// TODO: avoid bound checks.
|
||||||
res.column_mut(j1)
|
self.column_mut(j1)
|
||||||
.gemv_z(alpha.inlined_clone(), a, &b.column(j1));
|
.gemv_z(alpha.inlined_clone(), a, &b.column(j1));
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe { res.assume_init() }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -633,13 +633,13 @@ where
|
||||||
); // Arguments for non-generic constructors.
|
); // Arguments for non-generic constructors.
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, R: DimName, C: DimName> OMatrix<MaybeUninit<T>, R, C>
|
impl<T, R: DimName, C: DimName> OMatrix<T, R, C>
|
||||||
where
|
where
|
||||||
DefaultAllocator: Allocator<T, R, C>,
|
DefaultAllocator: Allocator<T, R, C>,
|
||||||
{
|
{
|
||||||
/// Creates a new uninitialized matrix or vector.
|
/// Creates a new uninitialized matrix or vector.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new_uninitialized() -> Self {
|
pub fn new_uninitialized() -> OMatrix<MaybeUninit<T>, R, C> {
|
||||||
Self::new_uninitialized_generic(R::name(), C::name())
|
Self::new_uninitialized_generic(R::name(), C::name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -655,13 +655,13 @@ where
|
||||||
ncols);
|
ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, R: DimName> OMatrix<MaybeUninit<T>, R, Dynamic>
|
impl<T, R: DimName> OMatrix<T, R, Dynamic>
|
||||||
where
|
where
|
||||||
DefaultAllocator: Allocator<T, R, Dynamic>,
|
DefaultAllocator: Allocator<T, R, Dynamic>,
|
||||||
{
|
{
|
||||||
/// Creates a new uninitialized matrix or vector.
|
/// Creates a new uninitialized matrix or vector.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new_uninitialized(ncols: usize) -> Self {
|
pub fn new_uninitialized(ncols: usize) -> OMatrix<MaybeUninit<T>, R, Dynamic> {
|
||||||
Self::new_uninitialized_generic(R::name(), Dynamic::new(ncols))
|
Self::new_uninitialized_generic(R::name(), Dynamic::new(ncols))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -677,13 +677,13 @@ where
|
||||||
nrows);
|
nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, C: DimName> OMatrix<MaybeUninit<T>, Dynamic, C>
|
impl<T, C: DimName> OMatrix<T, Dynamic, C>
|
||||||
where
|
where
|
||||||
DefaultAllocator: Allocator<T, Dynamic, C>,
|
DefaultAllocator: Allocator<T, Dynamic, C>,
|
||||||
{
|
{
|
||||||
/// Creates a new uninitialized matrix or vector.
|
/// Creates a new uninitialized matrix or vector.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new_uninitialized(nrows: usize) -> Self {
|
pub fn new_uninitialized(nrows: usize) -> OMatrix<MaybeUninit<T>, Dynamic, C> {
|
||||||
Self::new_uninitialized_generic(Dynamic::new(nrows), C::name())
|
Self::new_uninitialized_generic(Dynamic::new(nrows), C::name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -699,13 +699,13 @@ where
|
||||||
nrows, ncols);
|
nrows, ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> OMatrix<MaybeUninit<T>, Dynamic, Dynamic>
|
impl<T> OMatrix<T, Dynamic, Dynamic>
|
||||||
where
|
where
|
||||||
DefaultAllocator: Allocator<T, Dynamic, Dynamic>,
|
DefaultAllocator: Allocator<T, Dynamic, Dynamic>,
|
||||||
{
|
{
|
||||||
/// Creates a new uninitialized matrix or vector.
|
/// Creates a new uninitialized matrix or vector.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new_uninitialized(nrows: usize, ncols: usize) -> Self {
|
pub fn new_uninitialized(nrows: usize, ncols: usize) -> OMatrix<MaybeUninit<T>, Dynamic, Dynamic> {
|
||||||
Self::new_uninitialized_generic(Dynamic::new(nrows), Dynamic::new(ncols))
|
Self::new_uninitialized_generic(Dynamic::new(nrows), Dynamic::new(ncols))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
//! heap-allocated buffers for matrices with at least one dimension unknown at compile-time.
|
//! heap-allocated buffers for matrices with at least one dimension unknown at compile-time.
|
||||||
|
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
use std::mem;
|
|
||||||
use std::mem::ManuallyDrop;
|
use std::mem::ManuallyDrop;
|
||||||
use std::mem::MaybeUninit;
|
use std::mem::MaybeUninit;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
|
|
|
@ -53,7 +53,8 @@ impl<T: Scalar + Zero, R: Dim, C: Dim, S: Storage<T, R, C>> Matrix<T, R, C, S> {
|
||||||
{
|
{
|
||||||
let irows = irows.into_iter();
|
let irows = irows.into_iter();
|
||||||
let ncols = self.data.shape().1;
|
let ncols = self.data.shape().1;
|
||||||
let mut res = OMatrix::<T, Dynamic, C>::new_uninitialized_generic(Dynamic::new(irows.len()), ncols);
|
let mut res =
|
||||||
|
OMatrix::<T, Dynamic, C>::new_uninitialized_generic(Dynamic::new(irows.len()), ncols);
|
||||||
|
|
||||||
// First, check that all the indices from irows are valid.
|
// First, check that all the indices from irows are valid.
|
||||||
// This will allow us to use unchecked access in the inner loop.
|
// This will allow us to use unchecked access in the inner loop.
|
||||||
|
|
|
@ -223,7 +223,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
|
||||||
SliceStorage<'a, MaybeUninit<T>, R, C, RStride, CStride>
|
SliceStorage<'a, MaybeUninit<T>, R, C, RStride, CStride>
|
||||||
{
|
{
|
||||||
pub unsafe fn assume_init(self) -> SliceStorage<'a, T, R, C, RStride, CStride> {
|
pub unsafe fn assume_init(self) -> SliceStorage<'a, T, R, C, RStride, CStride> {
|
||||||
Self::from_raw_parts(self.ptr as *const T, self.shape, self.strides)
|
SliceStorage::from_raw_parts(self.ptr as *const T, self.shape, self.strides)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
|
||||||
SliceStorageMut<'a, MaybeUninit<T>, R, C, RStride, CStride>
|
SliceStorageMut<'a, MaybeUninit<T>, R, C, RStride, CStride>
|
||||||
{
|
{
|
||||||
pub unsafe fn assume_init(self) -> SliceStorageMut<'a, T, R, C, RStride, CStride> {
|
pub unsafe fn assume_init(self) -> SliceStorageMut<'a, T, R, C, RStride, CStride> {
|
||||||
Self::from_raw_parts(self.ptr as *mut T, self.shape, self.strides)
|
SliceStorageMut::from_raw_parts(self.ptr as *mut T, self.shape, self.strides)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -606,7 +606,7 @@ macro_rules! matrix_slice_impl(
|
||||||
/// Returns a slice containing the entire matrix.
|
/// Returns a slice containing the entire matrix.
|
||||||
pub fn $full_slice($me: $Me) -> $MatrixSlice<T, R, C, S::RStride, S::CStride> {
|
pub fn $full_slice($me: $Me) -> $MatrixSlice<T, R, C, S::RStride, S::CStride> {
|
||||||
let (nrows, ncols) = $me.shape();
|
let (nrows, ncols) = $me.shape();
|
||||||
$me.generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols)))
|
$me.$generic_slice((0, 0), (R::from_usize(nrows), C::from_usize(ncols)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -640,7 +640,8 @@ where
|
||||||
// TODO: this is too restrictive:
|
// TODO: this is too restrictive:
|
||||||
// − we can't use `a *= b` when `a` is a mutable slice.
|
// − we can't use `a *= b` when `a` is a mutable slice.
|
||||||
// − we can't use `a *= b` when C2 is not equal to C1.
|
// − we can't use `a *= b` when C2 is not equal to C1.
|
||||||
impl<T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<Matrix<T, R2, C1, SB>> for Matrix<T, R1, C1, SA>
|
impl<T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<Matrix<T, R2, C1, SB>>
|
||||||
|
for Matrix<T, R1, C1, SA>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
||||||
SB: Storage<T, R2, C1>,
|
SB: Storage<T, R2, C1>,
|
||||||
|
@ -654,7 +655,8 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix<T, R2, C1, SB>> for Matrix<T, R1, C1, SA>
|
impl<'b, T, R1: Dim, C1: Dim, R2: Dim, SA, SB> MulAssign<&'b Matrix<T, R2, C1, SB>>
|
||||||
|
for Matrix<T, R1, C1, SA>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
|
||||||
SB: Storage<T, R2, C1>,
|
SB: Storage<T, R2, C1>,
|
||||||
|
@ -794,6 +796,7 @@ where
|
||||||
ShapeConstraint: SameNumberOfRows<R3, R1>
|
ShapeConstraint: SameNumberOfRows<R3, R1>
|
||||||
+ SameNumberOfColumns<C3, C2>
|
+ SameNumberOfColumns<C3, C2>
|
||||||
+ AreMultipliable<R1, C1, R2, C2>,
|
+ AreMultipliable<R1, C1, R2, C2>,
|
||||||
|
DefaultAllocator: Allocator<T, R3, C3>,
|
||||||
{
|
{
|
||||||
out.gemm_z(T::one(), self, rhs);
|
out.gemm_z(T::one(), self, rhs);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue