2021-07-18 10:43:50 +08:00
|
|
|
|
use std::fmt::{self, Debug, Formatter};
|
2021-01-03 22:20:34 +08:00
|
|
|
|
// use std::hash::{Hash, Hasher};
|
|
|
|
|
use std::ops::Mul;
|
2017-02-16 05:04:34 +08:00
|
|
|
|
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2018-07-20 21:25:55 +08:00
|
|
|
|
use serde::de::{Error, SeqAccess, Visitor};
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2017-02-13 01:17:09 +08:00
|
|
|
|
use serde::ser::SerializeSeq;
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2018-07-20 21:25:55 +08:00
|
|
|
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2017-02-16 05:04:34 +08:00
|
|
|
|
use std::marker::PhantomData;
|
2016-12-05 05:44:42 +08:00
|
|
|
|
|
2021-08-03 00:41:46 +08:00
|
|
|
|
use crate::base::allocator::Allocator;
|
2019-03-23 21:29:07 +08:00
|
|
|
|
use crate::base::default_allocator::DefaultAllocator;
|
2021-01-03 22:20:34 +08:00
|
|
|
|
use crate::base::dimension::{Const, ToTypenum};
|
2021-08-03 00:41:46 +08:00
|
|
|
|
use crate::base::storage::{IsContiguous, Owned, RawStorage, RawStorageMut, ReshapableStorage};
|
|
|
|
|
use crate::base::Scalar;
|
|
|
|
|
use crate::Storage;
|
2021-08-03 23:26:56 +08:00
|
|
|
|
use std::mem;
|
2016-12-05 05:44:42 +08:00
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
*
|
2021-08-03 00:41:46 +08:00
|
|
|
|
* Static RawStorage.
|
2016-12-05 05:44:42 +08:00
|
|
|
|
*
|
|
|
|
|
*/
|
|
|
|
|
/// A array-based statically sized matrix data storage.
|
2021-08-04 17:19:57 +08:00
|
|
|
|
#[repr(transparent)]
|
2021-01-03 22:20:34 +08:00
|
|
|
|
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
2022-04-30 16:32:10 +08:00
|
|
|
|
#[cfg_attr(feature = "rkyv-serialize", derive(bytecheck::CheckBytes))]
|
|
|
|
|
#[cfg_attr(
|
|
|
|
|
feature = "rkyv-serialize-no-std",
|
|
|
|
|
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
|
|
|
|
|
)]
|
2022-03-17 01:07:29 +08:00
|
|
|
|
#[cfg_attr(feature = "cuda", derive(cust_core::DeviceCopy))]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
pub struct ArrayStorage<T, const R: usize, const C: usize>(pub [[T; R]; C]);
|
2016-12-05 05:44:42 +08:00
|
|
|
|
|
2021-08-03 00:41:46 +08:00
|
|
|
|
impl<T, const R: usize, const C: usize> ArrayStorage<T, R, C> {
|
2021-08-03 23:26:56 +08:00
|
|
|
|
/// Converts this array storage to a slice.
|
2021-08-03 00:41:46 +08:00
|
|
|
|
#[inline]
|
|
|
|
|
pub fn as_slice(&self) -> &[T] {
|
|
|
|
|
// SAFETY: this is OK because ArrayStorage is contiguous.
|
|
|
|
|
unsafe { self.as_slice_unchecked() }
|
|
|
|
|
}
|
|
|
|
|
|
2021-08-03 23:26:56 +08:00
|
|
|
|
/// Converts this array storage to a mutable slice.
|
2021-08-03 00:41:46 +08:00
|
|
|
|
#[inline]
|
|
|
|
|
pub fn as_mut_slice(&mut self) -> &mut [T] {
|
|
|
|
|
// SAFETY: this is OK because ArrayStorage is contiguous.
|
|
|
|
|
unsafe { self.as_mut_slice_unchecked() }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-01-03 22:20:34 +08:00
|
|
|
|
// TODO: remove this once the stdlib implements Default for arrays.
|
2021-04-11 17:00:38 +08:00
|
|
|
|
impl<T: Default, const R: usize, const C: usize> Default for ArrayStorage<T, R, C>
|
2019-04-16 16:11:27 +08:00
|
|
|
|
where
|
2021-04-11 17:00:38 +08:00
|
|
|
|
[[T; R]; C]: Default,
|
2019-04-16 16:11:27 +08:00
|
|
|
|
{
|
2021-01-03 22:20:34 +08:00
|
|
|
|
#[inline]
|
2019-04-16 16:11:27 +08:00
|
|
|
|
fn default() -> Self {
|
2021-04-11 17:00:38 +08:00
|
|
|
|
Self(Default::default())
|
2019-04-16 16:11:27 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-04-11 17:00:38 +08:00
|
|
|
|
impl<T: Debug, const R: usize, const C: usize> Debug for ArrayStorage<T, R, C> {
|
2016-12-05 05:44:42 +08:00
|
|
|
|
#[inline]
|
2021-07-26 01:06:14 +08:00
|
|
|
|
fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
|
2021-04-11 17:00:38 +08:00
|
|
|
|
self.0.fmt(fmt)
|
2016-12-05 05:44:42 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-08-03 00:41:46 +08:00
|
|
|
|
unsafe impl<T, const R: usize, const C: usize> RawStorage<T, Const<R>, Const<C>>
|
2021-04-11 17:00:38 +08:00
|
|
|
|
for ArrayStorage<T, R, C>
|
2018-02-02 19:26:35 +08:00
|
|
|
|
{
|
2021-01-03 22:20:34 +08:00
|
|
|
|
type RStride = Const<1>;
|
|
|
|
|
type CStride = Const<R>;
|
2016-12-05 05:44:42 +08:00
|
|
|
|
|
|
|
|
|
#[inline]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
fn ptr(&self) -> *const T {
|
|
|
|
|
self.0.as_ptr() as *const T
|
2016-12-05 05:44:42 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[inline]
|
2021-01-03 22:20:34 +08:00
|
|
|
|
fn shape(&self) -> (Const<R>, Const<C>) {
|
|
|
|
|
(Const, Const)
|
2017-08-03 01:37:44 +08:00
|
|
|
|
}
|
2016-12-05 05:44:42 +08:00
|
|
|
|
|
2017-08-03 01:37:44 +08:00
|
|
|
|
#[inline]
|
|
|
|
|
fn strides(&self) -> (Self::RStride, Self::CStride) {
|
2021-01-03 22:20:34 +08:00
|
|
|
|
(Const, Const)
|
2016-12-05 05:44:42 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[inline]
|
2021-07-09 00:12:43 +08:00
|
|
|
|
fn is_contiguous(&self) -> bool {
|
2017-08-03 01:37:44 +08:00
|
|
|
|
true
|
2016-12-05 05:44:42 +08:00
|
|
|
|
}
|
|
|
|
|
|
2021-08-03 00:41:46 +08:00
|
|
|
|
#[inline]
|
|
|
|
|
unsafe fn as_slice_unchecked(&self) -> &[T] {
|
|
|
|
|
std::slice::from_raw_parts(self.ptr(), R * C)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsafe impl<T: Scalar, const R: usize, const C: usize> Storage<T, Const<R>, Const<C>>
|
|
|
|
|
for ArrayStorage<T, R, C>
|
|
|
|
|
where
|
|
|
|
|
DefaultAllocator: Allocator<T, Const<R>, Const<C>, Buffer = Self>,
|
|
|
|
|
{
|
2016-12-05 05:44:42 +08:00
|
|
|
|
#[inline]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
fn into_owned(self) -> Owned<T, Const<R>, Const<C>>
|
2020-04-06 00:49:48 +08:00
|
|
|
|
where
|
2021-08-03 00:41:46 +08:00
|
|
|
|
DefaultAllocator: Allocator<T, Const<R>, Const<C>>,
|
2020-04-06 00:49:48 +08:00
|
|
|
|
{
|
2021-08-03 00:41:46 +08:00
|
|
|
|
self
|
2016-12-05 05:44:42 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[inline]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
fn clone_owned(&self) -> Owned<T, Const<R>, Const<C>>
|
2020-04-06 00:49:48 +08:00
|
|
|
|
where
|
2021-08-03 00:41:46 +08:00
|
|
|
|
DefaultAllocator: Allocator<T, Const<R>, Const<C>>,
|
2020-04-06 00:49:48 +08:00
|
|
|
|
{
|
2021-08-03 00:41:46 +08:00
|
|
|
|
self.clone()
|
2016-12-05 05:44:42 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-08-03 00:41:46 +08:00
|
|
|
|
unsafe impl<T, const R: usize, const C: usize> RawStorageMut<T, Const<R>, Const<C>>
|
2021-04-11 17:00:38 +08:00
|
|
|
|
for ArrayStorage<T, R, C>
|
2018-02-02 19:26:35 +08:00
|
|
|
|
{
|
2016-12-05 05:44:42 +08:00
|
|
|
|
#[inline]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
fn ptr_mut(&mut self) -> *mut T {
|
|
|
|
|
self.0.as_mut_ptr() as *mut T
|
2016-12-05 05:44:42 +08:00
|
|
|
|
}
|
2017-08-03 01:37:44 +08:00
|
|
|
|
|
|
|
|
|
#[inline]
|
2021-06-17 15:46:49 +08:00
|
|
|
|
unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] {
|
|
|
|
|
std::slice::from_raw_parts_mut(self.ptr_mut(), R * C)
|
2017-08-03 01:37:44 +08:00
|
|
|
|
}
|
2016-12-05 05:44:42 +08:00
|
|
|
|
}
|
|
|
|
|
|
2021-08-03 00:41:46 +08:00
|
|
|
|
unsafe impl<T, const R: usize, const C: usize> IsContiguous for ArrayStorage<T, R, C> {}
|
2017-02-13 01:17:09 +08:00
|
|
|
|
|
2021-04-11 17:00:38 +08:00
|
|
|
|
impl<T, const R1: usize, const C1: usize, const R2: usize, const C2: usize>
|
|
|
|
|
ReshapableStorage<T, Const<R1>, Const<C1>, Const<R2>, Const<C2>> for ArrayStorage<T, R1, C1>
|
2020-08-19 13:52:26 +08:00
|
|
|
|
where
|
2021-08-03 00:41:46 +08:00
|
|
|
|
T: Scalar,
|
2021-01-03 22:20:34 +08:00
|
|
|
|
Const<R1>: ToTypenum,
|
|
|
|
|
Const<C1>: ToTypenum,
|
|
|
|
|
Const<R2>: ToTypenum,
|
|
|
|
|
Const<C2>: ToTypenum,
|
|
|
|
|
<Const<R1> as ToTypenum>::Typenum: Mul<<Const<C1> as ToTypenum>::Typenum>,
|
|
|
|
|
<Const<R2> as ToTypenum>::Typenum: Mul<
|
|
|
|
|
<Const<C2> as ToTypenum>::Typenum,
|
|
|
|
|
Output = typenum::Prod<
|
|
|
|
|
<Const<R1> as ToTypenum>::Typenum,
|
|
|
|
|
<Const<C1> as ToTypenum>::Typenum,
|
|
|
|
|
>,
|
|
|
|
|
>,
|
2020-08-19 13:52:26 +08:00
|
|
|
|
{
|
2021-04-11 17:00:38 +08:00
|
|
|
|
type Output = ArrayStorage<T, R2, C2>;
|
2020-08-19 13:52:26 +08:00
|
|
|
|
|
2021-01-03 22:20:34 +08:00
|
|
|
|
fn reshape_generic(self, _: Const<R2>, _: Const<C2>) -> Self::Output {
|
|
|
|
|
unsafe {
|
2021-08-03 23:02:42 +08:00
|
|
|
|
let data: [[T; R2]; C2] = mem::transmute_copy(&self.0);
|
|
|
|
|
mem::forget(self.0);
|
2021-04-11 17:00:38 +08:00
|
|
|
|
ArrayStorage(data)
|
2021-01-03 22:20:34 +08:00
|
|
|
|
}
|
2020-08-19 13:52:26 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2017-02-13 01:17:09 +08:00
|
|
|
|
/*
|
|
|
|
|
*
|
2021-01-03 22:20:34 +08:00
|
|
|
|
* Serialization.
|
2017-02-13 01:17:09 +08:00
|
|
|
|
*
|
|
|
|
|
*/
|
2021-01-03 22:20:34 +08:00
|
|
|
|
// XXX: open an issue for serde so that it allows the serialization/deserialization of all arrays?
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
impl<T, const R: usize, const C: usize> Serialize for ArrayStorage<T, R, C>
|
2018-02-02 19:26:35 +08:00
|
|
|
|
where
|
2021-08-03 00:41:46 +08:00
|
|
|
|
T: Scalar + Serialize,
|
2018-02-02 19:26:35 +08:00
|
|
|
|
{
|
2017-02-13 01:17:09 +08:00
|
|
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
2020-04-06 00:49:48 +08:00
|
|
|
|
where
|
|
|
|
|
S: Serializer,
|
|
|
|
|
{
|
2021-01-03 22:20:34 +08:00
|
|
|
|
let mut serializer = serializer.serialize_seq(Some(R * C))?;
|
2017-02-13 01:17:09 +08:00
|
|
|
|
|
2021-01-03 22:20:34 +08:00
|
|
|
|
for e in self.as_slice().iter() {
|
2018-02-02 19:26:35 +08:00
|
|
|
|
serializer.serialize_element(e)?;
|
2017-02-13 01:17:09 +08:00
|
|
|
|
}
|
|
|
|
|
|
2018-02-02 19:26:35 +08:00
|
|
|
|
serializer.end()
|
|
|
|
|
}
|
|
|
|
|
}
|
2017-02-13 01:17:09 +08:00
|
|
|
|
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
impl<'a, T, const R: usize, const C: usize> Deserialize<'a> for ArrayStorage<T, R, C>
|
2018-02-02 19:26:35 +08:00
|
|
|
|
where
|
2021-08-03 00:41:46 +08:00
|
|
|
|
T: Scalar + Deserialize<'a>,
|
2018-02-02 19:26:35 +08:00
|
|
|
|
{
|
2017-02-13 01:17:09 +08:00
|
|
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
2020-04-06 00:49:48 +08:00
|
|
|
|
where
|
|
|
|
|
D: Deserializer<'a>,
|
|
|
|
|
{
|
2018-12-06 05:40:03 +08:00
|
|
|
|
deserializer.deserialize_seq(ArrayStorageVisitor::new())
|
2018-02-02 19:26:35 +08:00
|
|
|
|
}
|
2017-02-13 01:17:09 +08:00
|
|
|
|
}
|
|
|
|
|
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2017-02-13 01:17:09 +08:00
|
|
|
|
/// A visitor that produces a matrix array.
|
2021-04-11 17:00:38 +08:00
|
|
|
|
struct ArrayStorageVisitor<T, const R: usize, const C: usize> {
|
|
|
|
|
marker: PhantomData<T>,
|
2017-02-13 01:17:09 +08:00
|
|
|
|
}
|
|
|
|
|
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2021-08-03 00:41:46 +08:00
|
|
|
|
impl<T, const R: usize, const C: usize> ArrayStorageVisitor<T, R, C>
|
|
|
|
|
where
|
|
|
|
|
T: Scalar,
|
|
|
|
|
{
|
2017-02-13 01:17:09 +08:00
|
|
|
|
/// Construct a new sequence visitor.
|
|
|
|
|
pub fn new() -> Self {
|
2018-12-06 05:40:03 +08:00
|
|
|
|
ArrayStorageVisitor {
|
2017-02-13 01:17:09 +08:00
|
|
|
|
marker: PhantomData,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-04-12 18:14:16 +08:00
|
|
|
|
#[cfg(feature = "serde-serialize-no-std")]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
impl<'a, T, const R: usize, const C: usize> Visitor<'a> for ArrayStorageVisitor<T, R, C>
|
2018-02-02 19:26:35 +08:00
|
|
|
|
where
|
2021-08-03 00:41:46 +08:00
|
|
|
|
T: Scalar + Deserialize<'a>,
|
2018-02-02 19:26:35 +08:00
|
|
|
|
{
|
2021-04-11 17:00:38 +08:00
|
|
|
|
type Value = ArrayStorage<T, R, C>;
|
2017-02-13 01:17:09 +08:00
|
|
|
|
|
2021-07-26 01:06:14 +08:00
|
|
|
|
fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
|
2017-02-13 01:17:09 +08:00
|
|
|
|
formatter.write_str("a matrix array")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[inline]
|
2021-04-11 17:00:38 +08:00
|
|
|
|
fn visit_seq<V>(self, mut visitor: V) -> Result<ArrayStorage<T, R, C>, V::Error>
|
2020-04-06 00:49:48 +08:00
|
|
|
|
where
|
|
|
|
|
V: SeqAccess<'a>,
|
|
|
|
|
{
|
2021-08-03 23:26:56 +08:00
|
|
|
|
let mut out: ArrayStorage<core::mem::MaybeUninit<T>, R, C> =
|
2021-08-03 23:02:42 +08:00
|
|
|
|
DefaultAllocator::allocate_uninit(Const::<R>, Const::<C>);
|
2017-02-13 01:17:09 +08:00
|
|
|
|
let mut curr = 0;
|
|
|
|
|
|
2019-03-23 21:29:07 +08:00
|
|
|
|
while let Some(value) = visitor.next_element()? {
|
2021-01-03 22:20:34 +08:00
|
|
|
|
*out.as_mut_slice()
|
|
|
|
|
.get_mut(curr)
|
2021-08-03 23:26:56 +08:00
|
|
|
|
.ok_or_else(|| V::Error::invalid_length(curr, &self))? =
|
|
|
|
|
core::mem::MaybeUninit::new(value);
|
2017-02-13 01:17:09 +08:00
|
|
|
|
curr += 1;
|
|
|
|
|
}
|
|
|
|
|
|
2021-01-03 22:20:34 +08:00
|
|
|
|
if curr == R * C {
|
2021-08-03 23:02:42 +08:00
|
|
|
|
// Safety: all the elements have been initialized.
|
|
|
|
|
unsafe { Ok(<DefaultAllocator as Allocator<T, Const<R>, Const<C>>>::assume_init(out)) }
|
2018-02-02 19:26:35 +08:00
|
|
|
|
} else {
|
2021-08-03 23:02:42 +08:00
|
|
|
|
for i in 0..curr {
|
|
|
|
|
// Safety:
|
|
|
|
|
// - We couldn’t initialize the whole storage. Drop the ones we initialized.
|
|
|
|
|
unsafe { std::ptr::drop_in_place(out.as_mut_slice()[i].as_mut_ptr()) };
|
|
|
|
|
}
|
|
|
|
|
|
2017-04-25 02:05:45 +08:00
|
|
|
|
Err(V::Error::invalid_length(curr, &self))
|
|
|
|
|
}
|
2017-02-13 01:17:09 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
2017-08-14 18:07:06 +08:00
|
|
|
|
|
2021-02-25 21:10:34 +08:00
|
|
|
|
#[cfg(feature = "bytemuck")]
|
2021-08-03 00:41:46 +08:00
|
|
|
|
unsafe impl<T: Scalar + Copy + bytemuck::Zeroable, const R: usize, const C: usize>
|
|
|
|
|
bytemuck::Zeroable for ArrayStorage<T, R, C>
|
2021-02-25 21:10:34 +08:00
|
|
|
|
{
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(feature = "bytemuck")]
|
2021-08-03 00:41:46 +08:00
|
|
|
|
unsafe impl<T: Scalar + Copy + bytemuck::Pod, const R: usize, const C: usize> bytemuck::Pod
|
2021-04-11 17:00:38 +08:00
|
|
|
|
for ArrayStorage<T, R, C>
|
2021-02-25 21:10:34 +08:00
|
|
|
|
{
|
|
|
|
|
}
|