replace matrixmultiply with gemm
This commit is contained in:
parent
6ded8db479
commit
6be21fb294
|
@ -23,7 +23,7 @@ path = "src/lib.rs"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = [ "std", "macros" ]
|
default = [ "std", "macros" ]
|
||||||
std = [ "matrixmultiply", "simba/std" ]
|
std = [ "gemm", "simba/std" ]
|
||||||
sparse = [ ]
|
sparse = [ ]
|
||||||
debug = [ "approx/num-complex", "rand" ]
|
debug = [ "approx/num-complex", "rand" ]
|
||||||
alloc = [ ]
|
alloc = [ ]
|
||||||
|
@ -80,7 +80,7 @@ approx = { version = "0.5", default-features = false }
|
||||||
simba = { version = "0.7", default-features = false }
|
simba = { version = "0.7", default-features = false }
|
||||||
alga = { version = "0.9", default-features = false, optional = true }
|
alga = { version = "0.9", default-features = false, optional = true }
|
||||||
rand_distr = { version = "0.4", default-features = false, optional = true }
|
rand_distr = { version = "0.4", default-features = false, optional = true }
|
||||||
matrixmultiply = { version = "0.3", optional = true }
|
gemm = { version = "0.11", optional = true }
|
||||||
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
|
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
|
||||||
rkyv = { version = "~0.7.1", optional = true }
|
rkyv = { version = "~0.7.1", optional = true }
|
||||||
bytecheck = { version = "~0.6.1", optional = true }
|
bytecheck = { version = "~0.6.1", optional = true }
|
||||||
|
|
|
@ -9,11 +9,9 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
use matrixmultiply;
|
use gemm;
|
||||||
use num::{One, Zero};
|
use num::{One, Zero};
|
||||||
use simba::scalar::{ClosedAdd, ClosedMul};
|
use simba::scalar::{ClosedAdd, ClosedMul};
|
||||||
#[cfg(feature = "std")]
|
|
||||||
use std::mem;
|
|
||||||
|
|
||||||
use crate::base::constraint::{
|
use crate::base::constraint::{
|
||||||
AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
|
AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
|
||||||
|
@ -210,7 +208,7 @@ pub unsafe fn gemm_uninit<
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
{
|
{
|
||||||
// We assume large matrices will be Dynamic but small matrices static.
|
// We assume large matrices will be Dynamic but small matrices static.
|
||||||
// We could use matrixmultiply for large statically-sized matrices but the performance
|
// We could use gemm for large statically-sized matrices but the performance
|
||||||
// threshold to activate it would be different from SMALL_DIM because our code optimizes
|
// threshold to activate it would be different from SMALL_DIM because our code optimizes
|
||||||
// better for statically-sized matrices.
|
// better for statically-sized matrices.
|
||||||
if R1::is::<Dynamic>()
|
if R1::is::<Dynamic>()
|
||||||
|
@ -220,7 +218,7 @@ pub unsafe fn gemm_uninit<
|
||||||
|| R3::is::<Dynamic>()
|
|| R3::is::<Dynamic>()
|
||||||
|| C3::is::<Dynamic>()
|
|| C3::is::<Dynamic>()
|
||||||
{
|
{
|
||||||
// matrixmultiply can be used only if the std feature is available.
|
// gemm can be used only if the std feature is available.
|
||||||
let nrows1 = y.nrows();
|
let nrows1 = y.nrows();
|
||||||
let (nrows2, ncols2) = a.shape();
|
let (nrows2, ncols2) = a.shape();
|
||||||
let (nrows3, ncols3) = b.shape();
|
let (nrows3, ncols3) = b.shape();
|
||||||
|
@ -257,48 +255,29 @@ pub unsafe fn gemm_uninit<
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if TypeId::of::<T>() == TypeId::of::<f32>() {
|
let id = TypeId::of::<T>();
|
||||||
|
if id == TypeId::of::<f32>() || id == TypeId::of::<f64>() {
|
||||||
let (rsa, csa) = a.strides();
|
let (rsa, csa) = a.strides();
|
||||||
let (rsb, csb) = b.strides();
|
let (rsb, csb) = b.strides();
|
||||||
let (rsc, csc) = y.strides();
|
let (rsc, csc) = y.strides();
|
||||||
|
|
||||||
matrixmultiply::sgemm(
|
gemm::gemm::<T>(
|
||||||
nrows2,
|
nrows1,
|
||||||
|
ncols1,
|
||||||
ncols2,
|
ncols2,
|
||||||
ncols3,
|
y.data.ptr_mut() as *mut T,
|
||||||
mem::transmute_copy(&alpha),
|
|
||||||
a.data.ptr() as *const f32,
|
|
||||||
rsa as isize,
|
|
||||||
csa as isize,
|
|
||||||
b.data.ptr() as *const f32,
|
|
||||||
rsb as isize,
|
|
||||||
csb as isize,
|
|
||||||
mem::transmute_copy(&beta),
|
|
||||||
y.data.ptr_mut() as *mut f32,
|
|
||||||
rsc as isize,
|
|
||||||
csc as isize,
|
csc as isize,
|
||||||
);
|
|
||||||
return;
|
|
||||||
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
|
|
||||||
let (rsa, csa) = a.strides();
|
|
||||||
let (rsb, csb) = b.strides();
|
|
||||||
let (rsc, csc) = y.strides();
|
|
||||||
|
|
||||||
matrixmultiply::dgemm(
|
|
||||||
nrows2,
|
|
||||||
ncols2,
|
|
||||||
ncols3,
|
|
||||||
mem::transmute_copy(&alpha),
|
|
||||||
a.data.ptr() as *const f64,
|
|
||||||
rsa as isize,
|
|
||||||
csa as isize,
|
|
||||||
b.data.ptr() as *const f64,
|
|
||||||
rsb as isize,
|
|
||||||
csb as isize,
|
|
||||||
mem::transmute_copy(&beta),
|
|
||||||
y.data.ptr_mut() as *mut f64,
|
|
||||||
rsc as isize,
|
rsc as isize,
|
||||||
csc as isize,
|
!beta.is_zero(),
|
||||||
|
a.data.ptr(),
|
||||||
|
csa as isize,
|
||||||
|
rsa as isize,
|
||||||
|
b.data.ptr(),
|
||||||
|
csb as isize,
|
||||||
|
rsb as isize,
|
||||||
|
beta,
|
||||||
|
alpha,
|
||||||
|
gemm::Parallelism::None,
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue