diff --git a/src/runtime/src/rpc.rs b/src/runtime/src/rpc.rs index bb17f507..beaab895 100644 --- a/src/runtime/src/rpc.rs +++ b/src/runtime/src/rpc.rs @@ -14,15 +14,20 @@ use crate::proto_core_io::ProtoWrite; use crate::proto_async; use self::tag::{Tag, TagIterator, split_tag}; +#[inline] +fn alignment_offset(alignment: isize, ptr: isize) -> isize { + (alignment - ptr % alignment) % alignment +} + unsafe fn align_ptr(ptr: *const ()) -> *const T { let alignment = core::mem::align_of::() as isize; - let fix = (alignment - (ptr as isize) % alignment) % alignment; + let fix = alignment_offset(alignment, ptr as isize); ((ptr as isize) + fix) as *const T } unsafe fn align_ptr_mut(ptr: *mut ()) -> *mut T { let alignment = core::mem::align_of::() as isize; - let fix = (alignment - (ptr as isize) % alignment) % alignment; + let fix = alignment_offset(alignment, ptr as isize); ((ptr as isize) + fix) as *mut T } @@ -66,6 +71,7 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da }) } Tag::Tuple(it, arity) => { + *data = (*data).offset(alignment_offset(tag.alignment() as isize, *data as isize)); let mut it = it.clone(); for _ in 0..arity { let tag = it.next().expect("truncated tag"); @@ -80,17 +86,24 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da let length = proto_async::read_i32(stream).await? as usize; (*ptr).length = length as u32; let tag = it.clone().next().expect("truncated tag"); - let mut data = alloc(tag.size() * length as usize).await; + let data_size = tag.size() * length as usize + + match tag { + Tag::Int64 | Tag::Float64 => 4, + _ => 0 + }; + let mut data = alloc(data_size).await; + let alignment = tag.alignment(); + data = data.offset(alignment_offset(alignment as isize, data as isize)); (*ptr).elements = data; match tag { Tag::Bool => { - let ptr = align_ptr_mut::(data); + let ptr = data as *mut u8; let dest = core::slice::from_raw_parts_mut(ptr, length); proto_async::read_chunk(stream, dest).await?; }, Tag::Int32 => { - let ptr = align_ptr_mut::(data); + let ptr = data as *mut u32; // reading as raw bytes and do endianness conversion later let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4); proto_async::read_chunk(stream, dest).await?; @@ -99,7 +112,7 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da NativeEndian::from_slice_u32(dest); }, Tag::Int64 | Tag::Float64 => { - let ptr = align_ptr_mut::(data); + let ptr = data as *mut u64; let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8); proto_async::read_chunk(stream, dest).await?; drop(dest); @@ -125,18 +138,25 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da } let elt_tag = it.clone().next().expect("truncated tag"); - *buffer = alloc(elt_tag.size() * total_len as usize).await; + let data_size = elt_tag.size() * total_len as usize + + match elt_tag { + Tag::Int64 | Tag::Float64 => 4, + _ => 0 + }; + let mut data = alloc(data_size).await; + let alignment = tag.alignment(); + data = data.offset(alignment_offset(alignment as isize, data as isize)); + *buffer = data; let length = total_len as usize; - let mut data = *buffer; match elt_tag { Tag::Bool => { - let ptr = align_ptr_mut::(data); + let ptr = data as *mut u8; let dest = core::slice::from_raw_parts_mut(ptr, length); proto_async::read_chunk(stream, dest).await?; }, Tag::Int32 => { - let ptr = align_ptr_mut::(data); + let ptr = data as *mut u32; let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4); proto_async::read_chunk(stream, dest).await?; drop(dest); @@ -144,7 +164,7 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da NativeEndian::from_slice_u32(dest); }, Tag::Int64 | Tag::Float64 => { - let ptr = align_ptr_mut::(data); + let ptr = data as *mut u64; let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8); proto_async::read_chunk(stream, dest).await?; drop(dest); @@ -161,6 +181,7 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da }) } Tag::Range(it) => { + *data = (*data).offset(alignment_offset(tag.alignment() as isize, *data as isize)); let tag = it.clone().next().expect("truncated tag"); recv_value(stream, tag, data, alloc).await?; recv_value(stream, tag, data, alloc).await?; @@ -407,7 +428,35 @@ mod tag { } } + pub fn alignment(self) -> usize { + use cslice::CSlice; + match self { + Tag::None => 1, + Tag::Bool => core::mem::align_of::(), + Tag::Int32 => core::mem::align_of::(), + Tag::Int64 => core::mem::align_of::(), + Tag::Float64 => core::mem::align_of::(), + // struct type: align to largest element + Tag::Tuple(it, arity) => { + let it = it.clone(); + it.take(arity.into()).map(|t| t.alignment()).max().unwrap() + }, + Tag::Range(it) => { + let it = it.clone(); + it.take(3).map(|t| t.alignment()).max().unwrap() + } + // CSlice basically + Tag::Bytes | Tag::String | Tag::ByteArray | Tag::List(_) => + core::mem::align_of::>(), + // array buffer is allocated, so no need for alignment first + Tag::Array(_, _) => 1, + // will not be sent from the host + _ => unreachable!("unexpected tag from host") + } + } + pub fn size(self) -> usize { + use super::alignment_offset; match self { Tag::None => 0, Tag::Bool => 1, @@ -423,6 +472,8 @@ mod tag { for _ in 0..arity { let tag = it.next().expect("truncated tag"); size += tag.size(); + // includes padding + size += alignment_offset(tag.alignment() as isize, size as isize) as usize; } size } @@ -445,10 +496,23 @@ mod tag { impl<'a> TagIterator<'a> { pub fn new(data: &'a [u8]) -> TagIterator<'a> { - TagIterator { data: data } + TagIterator { data } } - pub fn next(&mut self) -> Option> { + + fn sub(&mut self, count: u8) -> TagIterator<'a> { + let data = self.data; + for _ in 0..count { + self.next().expect("truncated tag"); + } + TagIterator { data: &data[..(data.len() - self.data.len())] } + } + } + + impl<'a> core::iter::Iterator for TagIterator<'a> { + type Item = Tag<'a>; + + fn next(&mut self) -> Option> { if self.data.len() == 0 { return None } @@ -481,14 +545,6 @@ mod tag { _ => unreachable!() }) } - - fn sub(&mut self, count: u8) -> TagIterator<'a> { - let data = self.data; - for _ in 0..count { - self.next().expect("truncated tag"); - } - TagIterator { data: &data[..(data.len() - self.data.len())] } - } } impl<'a> fmt::Display for TagIterator<'a> {