replace matrixmultiply with gemm

This commit is contained in:
sarah 2022-11-06 15:33:11 +01:00
parent 6ded8db479
commit 6be21fb294
2 changed files with 21 additions and 42 deletions

View File

@ -23,7 +23,7 @@ path = "src/lib.rs"
[features]
default = [ "std", "macros" ]
std = [ "matrixmultiply", "simba/std" ]
std = [ "gemm", "simba/std" ]
sparse = [ ]
debug = [ "approx/num-complex", "rand" ]
alloc = [ ]
@ -80,7 +80,7 @@ approx = { version = "0.5", default-features = false }
simba = { version = "0.7", default-features = false }
alga = { version = "0.9", default-features = false, optional = true }
rand_distr = { version = "0.4", default-features = false, optional = true }
matrixmultiply = { version = "0.3", optional = true }
gemm = { version = "0.11", optional = true }
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
rkyv = { version = "~0.7.1", optional = true }
bytecheck = { version = "~0.6.1", optional = true }

View File

@ -9,11 +9,9 @@
*/
#[cfg(feature = "std")]
use matrixmultiply;
use gemm;
use num::{One, Zero};
use simba::scalar::{ClosedAdd, ClosedMul};
#[cfg(feature = "std")]
use std::mem;
use crate::base::constraint::{
AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
@ -210,7 +208,7 @@ pub unsafe fn gemm_uninit<
#[cfg(feature = "std")]
{
// We assume large matrices will be Dynamic but small matrices static.
// We could use matrixmultiply for large statically-sized matrices but the performance
// We could use gemm for large statically-sized matrices but the performance
// threshold to activate it would be different from SMALL_DIM because our code optimizes
// better for statically-sized matrices.
if R1::is::<Dynamic>()
@ -220,7 +218,7 @@ pub unsafe fn gemm_uninit<
|| R3::is::<Dynamic>()
|| C3::is::<Dynamic>()
{
// matrixmultiply can be used only if the std feature is available.
// gemm can be used only if the std feature is available.
let nrows1 = y.nrows();
let (nrows2, ncols2) = a.shape();
let (nrows3, ncols3) = b.shape();
@ -257,48 +255,29 @@ pub unsafe fn gemm_uninit<
return;
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let id = TypeId::of::<T>();
if id == TypeId::of::<f32>() || id == TypeId::of::<f64>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = y.strides();
matrixmultiply::sgemm(
nrows2,
gemm::gemm::<T>(
nrows1,
ncols1,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f32,
rsa as isize,
csa as isize,
b.data.ptr() as *const f32,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
y.data.ptr_mut() as *mut f32,
rsc as isize,
y.data.ptr_mut() as *mut T,
csc as isize,
);
return;
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
let (rsc, csc) = y.strides();
matrixmultiply::dgemm(
nrows2,
ncols2,
ncols3,
mem::transmute_copy(&alpha),
a.data.ptr() as *const f64,
rsa as isize,
csa as isize,
b.data.ptr() as *const f64,
rsb as isize,
csb as isize,
mem::transmute_copy(&beta),
y.data.ptr_mut() as *mut f64,
rsc as isize,
csc as isize,
!beta.is_zero(),
a.data.ptr(),
csa as isize,
rsa as isize,
b.data.ptr(),
csb as isize,
rsb as isize,
beta,
alpha,
gemm::Parallelism::None,
);
return;
}