forked from M-Labs/nac3
rpc: Port over size/alignment fix for structs (tuples) with tail padding
This ports over the following commits from the main ARTIQ repo: - 8740ec3dd52d85084237797881ea137492bfe070 - dbbe8e8ed4f852e623775b7bd3aec818cdd03376 - b9f13d48aa7e2c0652210152b971b21c3c419347
This commit is contained in:
parent
800c12e794
commit
df4988c774
|
@ -15,22 +15,85 @@ use crate::proto_async;
|
|||
use self::tag::{Tag, TagIterator, split_tag};
|
||||
|
||||
#[inline]
|
||||
fn alignment_offset(alignment: isize, ptr: isize) -> isize {
|
||||
(alignment - ptr % alignment) % alignment
|
||||
fn round_up(val: usize, power_of_two: usize) -> usize {
|
||||
assert!(power_of_two.is_power_of_two());
|
||||
let max_rem = power_of_two - 1;
|
||||
(val + max_rem) & (!max_rem)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn round_up_mut<T>(ptr: *mut T, power_of_two: usize) -> *mut T {
|
||||
round_up(ptr as usize, power_of_two) as *mut T
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn round_up_const<T>(ptr: *const T, power_of_two: usize) -> *const T {
|
||||
round_up(ptr as usize, power_of_two) as *const T
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn align_ptr<T>(ptr: *const ()) -> *const T {
|
||||
let alignment = core::mem::align_of::<T>() as isize;
|
||||
let fix = alignment_offset(alignment, ptr as isize);
|
||||
((ptr as isize) + fix) as *const T
|
||||
round_up_const(ptr, core::mem::align_of::<T>()) as *const T
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn align_ptr_mut<T>(ptr: *mut ()) -> *mut T {
|
||||
let alignment = core::mem::align_of::<T>() as isize;
|
||||
let fix = alignment_offset(alignment, ptr as isize);
|
||||
((ptr as isize) + fix) as *mut T
|
||||
round_up_mut(ptr, core::mem::align_of::<T>()) as *mut T
|
||||
}
|
||||
|
||||
/// Reads (deserializes) `length` array or list elements of type `tag` from `stream`,
|
||||
/// writing them into the buffer given by `storage`.
|
||||
///
|
||||
/// `alloc` is used for nested allocations (if elements themselves contain
|
||||
/// lists/arrays), see [recv_value].
|
||||
#[async_recursion(?Send)]
|
||||
async unsafe fn recv_elements<F>(
|
||||
stream: &TcpStream,
|
||||
elt_tag: Tag<'async_recursion>,
|
||||
length: usize,
|
||||
storage: *mut (),
|
||||
alloc: &(impl Fn(usize) -> F + 'async_recursion)
|
||||
) -> Result<(), smoltcp::Error>
|
||||
where
|
||||
F: Future<Output=*mut ()>,
|
||||
{
|
||||
// List of simple types are special-cased in the protocol for performance.
|
||||
match elt_tag {
|
||||
Tag::Bool => {
|
||||
let dest = core::slice::from_raw_parts_mut(storage as *mut u8, length);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
},
|
||||
Tag::Int32 => {
|
||||
let ptr = storage as *mut u32;
|
||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
drop(dest);
|
||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||
NativeEndian::from_slice_u32(dest);
|
||||
},
|
||||
Tag::Int64 | Tag::Float64 => {
|
||||
let ptr = storage as *mut u64;
|
||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
drop(dest);
|
||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||
NativeEndian::from_slice_u64(dest);
|
||||
},
|
||||
_ => {
|
||||
let mut data = storage;
|
||||
for _ in 0..length {
|
||||
recv_value(stream, elt_tag, &mut data, alloc).await?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reads (deserializes) a value of type `tag` from `stream`, writing the results to
|
||||
/// the kernel-side buffer `data` (the passed pointer to which is incremented to point
|
||||
/// past the just-received data). For nested allocations (lists/arrays), `alloc` is
|
||||
/// invoked any number of times with the size of the required allocation as a parameter
|
||||
/// (which is assumed to be correctly aligned for all payload types).
|
||||
#[async_recursion(?Send)]
|
||||
async unsafe fn recv_value<F>(stream: &TcpStream, tag: Tag<'async_recursion>, data: &mut *mut (),
|
||||
alloc: &(impl Fn(usize) -> F + 'async_recursion))
|
||||
|
@ -71,120 +134,63 @@ async unsafe fn recv_value<F>(stream: &TcpStream, tag: Tag<'async_recursion>, da
|
|||
})
|
||||
}
|
||||
Tag::Tuple(it, arity) => {
|
||||
*data = (*data).offset(alignment_offset(tag.alignment() as isize, *data as isize));
|
||||
let alignment = tag.alignment();
|
||||
*data = round_up_mut(*data, alignment);
|
||||
let mut it = it.clone();
|
||||
for _ in 0..arity {
|
||||
let tag = it.next().expect("truncated tag");
|
||||
recv_value(stream, tag, data, alloc).await?;
|
||||
recv_value(stream, tag, data, alloc).await?
|
||||
}
|
||||
// Take into account any tail padding (if element(s) with largest alignment
|
||||
// are not at the end).
|
||||
*data = round_up_mut(*data, alignment);
|
||||
Ok(())
|
||||
}
|
||||
Tag::List(it) => {
|
||||
#[repr(C)]
|
||||
struct List { elements: *mut (), length: u32 }
|
||||
consume_value!(*mut List, |ptr| {
|
||||
let length = proto_async::read_i32(stream).await? as usize;
|
||||
struct List { elements: *mut (), length: usize }
|
||||
consume_value!(*mut List, |ptr_to_list| {
|
||||
let tag = it.clone().next().expect("truncated tag");
|
||||
let data_size = tag.size() * length as usize +
|
||||
match tag {
|
||||
Tag::Int64 | Tag::Float64 => 4,
|
||||
_ => 0
|
||||
};
|
||||
let data = alloc(data_size + 8).await as *mut u8;
|
||||
*ptr = data as *mut List;
|
||||
let ptr = data as *mut List;
|
||||
let data = data.offset(8);
|
||||
let length = proto_async::read_i32(stream).await? as usize;
|
||||
|
||||
let alignment = tag.alignment();
|
||||
let mut data = data.offset(alignment_offset(alignment as isize, data as isize)) as *mut ();
|
||||
(*ptr).length = length as u32;
|
||||
(*ptr).elements = data;
|
||||
match tag {
|
||||
Tag::Bool => {
|
||||
let ptr = data as *mut u8;
|
||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
},
|
||||
Tag::Int32 => {
|
||||
let ptr = data as *mut u32;
|
||||
// reading as raw bytes and do endianness conversion later
|
||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
drop(dest);
|
||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||
NativeEndian::from_slice_u32(dest);
|
||||
},
|
||||
Tag::Int64 | Tag::Float64 => {
|
||||
let ptr = data as *mut u64;
|
||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
drop(dest);
|
||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||
NativeEndian::from_slice_u64(dest);
|
||||
},
|
||||
_ => {
|
||||
for _ in 0..(*ptr).length as usize {
|
||||
recv_value(stream, tag, &mut data, alloc).await?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
// To avoid multiple kernel CPU roundtrips, use a single allocation for
|
||||
// both the pointer/length List (slice) and the backing storage for the
|
||||
// elements. We can assume that alloc() is aligned suitably, so just
|
||||
// need to take into account any extra padding required.
|
||||
// (Note: At the time of writing, there will never actually be any types
|
||||
// with alignment larger than 8 bytes, so storage_offset == 0 always.)
|
||||
let list_size = 4 + 4;
|
||||
let storage_offset = round_up(list_size, tag.alignment());
|
||||
let storage_size = tag.size() * length;
|
||||
|
||||
let allocation = alloc(storage_offset + storage_size).await as *mut u8;
|
||||
*ptr_to_list = allocation as *mut List;
|
||||
let storage = allocation.offset(storage_offset as isize) as *mut ();
|
||||
|
||||
(**ptr_to_list).length = length;
|
||||
(**ptr_to_list).elements = storage;
|
||||
recv_elements(stream, tag, length, storage, alloc).await
|
||||
})
|
||||
}
|
||||
Tag::Array(it, num_dims) => {
|
||||
consume_value!(*mut (), |buffer| {
|
||||
let mut total_len: u32 = 1;
|
||||
// Deserialize length along each dimension and compute total number of
|
||||
// elements.
|
||||
let mut total_len: usize = 1;
|
||||
for _ in 0..num_dims {
|
||||
let len = proto_async::read_i32(stream).await? as u32;
|
||||
let len = proto_async::read_i32(stream).await? as usize;
|
||||
total_len *= len;
|
||||
consume_value!(u32, |ptr| *ptr = len )
|
||||
consume_value!(usize, |ptr| *ptr = len )
|
||||
}
|
||||
|
||||
// Allocate backing storage for elements; deserialize them.
|
||||
let elt_tag = it.clone().next().expect("truncated tag");
|
||||
let data_size = elt_tag.size() * total_len as usize +
|
||||
match elt_tag {
|
||||
Tag::Int64 | Tag::Float64 => 4,
|
||||
_ => 0
|
||||
};
|
||||
let mut data = alloc(data_size).await;
|
||||
|
||||
let alignment = tag.alignment();
|
||||
data = data.offset(alignment_offset(alignment as isize, data as isize));
|
||||
*buffer = data;
|
||||
let length = total_len as usize;
|
||||
match elt_tag {
|
||||
Tag::Bool => {
|
||||
let ptr = data as *mut u8;
|
||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
},
|
||||
Tag::Int32 => {
|
||||
let ptr = data as *mut u32;
|
||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
drop(dest);
|
||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||
NativeEndian::from_slice_u32(dest);
|
||||
},
|
||||
Tag::Int64 | Tag::Float64 => {
|
||||
let ptr = data as *mut u64;
|
||||
let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8);
|
||||
proto_async::read_chunk(stream, dest).await?;
|
||||
drop(dest);
|
||||
let dest = core::slice::from_raw_parts_mut(ptr, length);
|
||||
NativeEndian::from_slice_u64(dest);
|
||||
},
|
||||
_ => {
|
||||
for _ in 0..length {
|
||||
recv_value(stream, elt_tag, &mut data, alloc).await?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
*buffer = alloc(elt_tag.size() * total_len).await;
|
||||
recv_elements(stream, elt_tag, total_len, *buffer, alloc).await
|
||||
})
|
||||
}
|
||||
Tag::Range(it) => {
|
||||
*data = (*data).offset(alignment_offset(tag.alignment() as isize, *data as isize));
|
||||
*data = round_up_mut(*data, tag.alignment());
|
||||
let tag = it.clone().next().expect("truncated tag");
|
||||
recv_value(stream, tag, data, alloc).await?;
|
||||
recv_value(stream, tag, data, alloc).await?;
|
||||
|
@ -211,6 +217,36 @@ pub async fn recv_return<F>(stream: &TcpStream, tag_bytes: &[u8], data: *mut (),
|
|||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn send_elements<W>(writer: &mut W, elt_tag: Tag, length: usize, data: *const ())
|
||||
-> Result<(), Error>
|
||||
where W: Write + ?Sized
|
||||
{
|
||||
writer.write_u8(elt_tag.as_u8())?;
|
||||
match elt_tag {
|
||||
// we cannot use NativeEndian::from_slice_i32 as the data is not mutable,
|
||||
// and that is not needed as the data is already in native endian
|
||||
Tag::Bool => {
|
||||
let slice = core::slice::from_raw_parts(data as *const u8, length);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
Tag::Int32 => {
|
||||
let slice = core::slice::from_raw_parts(data as *const u8, length * 4);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
Tag::Int64 | Tag::Float64 => {
|
||||
let slice = core::slice::from_raw_parts(data as *const u8, length * 8);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
_ => {
|
||||
let mut data = data;
|
||||
for _ in 0..length {
|
||||
send_value(writer, elt_tag, &mut data)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn send_value<W>(writer: &mut W, tag: Tag, data: &mut *const ())
|
||||
-> Result<(), Error>
|
||||
where W: Write + ?Sized
|
||||
|
@ -244,46 +280,23 @@ unsafe fn send_value<W>(writer: &mut W, tag: Tag, data: &mut *const ())
|
|||
Tag::Tuple(it, arity) => {
|
||||
let mut it = it.clone();
|
||||
writer.write_u8(arity)?;
|
||||
let mut max_alignment = 0;
|
||||
for _ in 0..arity {
|
||||
let tag = it.next().expect("truncated tag");
|
||||
max_alignment = core::cmp::max(max_alignment, tag.alignment());
|
||||
send_value(writer, tag, data)?
|
||||
}
|
||||
*data = round_up_const(*data, max_alignment);
|
||||
Ok(())
|
||||
}
|
||||
Tag::List(it) => {
|
||||
#[repr(C)]
|
||||
struct List { elements: *const (), length: u32 }
|
||||
consume_value!(&List, |ptr| {
|
||||
let length = (**ptr).length as isize;
|
||||
let length = (**ptr).length as usize;
|
||||
writer.write_u32((*ptr).length)?;
|
||||
let tag = it.clone().next().expect("truncated tag");
|
||||
let mut data = (**ptr).elements;
|
||||
writer.write_u8(tag.as_u8())?;
|
||||
match tag {
|
||||
Tag::Bool => {
|
||||
// we can pretend this is u8...
|
||||
let ptr1 = align_ptr::<u8>(data);
|
||||
let slice = core::slice::from_raw_parts(ptr1, length as usize);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
Tag::Int32 => {
|
||||
let ptr1 = align_ptr::<i32>(data);
|
||||
let slice = core::slice::from_raw_parts(ptr1 as *const u8, length as usize * 4);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
Tag::Int64 | Tag::Float64 => {
|
||||
let ptr1 = align_ptr::<i64>(data);
|
||||
let slice = core::slice::from_raw_parts(ptr1 as *const u8, length as usize * 8);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
// non-primitive types, not sure if this would happen but we can handle it...
|
||||
_ => {
|
||||
for _ in 0..length {
|
||||
send_value(writer, tag, &mut data)?;
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
send_elements(writer, tag, length, (**ptr).elements)
|
||||
})
|
||||
}
|
||||
Tag::Array(it, num_dims) => {
|
||||
|
@ -298,33 +311,8 @@ unsafe fn send_value<W>(writer: &mut W, tag: Tag, data: &mut *const ())
|
|||
total_len *= *len;
|
||||
})
|
||||
}
|
||||
let mut data = *buffer;
|
||||
let length = total_len as isize;
|
||||
writer.write_u8(elt_tag.as_u8())?;
|
||||
match elt_tag {
|
||||
Tag::Bool => {
|
||||
let ptr1 = align_ptr::<u8>(data);
|
||||
let slice = core::slice::from_raw_parts(ptr1, length as usize);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
Tag::Int32 => {
|
||||
let ptr1 = align_ptr::<i32>(data);
|
||||
let slice = core::slice::from_raw_parts(ptr1 as *const u8, length as usize * 4);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
Tag::Int64 | Tag::Float64 => {
|
||||
let ptr1 = align_ptr::<i64>(data);
|
||||
let slice = core::slice::from_raw_parts(ptr1 as *const u8, length as usize * 8);
|
||||
writer.write_all(slice)?;
|
||||
},
|
||||
// non-primitive types, not sure if this would happen but we can handle it...
|
||||
_ => {
|
||||
for _ in 0..length {
|
||||
send_value(writer, elt_tag, &mut data)?;
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
let length = total_len as usize;
|
||||
send_elements(writer, elt_tag, length, *buffer)
|
||||
})
|
||||
}
|
||||
Tag::Range(it) => {
|
||||
|
@ -448,18 +436,15 @@ mod tag {
|
|||
let it = it.clone();
|
||||
it.take(3).map(|t| t.alignment()).max().unwrap()
|
||||
}
|
||||
// CSlice basically
|
||||
Tag::Bytes | Tag::String | Tag::ByteArray =>
|
||||
// the ptr/length(s) pair is basically CSlice
|
||||
Tag::Bytes | Tag::String | Tag::ByteArray | Tag::List(_) | Tag::Array(_, _) =>
|
||||
core::mem::align_of::<CSlice<()>>(),
|
||||
// array buffer is allocated, so no need for alignment first
|
||||
Tag::List(_) | Tag::Array(_, _) => 1,
|
||||
// will not be sent from the host
|
||||
_ => unreachable!("unexpected tag from host")
|
||||
Tag::Keyword(_) => unreachable!("Tag::Keyword should not appear in composite types"),
|
||||
Tag::Object => core::mem::align_of::<u32>(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size(self) -> usize {
|
||||
use super::alignment_offset;
|
||||
match self {
|
||||
Tag::None => 0,
|
||||
Tag::Bool => 1,
|
||||
|
@ -471,13 +456,18 @@ mod tag {
|
|||
Tag::ByteArray => 8,
|
||||
Tag::Tuple(it, arity) => {
|
||||
let mut size = 0;
|
||||
let mut max_alignment = 0;
|
||||
let mut it = it.clone();
|
||||
for _ in 0..arity {
|
||||
let tag = it.next().expect("truncated tag");
|
||||
let alignment = tag.alignment();
|
||||
max_alignment = core::cmp::max(max_alignment, alignment);
|
||||
size = super::round_up(size, alignment);
|
||||
size += tag.size();
|
||||
// includes padding
|
||||
size += alignment_offset(tag.alignment() as isize, size as isize) as usize;
|
||||
}
|
||||
// Take into account any tail padding (if element(s) with largest
|
||||
// alignment are not at the end).
|
||||
size = super::round_up(size, max_alignment);
|
||||
size
|
||||
}
|
||||
Tag::List(_) => 4,
|
||||
|
|
Loading…
Reference in New Issue