diff --git a/Cargo.toml b/Cargo.toml index bcfa07d2..1304e7ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ path = "src/lib.rs" [features] default = [ "std", "macros" ] -std = [ "matrixmultiply", "simba/std" ] +std = [ "gemm", "simba/std" ] sparse = [ ] debug = [ "approx/num-complex", "rand" ] alloc = [ ] @@ -80,7 +80,7 @@ approx = { version = "0.5", default-features = false } simba = { version = "0.7", default-features = false } alga = { version = "0.9", default-features = false, optional = true } rand_distr = { version = "0.4", default-features = false, optional = true } -matrixmultiply = { version = "0.3", optional = true } +gemm = { version = "0.11", optional = true } serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true } rkyv = { version = "~0.7.1", optional = true } bytecheck = { version = "~0.6.1", optional = true } diff --git a/src/base/blas_uninit.rs b/src/base/blas_uninit.rs index 7e449d7d..5c56870d 100644 --- a/src/base/blas_uninit.rs +++ b/src/base/blas_uninit.rs @@ -9,11 +9,9 @@ */ #[cfg(feature = "std")] -use matrixmultiply; +use gemm; use num::{One, Zero}; use simba::scalar::{ClosedAdd, ClosedMul}; -#[cfg(feature = "std")] -use std::mem; use crate::base::constraint::{ AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint, @@ -210,7 +208,7 @@ pub unsafe fn gemm_uninit< #[cfg(feature = "std")] { // We assume large matrices will be Dynamic but small matrices static. - // We could use matrixmultiply for large statically-sized matrices but the performance + // We could use gemm for large statically-sized matrices but the performance // threshold to activate it would be different from SMALL_DIM because our code optimizes // better for statically-sized matrices. if R1::is::() @@ -220,7 +218,7 @@ pub unsafe fn gemm_uninit< || R3::is::() || C3::is::() { - // matrixmultiply can be used only if the std feature is available. + // gemm can be used only if the std feature is available. let nrows1 = y.nrows(); let (nrows2, ncols2) = a.shape(); let (nrows3, ncols3) = b.shape(); @@ -257,48 +255,29 @@ pub unsafe fn gemm_uninit< return; } - if TypeId::of::() == TypeId::of::() { + let id = TypeId::of::(); + if id == TypeId::of::() || id == TypeId::of::() { let (rsa, csa) = a.strides(); let (rsb, csb) = b.strides(); let (rsc, csc) = y.strides(); - matrixmultiply::sgemm( - nrows2, + gemm::gemm::( + nrows1, + ncols1, ncols2, - ncols3, - mem::transmute_copy(&alpha), - a.data.ptr() as *const f32, - rsa as isize, - csa as isize, - b.data.ptr() as *const f32, - rsb as isize, - csb as isize, - mem::transmute_copy(&beta), - y.data.ptr_mut() as *mut f32, - rsc as isize, + y.data.ptr_mut() as *mut T, csc as isize, - ); - return; - } else if TypeId::of::() == TypeId::of::() { - let (rsa, csa) = a.strides(); - let (rsb, csb) = b.strides(); - let (rsc, csc) = y.strides(); - - matrixmultiply::dgemm( - nrows2, - ncols2, - ncols3, - mem::transmute_copy(&alpha), - a.data.ptr() as *const f64, - rsa as isize, - csa as isize, - b.data.ptr() as *const f64, - rsb as isize, - csb as isize, - mem::transmute_copy(&beta), - y.data.ptr_mut() as *mut f64, rsc as isize, - csc as isize, + !beta.is_zero(), + a.data.ptr(), + csa as isize, + rsa as isize, + b.data.ptr(), + csb as isize, + rsb as isize, + beta, + alpha, + gemm::Parallelism::None, ); return; }