From 2ad1970004d7b6c658ee7b175a5b4f635d3d3022 Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Thu, 29 Dec 2022 12:27:50 +0100 Subject: [PATCH] rpc: Port over size/alignment fix for structs (tuples) with tail padding This ports over the following commits from the main ARTIQ repo: - 8740ec3dd52d85084237797881ea137492bfe070 - dbbe8e8ed4f852e623775b7bd3aec818cdd03376 - b9f13d48aa7e2c0652210152b971b21c3c419347 --- src/runtime/src/rpc.rs | 318 ++++++++++++++++++++--------------------- 1 file changed, 154 insertions(+), 164 deletions(-) diff --git a/src/runtime/src/rpc.rs b/src/runtime/src/rpc.rs index 3e41340..44b6394 100644 --- a/src/runtime/src/rpc.rs +++ b/src/runtime/src/rpc.rs @@ -15,22 +15,85 @@ use crate::proto_async; use self::tag::{Tag, TagIterator, split_tag}; #[inline] -fn alignment_offset(alignment: isize, ptr: isize) -> isize { - (alignment - ptr % alignment) % alignment +fn round_up(val: usize, power_of_two: usize) -> usize { + assert!(power_of_two.is_power_of_two()); + let max_rem = power_of_two - 1; + (val + max_rem) & (!max_rem) } +#[inline] +unsafe fn round_up_mut(ptr: *mut T, power_of_two: usize) -> *mut T { + round_up(ptr as usize, power_of_two) as *mut T +} + +#[inline] +unsafe fn round_up_const(ptr: *const T, power_of_two: usize) -> *const T { + round_up(ptr as usize, power_of_two) as *const T +} + +#[inline] unsafe fn align_ptr(ptr: *const ()) -> *const T { - let alignment = core::mem::align_of::() as isize; - let fix = alignment_offset(alignment, ptr as isize); - ((ptr as isize) + fix) as *const T + round_up_const(ptr, core::mem::align_of::()) as *const T } +#[inline] unsafe fn align_ptr_mut(ptr: *mut ()) -> *mut T { - let alignment = core::mem::align_of::() as isize; - let fix = alignment_offset(alignment, ptr as isize); - ((ptr as isize) + fix) as *mut T + round_up_mut(ptr, core::mem::align_of::()) as *mut T } +/// Reads (deserializes) `length` array or list elements of type `tag` from `stream`, +/// writing them into the buffer given by `storage`. +/// +/// `alloc` is used for nested allocations (if elements themselves contain +/// lists/arrays), see [recv_value]. +#[async_recursion(?Send)] +async unsafe fn recv_elements( + stream: &TcpStream, + elt_tag: Tag<'async_recursion>, + length: usize, + storage: *mut (), + alloc: &(impl Fn(usize) -> F + 'async_recursion) +) -> Result<(), smoltcp::Error> +where + F: Future, +{ + // List of simple types are special-cased in the protocol for performance. + match elt_tag { + Tag::Bool => { + let dest = core::slice::from_raw_parts_mut(storage as *mut u8, length); + proto_async::read_chunk(stream, dest).await?; + }, + Tag::Int32 => { + let ptr = storage as *mut u32; + let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4); + proto_async::read_chunk(stream, dest).await?; + drop(dest); + let dest = core::slice::from_raw_parts_mut(ptr, length); + NativeEndian::from_slice_u32(dest); + }, + Tag::Int64 | Tag::Float64 => { + let ptr = storage as *mut u64; + let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8); + proto_async::read_chunk(stream, dest).await?; + drop(dest); + let dest = core::slice::from_raw_parts_mut(ptr, length); + NativeEndian::from_slice_u64(dest); + }, + _ => { + let mut data = storage; + for _ in 0..length { + recv_value(stream, elt_tag, &mut data, alloc).await? + } + } + } + Ok(()) +} + +/// Reads (deserializes) a value of type `tag` from `stream`, writing the results to +/// the kernel-side buffer `data` (the passed pointer to which is incremented to point +/// past the just-received data). For nested allocations (lists/arrays), `alloc` is +/// invoked any number of times with the size of the required allocation as a parameter +/// (which is assumed to be correctly aligned for all payload types). #[async_recursion(?Send)] async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, data: &mut *mut (), alloc: &(impl Fn(usize) -> F + 'async_recursion)) @@ -71,120 +134,63 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da }) } Tag::Tuple(it, arity) => { - *data = (*data).offset(alignment_offset(tag.alignment() as isize, *data as isize)); + let alignment = tag.alignment(); + *data = round_up_mut(*data, alignment); let mut it = it.clone(); for _ in 0..arity { let tag = it.next().expect("truncated tag"); - recv_value(stream, tag, data, alloc).await?; + recv_value(stream, tag, data, alloc).await? } + // Take into account any tail padding (if element(s) with largest alignment + // are not at the end). + *data = round_up_mut(*data, alignment); Ok(()) } Tag::List(it) => { #[repr(C)] - struct List { elements: *mut (), length: u32 } - consume_value!(*mut List, |ptr| { - let length = proto_async::read_i32(stream).await? as usize; + struct List { elements: *mut (), length: usize } + consume_value!(*mut List, |ptr_to_list| { let tag = it.clone().next().expect("truncated tag"); - let data_size = tag.size() * length as usize + - match tag { - Tag::Int64 | Tag::Float64 => 4, - _ => 0 - }; - let data = alloc(data_size + 8).await as *mut u8; - *ptr = data as *mut List; - let ptr = data as *mut List; - let data = data.offset(8); + let length = proto_async::read_i32(stream).await? as usize; - let alignment = tag.alignment(); - let mut data = data.offset(alignment_offset(alignment as isize, data as isize)) as *mut (); - (*ptr).length = length as u32; - (*ptr).elements = data; - match tag { - Tag::Bool => { - let ptr = data as *mut u8; - let dest = core::slice::from_raw_parts_mut(ptr, length); - proto_async::read_chunk(stream, dest).await?; - }, - Tag::Int32 => { - let ptr = data as *mut u32; - // reading as raw bytes and do endianness conversion later - let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4); - proto_async::read_chunk(stream, dest).await?; - drop(dest); - let dest = core::slice::from_raw_parts_mut(ptr, length); - NativeEndian::from_slice_u32(dest); - }, - Tag::Int64 | Tag::Float64 => { - let ptr = data as *mut u64; - let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8); - proto_async::read_chunk(stream, dest).await?; - drop(dest); - let dest = core::slice::from_raw_parts_mut(ptr, length); - NativeEndian::from_slice_u64(dest); - }, - _ => { - for _ in 0..(*ptr).length as usize { - recv_value(stream, tag, &mut data, alloc).await? - } - } - } - Ok(()) + // To avoid multiple kernel CPU roundtrips, use a single allocation for + // both the pointer/length List (slice) and the backing storage for the + // elements. We can assume that alloc() is aligned suitably, so just + // need to take into account any extra padding required. + // (Note: At the time of writing, there will never actually be any types + // with alignment larger than 8 bytes, so storage_offset == 0 always.) + let list_size = 4 + 4; + let storage_offset = round_up(list_size, tag.alignment()); + let storage_size = tag.size() * length; + + let allocation = alloc(storage_offset + storage_size).await as *mut u8; + *ptr_to_list = allocation as *mut List; + let storage = allocation.offset(storage_offset as isize) as *mut (); + + (**ptr_to_list).length = length; + (**ptr_to_list).elements = storage; + recv_elements(stream, tag, length, storage, alloc).await }) } Tag::Array(it, num_dims) => { consume_value!(*mut (), |buffer| { - let mut total_len: u32 = 1; + // Deserialize length along each dimension and compute total number of + // elements. + let mut total_len: usize = 1; for _ in 0..num_dims { - let len = proto_async::read_i32(stream).await? as u32; + let len = proto_async::read_i32(stream).await? as usize; total_len *= len; - consume_value!(u32, |ptr| *ptr = len ) + consume_value!(usize, |ptr| *ptr = len ) } + // Allocate backing storage for elements; deserialize them. let elt_tag = it.clone().next().expect("truncated tag"); - let data_size = elt_tag.size() * total_len as usize + - match elt_tag { - Tag::Int64 | Tag::Float64 => 4, - _ => 0 - }; - let mut data = alloc(data_size).await; - - let alignment = tag.alignment(); - data = data.offset(alignment_offset(alignment as isize, data as isize)); - *buffer = data; - let length = total_len as usize; - match elt_tag { - Tag::Bool => { - let ptr = data as *mut u8; - let dest = core::slice::from_raw_parts_mut(ptr, length); - proto_async::read_chunk(stream, dest).await?; - }, - Tag::Int32 => { - let ptr = data as *mut u32; - let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 4); - proto_async::read_chunk(stream, dest).await?; - drop(dest); - let dest = core::slice::from_raw_parts_mut(ptr, length); - NativeEndian::from_slice_u32(dest); - }, - Tag::Int64 | Tag::Float64 => { - let ptr = data as *mut u64; - let dest = core::slice::from_raw_parts_mut(ptr as *mut u8, length * 8); - proto_async::read_chunk(stream, dest).await?; - drop(dest); - let dest = core::slice::from_raw_parts_mut(ptr, length); - NativeEndian::from_slice_u64(dest); - }, - _ => { - for _ in 0..length { - recv_value(stream, elt_tag, &mut data, alloc).await? - } - } - } - Ok(()) + *buffer = alloc(elt_tag.size() * total_len).await; + recv_elements(stream, elt_tag, total_len, *buffer, alloc).await }) } Tag::Range(it) => { - *data = (*data).offset(alignment_offset(tag.alignment() as isize, *data as isize)); + *data = round_up_mut(*data, tag.alignment()); let tag = it.clone().next().expect("truncated tag"); recv_value(stream, tag, data, alloc).await?; recv_value(stream, tag, data, alloc).await?; @@ -211,6 +217,36 @@ pub async fn recv_return(stream: &TcpStream, tag_bytes: &[u8], data: *mut (), Ok(()) } +unsafe fn send_elements(writer: &mut W, elt_tag: Tag, length: usize, data: *const ()) + -> Result<(), Error> + where W: Write + ?Sized +{ + writer.write_u8(elt_tag.as_u8())?; + match elt_tag { + // we cannot use NativeEndian::from_slice_i32 as the data is not mutable, + // and that is not needed as the data is already in native endian + Tag::Bool => { + let slice = core::slice::from_raw_parts(data as *const u8, length); + writer.write_all(slice)?; + }, + Tag::Int32 => { + let slice = core::slice::from_raw_parts(data as *const u8, length * 4); + writer.write_all(slice)?; + }, + Tag::Int64 | Tag::Float64 => { + let slice = core::slice::from_raw_parts(data as *const u8, length * 8); + writer.write_all(slice)?; + }, + _ => { + let mut data = data; + for _ in 0..length { + send_value(writer, elt_tag, &mut data)?; + } + } + } + Ok(()) +} + unsafe fn send_value(writer: &mut W, tag: Tag, data: &mut *const ()) -> Result<(), Error> where W: Write + ?Sized @@ -244,46 +280,23 @@ unsafe fn send_value(writer: &mut W, tag: Tag, data: &mut *const ()) Tag::Tuple(it, arity) => { let mut it = it.clone(); writer.write_u8(arity)?; + let mut max_alignment = 0; for _ in 0..arity { let tag = it.next().expect("truncated tag"); + max_alignment = core::cmp::max(max_alignment, tag.alignment()); send_value(writer, tag, data)? } + *data = round_up_const(*data, max_alignment); Ok(()) } Tag::List(it) => { #[repr(C)] struct List { elements: *const (), length: u32 } consume_value!(&List, |ptr| { - let length = (**ptr).length as isize; + let length = (**ptr).length as usize; writer.write_u32((*ptr).length)?; let tag = it.clone().next().expect("truncated tag"); - let mut data = (**ptr).elements; - writer.write_u8(tag.as_u8())?; - match tag { - Tag::Bool => { - // we can pretend this is u8... - let ptr1 = align_ptr::(data); - let slice = core::slice::from_raw_parts(ptr1, length as usize); - writer.write_all(slice)?; - }, - Tag::Int32 => { - let ptr1 = align_ptr::(data); - let slice = core::slice::from_raw_parts(ptr1 as *const u8, length as usize * 4); - writer.write_all(slice)?; - }, - Tag::Int64 | Tag::Float64 => { - let ptr1 = align_ptr::(data); - let slice = core::slice::from_raw_parts(ptr1 as *const u8, length as usize * 8); - writer.write_all(slice)?; - }, - // non-primitive types, not sure if this would happen but we can handle it... - _ => { - for _ in 0..length { - send_value(writer, tag, &mut data)?; - } - } - }; - Ok(()) + send_elements(writer, tag, length, (**ptr).elements) }) } Tag::Array(it, num_dims) => { @@ -298,33 +311,8 @@ unsafe fn send_value(writer: &mut W, tag: Tag, data: &mut *const ()) total_len *= *len; }) } - let mut data = *buffer; - let length = total_len as isize; - writer.write_u8(elt_tag.as_u8())?; - match elt_tag { - Tag::Bool => { - let ptr1 = align_ptr::(data); - let slice = core::slice::from_raw_parts(ptr1, length as usize); - writer.write_all(slice)?; - }, - Tag::Int32 => { - let ptr1 = align_ptr::(data); - let slice = core::slice::from_raw_parts(ptr1 as *const u8, length as usize * 4); - writer.write_all(slice)?; - }, - Tag::Int64 | Tag::Float64 => { - let ptr1 = align_ptr::(data); - let slice = core::slice::from_raw_parts(ptr1 as *const u8, length as usize * 8); - writer.write_all(slice)?; - }, - // non-primitive types, not sure if this would happen but we can handle it... - _ => { - for _ in 0..length { - send_value(writer, elt_tag, &mut data)?; - } - } - }; - Ok(()) + let length = total_len as usize; + send_elements(writer, elt_tag, length, *buffer) }) } Tag::Range(it) => { @@ -448,18 +436,15 @@ mod tag { let it = it.clone(); it.take(3).map(|t| t.alignment()).max().unwrap() } - // CSlice basically - Tag::Bytes | Tag::String | Tag::ByteArray => + // the ptr/length(s) pair is basically CSlice + Tag::Bytes | Tag::String | Tag::ByteArray | Tag::List(_) | Tag::Array(_, _) => core::mem::align_of::>(), - // array buffer is allocated, so no need for alignment first - Tag::List(_) | Tag::Array(_, _) => 1, - // will not be sent from the host - _ => unreachable!("unexpected tag from host") + Tag::Keyword(_) => unreachable!("Tag::Keyword should not appear in composite types"), + Tag::Object => core::mem::align_of::(), } } pub fn size(self) -> usize { - use super::alignment_offset; match self { Tag::None => 0, Tag::Bool => 1, @@ -471,13 +456,18 @@ mod tag { Tag::ByteArray => 8, Tag::Tuple(it, arity) => { let mut size = 0; + let mut max_alignment = 0; let mut it = it.clone(); for _ in 0..arity { let tag = it.next().expect("truncated tag"); + let alignment = tag.alignment(); + max_alignment = core::cmp::max(max_alignment, alignment); + size = super::round_up(size, alignment); size += tag.size(); - // includes padding - size += alignment_offset(tag.alignment() as isize, size as isize) as usize; } + // Take into account any tail padding (if element(s) with largest + // alignment are not at the end). + size = super::round_up(size, max_alignment); size } Tag::List(_) => 4,