runtime/rpc: fixes alignment and size problem
This commit is contained in:
parent
1155802cf7
commit
14c2abe578
@ -14,15 +14,20 @@ use crate::proto_core_io::ProtoWrite;
|
|||||||
use crate::proto_async;
|
use crate::proto_async;
|
||||||
use self::tag::{Tag, TagIterator, split_tag};
|
use self::tag::{Tag, TagIterator, split_tag};
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn alignment_offset(alignment: isize, ptr: isize) -> isize {
|
||||||
|
(alignment - ptr % alignment) % alignment
|
||||||
|
}
|
||||||
|
|
||||||
unsafe fn align_ptr<T>(ptr: *const ()) -> *const T {
|
unsafe fn align_ptr<T>(ptr: *const ()) -> *const T {
|
||||||
let alignment = core::mem::align_of::<T>() as isize;
|
let alignment = core::mem::align_of::<T>() as isize;
|
||||||
let fix = (alignment - (ptr as isize) % alignment) % alignment;
|
let fix = alignment_offset(alignment, ptr as isize);
|
||||||
((ptr as isize) + fix) as *const T
|
((ptr as isize) + fix) as *const T
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn align_ptr_mut<T>(ptr: *mut ()) -> *mut T {
|
unsafe fn align_ptr_mut<T>(ptr: *mut ()) -> *mut T {
|
||||||
let alignment = core::mem::align_of::<T>() as isize;
|
let alignment = core::mem::align_of::<T>() as isize;
|
||||||
let fix = (alignment - (ptr as isize) % alignment) % alignment;
|
let fix = alignment_offset(alignment, ptr as isize);
|
||||||
((ptr as isize) + fix) as *mut T
|
((ptr as isize) + fix) as *mut T
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,6 +71,7 @@ async unsafe fn recv_value<F>(stream: &TcpStream, tag: Tag<'async_recursion>, da
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
Tag::Tuple(it, arity) => {
|
Tag::Tuple(it, arity) => {
|
||||||
|
*data = (*data).offset(alignment_offset(tag.alignment() as isize, *data as isize));
|
||||||
let mut it = it.clone();
|
let mut it = it.clone();
|
||||||
for _ in 0..arity {
|
for _ in 0..arity {
|
||||||
let tag = it.next().expect("truncated tag");
|
let tag = it.next().expect("truncated tag");
|
||||||
@ -80,17 +86,24 @@ async unsafe fn recv_value<F>(stream: &TcpStream, tag: Tag<'async_recursion>, da
|
|||||||
let length = proto_async::read_i32(stream).await? as usize;
|
let length = proto_async::read_i32(stream).await? as usize;
|
||||||
(*ptr).length = length as u32;
|
(*ptr).length = length as u32;
|
||||||
let tag = it.clone().next().expect("truncated tag");
|
let tag = it.clone().next().expect("truncated tag");
|
||||||
let mut data = alloc(tag.size() * length as usize).await;
|
let data_size = tag.size() * length as usize +
|
||||||
|
match tag {
|
||||||
|
Tag::Int64 | Tag::Float64 => 4,
|
||||||
|
_ => 0
|
||||||
|
};
|
||||||
|
let mut data = alloc(data_size).await;
|
||||||
|
|
||||||
|
let alignment = tag.alignment();
|
||||||
|
data = data.offset(alignment_offset(alignment as isize, data as isize));
|
||||||
(*ptr).elements = data;
|
(*ptr).elements = data;
|
||||||
match tag {
|
match tag {
|
||||||
Tag::Bool => {
|
Tag::Bool => {
|
||||||
let ptr = align_ptr_mut::<u8>(data);
|
let ptr = data as *mut u8;
|
||||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||||
proto_async::read_chunk(stream, dest).await?;
|
proto_async::read_chunk(stream, dest).await?;
|
||||||
},
|
},
|
||||||
Tag::Int32 => {
|
Tag::Int32 => {
|
||||||
let ptr = align_ptr_mut::<u32>(data);
|
let ptr = data as *mut u32;
|
||||||
// reading as raw bytes and do endianness conversion later
|
// reading as raw bytes and do endianness conversion later
|
||||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4);
|
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4);
|
||||||
proto_async::read_chunk(stream, dest).await?;
|
proto_async::read_chunk(stream, dest).await?;
|
||||||
@ -99,7 +112,7 @@ async unsafe fn recv_value<F>(stream: &TcpStream, tag: Tag<'async_recursion>, da
|
|||||||
NativeEndian::from_slice_u32(dest);
|
NativeEndian::from_slice_u32(dest);
|
||||||
},
|
},
|
||||||
Tag::Int64 | Tag::Float64 => {
|
Tag::Int64 | Tag::Float64 => {
|
||||||
let ptr = align_ptr_mut::<u64>(data);
|
let ptr = data as *mut u64;
|
||||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8);
|
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8);
|
||||||
proto_async::read_chunk(stream, dest).await?;
|
proto_async::read_chunk(stream, dest).await?;
|
||||||
drop(dest);
|
drop(dest);
|
||||||
@ -125,18 +138,25 @@ async unsafe fn recv_value<F>(stream: &TcpStream, tag: Tag<'async_recursion>, da
|
|||||||
}
|
}
|
||||||
|
|
||||||
let elt_tag = it.clone().next().expect("truncated tag");
|
let elt_tag = it.clone().next().expect("truncated tag");
|
||||||
*buffer = alloc(elt_tag.size() * total_len as usize).await;
|
let data_size = elt_tag.size() * total_len as usize +
|
||||||
|
match elt_tag {
|
||||||
|
Tag::Int64 | Tag::Float64 => 4,
|
||||||
|
_ => 0
|
||||||
|
};
|
||||||
|
let mut data = alloc(data_size).await;
|
||||||
|
|
||||||
|
let alignment = tag.alignment();
|
||||||
|
data = data.offset(alignment_offset(alignment as isize, data as isize));
|
||||||
|
*buffer = data;
|
||||||
let length = total_len as usize;
|
let length = total_len as usize;
|
||||||
let mut data = *buffer;
|
|
||||||
match elt_tag {
|
match elt_tag {
|
||||||
Tag::Bool => {
|
Tag::Bool => {
|
||||||
let ptr = align_ptr_mut::<u8>(data);
|
let ptr = data as *mut u8;
|
||||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||||
proto_async::read_chunk(stream, dest).await?;
|
proto_async::read_chunk(stream, dest).await?;
|
||||||
},
|
},
|
||||||
Tag::Int32 => {
|
Tag::Int32 => {
|
||||||
let ptr = align_ptr_mut::<u32>(data);
|
let ptr = data as *mut u32;
|
||||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4);
|
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4);
|
||||||
proto_async::read_chunk(stream, dest).await?;
|
proto_async::read_chunk(stream, dest).await?;
|
||||||
drop(dest);
|
drop(dest);
|
||||||
@ -144,7 +164,7 @@ async unsafe fn recv_value<F>(stream: &TcpStream, tag: Tag<'async_recursion>, da
|
|||||||
NativeEndian::from_slice_u32(dest);
|
NativeEndian::from_slice_u32(dest);
|
||||||
},
|
},
|
||||||
Tag::Int64 | Tag::Float64 => {
|
Tag::Int64 | Tag::Float64 => {
|
||||||
let ptr = align_ptr_mut::<u64>(data);
|
let ptr = data as *mut u64;
|
||||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8);
|
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8);
|
||||||
proto_async::read_chunk(stream, dest).await?;
|
proto_async::read_chunk(stream, dest).await?;
|
||||||
drop(dest);
|
drop(dest);
|
||||||
@ -161,6 +181,7 @@ async unsafe fn recv_value<F>(stream: &TcpStream, tag: Tag<'async_recursion>, da
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
Tag::Range(it) => {
|
Tag::Range(it) => {
|
||||||
|
*data = (*data).offset(alignment_offset(tag.alignment() as isize, *data as isize));
|
||||||
let tag = it.clone().next().expect("truncated tag");
|
let tag = it.clone().next().expect("truncated tag");
|
||||||
recv_value(stream, tag, data, alloc).await?;
|
recv_value(stream, tag, data, alloc).await?;
|
||||||
recv_value(stream, tag, data, alloc).await?;
|
recv_value(stream, tag, data, alloc).await?;
|
||||||
@ -407,7 +428,35 @@ mod tag {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn alignment(self) -> usize {
|
||||||
|
use cslice::CSlice;
|
||||||
|
match self {
|
||||||
|
Tag::None => 1,
|
||||||
|
Tag::Bool => core::mem::align_of::<u8>(),
|
||||||
|
Tag::Int32 => core::mem::align_of::<i32>(),
|
||||||
|
Tag::Int64 => core::mem::align_of::<i64>(),
|
||||||
|
Tag::Float64 => core::mem::align_of::<f64>(),
|
||||||
|
// struct type: align to largest element
|
||||||
|
Tag::Tuple(it, arity) => {
|
||||||
|
let it = it.clone();
|
||||||
|
it.take(arity.into()).map(|t| t.alignment()).max().unwrap()
|
||||||
|
},
|
||||||
|
Tag::Range(it) => {
|
||||||
|
let it = it.clone();
|
||||||
|
it.take(3).map(|t| t.alignment()).max().unwrap()
|
||||||
|
}
|
||||||
|
// CSlice basically
|
||||||
|
Tag::Bytes | Tag::String | Tag::ByteArray | Tag::List(_) =>
|
||||||
|
core::mem::align_of::<CSlice<()>>(),
|
||||||
|
// array buffer is allocated, so no need for alignment first
|
||||||
|
Tag::Array(_, _) => 1,
|
||||||
|
// will not be sent from the host
|
||||||
|
_ => unreachable!("unexpected tag from host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn size(self) -> usize {
|
pub fn size(self) -> usize {
|
||||||
|
use super::alignment_offset;
|
||||||
match self {
|
match self {
|
||||||
Tag::None => 0,
|
Tag::None => 0,
|
||||||
Tag::Bool => 1,
|
Tag::Bool => 1,
|
||||||
@ -423,6 +472,8 @@ mod tag {
|
|||||||
for _ in 0..arity {
|
for _ in 0..arity {
|
||||||
let tag = it.next().expect("truncated tag");
|
let tag = it.next().expect("truncated tag");
|
||||||
size += tag.size();
|
size += tag.size();
|
||||||
|
// includes padding
|
||||||
|
size += alignment_offset(tag.alignment() as isize, size as isize) as usize;
|
||||||
}
|
}
|
||||||
size
|
size
|
||||||
}
|
}
|
||||||
@ -445,10 +496,23 @@ mod tag {
|
|||||||
|
|
||||||
impl<'a> TagIterator<'a> {
|
impl<'a> TagIterator<'a> {
|
||||||
pub fn new(data: &'a [u8]) -> TagIterator<'a> {
|
pub fn new(data: &'a [u8]) -> TagIterator<'a> {
|
||||||
TagIterator { data: data }
|
TagIterator { data }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn next(&mut self) -> Option<Tag<'a>> {
|
|
||||||
|
fn sub(&mut self, count: u8) -> TagIterator<'a> {
|
||||||
|
let data = self.data;
|
||||||
|
for _ in 0..count {
|
||||||
|
self.next().expect("truncated tag");
|
||||||
|
}
|
||||||
|
TagIterator { data: &data[..(data.len() - self.data.len())] }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> core::iter::Iterator for TagIterator<'a> {
|
||||||
|
type Item = Tag<'a>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Tag<'a>> {
|
||||||
if self.data.len() == 0 {
|
if self.data.len() == 0 {
|
||||||
return None
|
return None
|
||||||
}
|
}
|
||||||
@ -481,14 +545,6 @@ mod tag {
|
|||||||
_ => unreachable!()
|
_ => unreachable!()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sub(&mut self, count: u8) -> TagIterator<'a> {
|
|
||||||
let data = self.data;
|
|
||||||
for _ in 0..count {
|
|
||||||
self.next().expect("truncated tag");
|
|
||||||
}
|
|
||||||
TagIterator { data: &data[..(data.len() - self.data.len())] }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> fmt::Display for TagIterator<'a> {
|
impl<'a> fmt::Display for TagIterator<'a> {
|
||||||
|
Loading…
Reference in New Issue
Block a user