remove the checked suffix to keep backward compatibility

This commit is contained in:
Saurabh 2022-04-01 15:26:36 -06:00
parent 2606409a02
commit e3fd0e7393
3 changed files with 19 additions and 21 deletions

View File

@ -63,7 +63,7 @@ where
Ok(()) Ok(())
} }
pub fn spmm_cs_prealloc_checked<T>( pub fn spmm_cs_prealloc<T>(
beta: T, beta: T,
c: &mut CsMatrix<T>, c: &mut CsMatrix<T>,
alpha: T, alpha: T,

View File

@ -1,6 +1,6 @@
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::ops::serial::cs::{ use crate::ops::serial::cs::{
spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc_checked, spmm_cs_prealloc_unchecked, spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc, spmm_cs_prealloc_unchecked,
}; };
use crate::ops::serial::{OperationError, OperationErrorKind}; use crate::ops::serial::{OperationError, OperationErrorKind};
use crate::ops::Op; use crate::ops::Op;
@ -73,7 +73,7 @@ where
/// # Panics /// # Panics
/// ///
/// Panics if the dimensions of the matrices involved are not compatible with the expression. /// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spmm_csc_prealloc_checked<T>( pub fn spmm_csc_prealloc<T>(
beta: T, beta: T,
c: &mut CscMatrix<T>, c: &mut CscMatrix<T>,
alpha: T, alpha: T,
@ -90,9 +90,9 @@ where
match (&a, &b) { match (&a, &b) {
(NoOp(ref a), NoOp(ref b)) => { (NoOp(ref a), NoOp(ref b)) => {
// Note: We have to reverse the order for CSC matrices // Note: We have to reverse the order for CSC matrices
spmm_cs_prealloc_checked(beta, &mut c.cs, alpha, &b.cs, &a.cs) spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
} }
_ => do_transposes(beta, c, alpha, a, b, spmm_csc_prealloc_checked), _ => spmm_csc_transposed(beta, c, alpha, a, b, spmm_csc_prealloc),
} }
} }
@ -101,7 +101,7 @@ where
/// Should be used for situations where pattern creation immediately preceeds multiplication. /// Should be used for situations where pattern creation immediately preceeds multiplication.
/// ///
/// Panics if the dimensions of the matrices involved are not compatible with the expression. /// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub(crate) fn spmm_csc_prealloc_unchecked<T>( pub fn spmm_csc_prealloc_unchecked<T>(
beta: T, beta: T,
c: &mut CscMatrix<T>, c: &mut CscMatrix<T>,
alpha: T, alpha: T,
@ -120,17 +120,17 @@ where
// Note: We have to reverse the order for CSC matrices // Note: We have to reverse the order for CSC matrices
spmm_cs_prealloc_unchecked(beta, &mut c.cs, alpha, &b.cs, &a.cs) spmm_cs_prealloc_unchecked(beta, &mut c.cs, alpha, &b.cs, &a.cs)
} }
_ => do_transposes(beta, c, alpha, a, b, spmm_csc_prealloc_unchecked), _ => spmm_csc_transposed(beta, c, alpha, a, b, spmm_csc_prealloc_unchecked),
} }
} }
fn do_transposes<T, F>( fn spmm_csc_transposed<T, F>(
beta: T, beta: T,
c: &mut CscMatrix<T>, c: &mut CscMatrix<T>,
alpha: T, alpha: T,
a: Op<&CscMatrix<T>>, a: Op<&CscMatrix<T>>,
b: Op<&CscMatrix<T>>, b: Op<&CscMatrix<T>>,
caller: F, spmm_kernel: F,
) -> Result<(), OperationError> ) -> Result<(), OperationError>
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One, T: Scalar + ClosedAdd + ClosedMul + Zero + One,
@ -157,7 +157,7 @@ where
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose())), (Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose())),
} }
}; };
caller(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref())) spmm_kernel(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
} }
/// Solve the lower triangular system `op(L) X = B`. /// Solve the lower triangular system `op(L) X = B`.

View File

@ -1,6 +1,6 @@
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::ops::serial::cs::{ use crate::ops::serial::cs::{
spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc_checked, spmm_cs_prealloc_unchecked, spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc, spmm_cs_prealloc_unchecked,
}; };
use crate::ops::serial::OperationError; use crate::ops::serial::OperationError;
use crate::ops::Op; use crate::ops::Op;
@ -67,7 +67,7 @@ where
/// # Panics /// # Panics
/// ///
/// Panics if the dimensions of the matrices involved are not compatible with the expression. /// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spmm_csr_prealloc_checked<T>( pub fn spmm_csr_prealloc<T>(
beta: T, beta: T,
c: &mut CsrMatrix<T>, c: &mut CsrMatrix<T>,
alpha: T, alpha: T,
@ -82,10 +82,8 @@ where
use Op::NoOp; use Op::NoOp;
match (&a, &b) { match (&a, &b) {
(NoOp(ref a), NoOp(ref b)) => { (NoOp(ref a), NoOp(ref b)) => spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs),
spmm_cs_prealloc_checked(beta, &mut c.cs, alpha, &a.cs, &b.cs) _ => spmm_csr_transposed(beta, c, alpha, a, b, spmm_csr_prealloc),
}
_ => do_transposes(beta, c, alpha, a, b, spmm_csr_prealloc_checked),
} }
} }
@ -94,7 +92,7 @@ where
/// Should be used for situations where pattern creation immediately preceeds multiplication. /// Should be used for situations where pattern creation immediately preceeds multiplication.
/// ///
/// Panics if the dimensions of the matrices involved are not compatible with the expression. /// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub(crate) fn spmm_csr_prealloc_unchecked<T>( pub fn spmm_csr_prealloc_unchecked<T>(
beta: T, beta: T,
c: &mut CsrMatrix<T>, c: &mut CsrMatrix<T>,
alpha: T, alpha: T,
@ -112,17 +110,17 @@ where
(NoOp(ref a), NoOp(ref b)) => { (NoOp(ref a), NoOp(ref b)) => {
spmm_cs_prealloc_unchecked(beta, &mut c.cs, alpha, &a.cs, &b.cs) spmm_cs_prealloc_unchecked(beta, &mut c.cs, alpha, &a.cs, &b.cs)
} }
_ => do_transposes(beta, c, alpha, a, b, spmm_csr_prealloc_unchecked), _ => spmm_csr_transposed(beta, c, alpha, a, b, spmm_csr_prealloc_unchecked),
} }
} }
fn do_transposes<T, F>( fn spmm_csr_transposed<T, F>(
beta: T, beta: T,
c: &mut CsrMatrix<T>, c: &mut CsrMatrix<T>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>, a: Op<&CsrMatrix<T>>,
b: Op<&CsrMatrix<T>>, b: Op<&CsrMatrix<T>>,
caller: F, spmm_kernel: F,
) -> Result<(), OperationError> ) -> Result<(), OperationError>
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One, T: Scalar + ClosedAdd + ClosedMul + Zero + One,
@ -149,5 +147,5 @@ where
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose())), (Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose())),
} }
}; };
caller(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref())) spmm_kernel(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
} }