replace matrixmultiply with gemm
This commit is contained in:
parent
6ded8db479
commit
6be21fb294
|
@ -23,7 +23,7 @@ path = "src/lib.rs"
|
|||
|
||||
[features]
|
||||
default = [ "std", "macros" ]
|
||||
std = [ "matrixmultiply", "simba/std" ]
|
||||
std = [ "gemm", "simba/std" ]
|
||||
sparse = [ ]
|
||||
debug = [ "approx/num-complex", "rand" ]
|
||||
alloc = [ ]
|
||||
|
@ -80,7 +80,7 @@ approx = { version = "0.5", default-features = false }
|
|||
simba = { version = "0.7", default-features = false }
|
||||
alga = { version = "0.9", default-features = false, optional = true }
|
||||
rand_distr = { version = "0.4", default-features = false, optional = true }
|
||||
matrixmultiply = { version = "0.3", optional = true }
|
||||
gemm = { version = "0.11", optional = true }
|
||||
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
|
||||
rkyv = { version = "~0.7.1", optional = true }
|
||||
bytecheck = { version = "~0.6.1", optional = true }
|
||||
|
|
|
@ -9,11 +9,9 @@
|
|||
*/
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use matrixmultiply;
|
||||
use gemm;
|
||||
use num::{One, Zero};
|
||||
use simba::scalar::{ClosedAdd, ClosedMul};
|
||||
#[cfg(feature = "std")]
|
||||
use std::mem;
|
||||
|
||||
use crate::base::constraint::{
|
||||
AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
|
||||
|
@ -210,7 +208,7 @@ pub unsafe fn gemm_uninit<
|
|||
#[cfg(feature = "std")]
|
||||
{
|
||||
// We assume large matrices will be Dynamic but small matrices static.
|
||||
// We could use matrixmultiply for large statically-sized matrices but the performance
|
||||
// We could use gemm for large statically-sized matrices but the performance
|
||||
// threshold to activate it would be different from SMALL_DIM because our code optimizes
|
||||
// better for statically-sized matrices.
|
||||
if R1::is::<Dynamic>()
|
||||
|
@ -220,7 +218,7 @@ pub unsafe fn gemm_uninit<
|
|||
|| R3::is::<Dynamic>()
|
||||
|| C3::is::<Dynamic>()
|
||||
{
|
||||
// matrixmultiply can be used only if the std feature is available.
|
||||
// gemm can be used only if the std feature is available.
|
||||
let nrows1 = y.nrows();
|
||||
let (nrows2, ncols2) = a.shape();
|
||||
let (nrows3, ncols3) = b.shape();
|
||||
|
@ -257,48 +255,29 @@ pub unsafe fn gemm_uninit<
|
|||
return;
|
||||
}
|
||||
|
||||
if TypeId::of::<T>() == TypeId::of::<f32>() {
|
||||
let id = TypeId::of::<T>();
|
||||
if id == TypeId::of::<f32>() || id == TypeId::of::<f64>() {
|
||||
let (rsa, csa) = a.strides();
|
||||
let (rsb, csb) = b.strides();
|
||||
let (rsc, csc) = y.strides();
|
||||
|
||||
matrixmultiply::sgemm(
|
||||
nrows2,
|
||||
gemm::gemm::<T>(
|
||||
nrows1,
|
||||
ncols1,
|
||||
ncols2,
|
||||
ncols3,
|
||||
mem::transmute_copy(&alpha),
|
||||
a.data.ptr() as *const f32,
|
||||
rsa as isize,
|
||||
csa as isize,
|
||||
b.data.ptr() as *const f32,
|
||||
rsb as isize,
|
||||
csb as isize,
|
||||
mem::transmute_copy(&beta),
|
||||
y.data.ptr_mut() as *mut f32,
|
||||
rsc as isize,
|
||||
y.data.ptr_mut() as *mut T,
|
||||
csc as isize,
|
||||
);
|
||||
return;
|
||||
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
|
||||
let (rsa, csa) = a.strides();
|
||||
let (rsb, csb) = b.strides();
|
||||
let (rsc, csc) = y.strides();
|
||||
|
||||
matrixmultiply::dgemm(
|
||||
nrows2,
|
||||
ncols2,
|
||||
ncols3,
|
||||
mem::transmute_copy(&alpha),
|
||||
a.data.ptr() as *const f64,
|
||||
rsa as isize,
|
||||
csa as isize,
|
||||
b.data.ptr() as *const f64,
|
||||
rsb as isize,
|
||||
csb as isize,
|
||||
mem::transmute_copy(&beta),
|
||||
y.data.ptr_mut() as *mut f64,
|
||||
rsc as isize,
|
||||
csc as isize,
|
||||
!beta.is_zero(),
|
||||
a.data.ptr(),
|
||||
csa as isize,
|
||||
rsa as isize,
|
||||
b.data.ptr(),
|
||||
csb as isize,
|
||||
rsb as isize,
|
||||
beta,
|
||||
alpha,
|
||||
gemm::Parallelism::None,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue