use std::fmt::{self, Debug, Formatter}; // use std::hash::{Hash, Hasher}; use std::ops::Mul; #[cfg(feature = "serde-serialize-no-std")] use serde::de::{Error, SeqAccess, Visitor}; #[cfg(feature = "serde-serialize-no-std")] use serde::ser::SerializeSeq; #[cfg(feature = "serde-serialize-no-std")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; #[cfg(feature = "serde-serialize-no-std")] use std::marker::PhantomData; use crate::base::allocator::Allocator; use crate::base::default_allocator::DefaultAllocator; use crate::base::dimension::{Const, ToTypenum}; use crate::base::storage::{IsContiguous, Owned, RawStorage, RawStorageMut, ReshapableStorage}; use crate::base::Scalar; use crate::Storage; use std::mem; /* * * Static RawStorage. * */ /// A array-based statically sized matrix data storage. #[repr(transparent)] #[derive(Copy, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "rkyv-serialize", derive(bytecheck::CheckBytes))] #[cfg_attr( feature = "rkyv-serialize-no-std", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize) )] #[cfg_attr(feature = "cuda", derive(cust_core::DeviceCopy))] pub struct ArrayStorage(pub [[T; R]; C]); impl ArrayStorage { /// Converts this array storage to a slice. #[inline] pub fn as_slice(&self) -> &[T] { // SAFETY: this is OK because ArrayStorage is contiguous. unsafe { self.as_slice_unchecked() } } /// Converts this array storage to a mutable slice. #[inline] pub fn as_mut_slice(&mut self) -> &mut [T] { // SAFETY: this is OK because ArrayStorage is contiguous. unsafe { self.as_mut_slice_unchecked() } } } // TODO: remove this once the stdlib implements Default for arrays. impl Default for ArrayStorage where [[T; R]; C]: Default, { #[inline] fn default() -> Self { Self(Default::default()) } } impl Debug for ArrayStorage { #[inline] fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result { self.0.fmt(fmt) } } unsafe impl RawStorage, Const> for ArrayStorage { type RStride = Const<1>; type CStride = Const; #[inline] fn ptr(&self) -> *const T { self.0.as_ptr() as *const T } #[inline] fn shape(&self) -> (Const, Const) { (Const, Const) } #[inline] fn strides(&self) -> (Self::RStride, Self::CStride) { (Const, Const) } #[inline] fn is_contiguous(&self) -> bool { true } #[inline] unsafe fn as_slice_unchecked(&self) -> &[T] { std::slice::from_raw_parts(self.ptr(), R * C) } } unsafe impl Storage, Const> for ArrayStorage where DefaultAllocator: Allocator, Const, Buffer = Self>, { #[inline] fn into_owned(self) -> Owned, Const> where DefaultAllocator: Allocator, Const>, { self } #[inline] fn clone_owned(&self) -> Owned, Const> where DefaultAllocator: Allocator, Const>, { self.clone() } } unsafe impl RawStorageMut, Const> for ArrayStorage { #[inline] fn ptr_mut(&mut self) -> *mut T { self.0.as_mut_ptr() as *mut T } #[inline] unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] { std::slice::from_raw_parts_mut(self.ptr_mut(), R * C) } } unsafe impl IsContiguous for ArrayStorage {} impl ReshapableStorage, Const, Const, Const> for ArrayStorage where T: Scalar, Const: ToTypenum, Const: ToTypenum, Const: ToTypenum, Const: ToTypenum, as ToTypenum>::Typenum: Mul< as ToTypenum>::Typenum>, as ToTypenum>::Typenum: Mul< as ToTypenum>::Typenum, Output = typenum::Prod< as ToTypenum>::Typenum, as ToTypenum>::Typenum, >, >, { type Output = ArrayStorage; fn reshape_generic(self, _: Const, _: Const) -> Self::Output { unsafe { let data: [[T; R2]; C2] = mem::transmute_copy(&self.0); mem::forget(self.0); ArrayStorage(data) } } } /* * * Serialization. * */ // XXX: open an issue for serde so that it allows the serialization/deserialization of all arrays? #[cfg(feature = "serde-serialize-no-std")] impl Serialize for ArrayStorage where T: Scalar + Serialize, { fn serialize(&self, serializer: S) -> Result where S: Serializer, { let mut serializer = serializer.serialize_seq(Some(R * C))?; for e in self.as_slice().iter() { serializer.serialize_element(e)?; } serializer.end() } } #[cfg(feature = "serde-serialize-no-std")] impl<'a, T, const R: usize, const C: usize> Deserialize<'a> for ArrayStorage where T: Scalar + Deserialize<'a>, { fn deserialize(deserializer: D) -> Result where D: Deserializer<'a>, { deserializer.deserialize_seq(ArrayStorageVisitor::new()) } } #[cfg(feature = "serde-serialize-no-std")] /// A visitor that produces a matrix array. struct ArrayStorageVisitor { marker: PhantomData, } #[cfg(feature = "serde-serialize-no-std")] impl ArrayStorageVisitor where T: Scalar, { /// Construct a new sequence visitor. pub fn new() -> Self { ArrayStorageVisitor { marker: PhantomData, } } } #[cfg(feature = "serde-serialize-no-std")] impl<'a, T, const R: usize, const C: usize> Visitor<'a> for ArrayStorageVisitor where T: Scalar + Deserialize<'a>, { type Value = ArrayStorage; fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result { formatter.write_str("a matrix array") } #[inline] fn visit_seq(self, mut visitor: V) -> Result, V::Error> where V: SeqAccess<'a>, { let mut out: ArrayStorage, R, C> = DefaultAllocator::allocate_uninit(Const::, Const::); let mut curr = 0; while let Some(value) = visitor.next_element()? { *out.as_mut_slice() .get_mut(curr) .ok_or_else(|| V::Error::invalid_length(curr, &self))? = core::mem::MaybeUninit::new(value); curr += 1; } if curr == R * C { // Safety: all the elements have been initialized. unsafe { Ok(, Const>>::assume_init(out)) } } else { for i in 0..curr { // Safety: // - We couldn’t initialize the whole storage. Drop the ones we initialized. unsafe { std::ptr::drop_in_place(out.as_mut_slice()[i].as_mut_ptr()) }; } Err(V::Error::invalid_length(curr, &self)) } } } #[cfg(feature = "bytemuck")] unsafe impl bytemuck::Zeroable for ArrayStorage { } #[cfg(feature = "bytemuck")] unsafe impl bytemuck::Pod for ArrayStorage { }