From 9f898dd2b82b162e604443a67dc7e26602a17579 Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Sun, 9 Aug 2020 00:00:58 +0100 Subject: [PATCH] runtime/rpc: Support new TArray layout (ndarrays) This is a port of the respective commit in the main ARTIQ repository. --- src/runtime/src/rpc.rs | 58 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/src/runtime/src/rpc.rs b/src/runtime/src/rpc.rs index a0653fa..e558425 100644 --- a/src/runtime/src/rpc.rs +++ b/src/runtime/src/rpc.rs @@ -72,7 +72,7 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da } Ok(()) } - Tag::List(it) | Tag::Array(it) => { + Tag::List(it) => { #[repr(C)] struct List { elements: *mut (), length: u32 }; consume_value!(List, |ptr| { @@ -88,6 +88,25 @@ async unsafe fn recv_value(stream: &TcpStream, tag: Tag<'async_recursion>, da Ok(()) }) } + Tag::Array(it, num_dims) => { + consume_value!(*mut (), |buffer| { + let mut total_len: u32 = 1; + for _ in 0..num_dims { + let len = proto_async::read_i32(stream).await? as u32; + total_len *= len; + consume_value!(u32, |ptr| *ptr = len ) + } + + let elt_tag = it.clone().next().expect("truncated tag"); + *buffer = alloc(elt_tag.size() * total_len as usize).await; + + let mut data = *buffer; + for _ in 0..total_len { + recv_value(stream, elt_tag, &mut data, alloc).await? + } + Ok(()) + }) + } Tag::Range(it) => { let tag = it.clone().next().expect("truncated tag"); recv_value(stream, tag, data, alloc).await?; @@ -154,7 +173,7 @@ unsafe fn send_value(writer: &mut W, tag: Tag, data: &mut *const ()) } Ok(()) } - Tag::List(it) | Tag::Array(it) => { + Tag::List(it) => { #[repr(C)] struct List { elements: *const (), length: u32 }; consume_value!(List, |ptr| { @@ -167,6 +186,25 @@ unsafe fn send_value(writer: &mut W, tag: Tag, data: &mut *const ()) Ok(()) }) } + Tag::Array(it, num_dims) => { + writer.write_u8(num_dims)?; + consume_value!(*const(), |buffer| { + let elt_tag = it.clone().next().expect("truncated tag"); + + let mut total_len = 1; + for _ in 0..num_dims { + consume_value!(u32, |len| { + writer.write_u32(*len)?; + total_len *= *len; + }) + } + let mut data = *buffer; + for _ in 0..total_len as usize { + send_value(writer, elt_tag, &mut data)?; + } + Ok(()) + }) + } Tag::Range(it) => { let tag = it.clone().next().expect("truncated tag"); send_value(writer, tag, data)?; @@ -245,7 +283,7 @@ mod tag { ByteArray, Tuple(TagIterator<'a>, u8), List(TagIterator<'a>), - Array(TagIterator<'a>), + Array(TagIterator<'a>, u8), Range(TagIterator<'a>), Keyword(TagIterator<'a>), Object @@ -264,7 +302,7 @@ mod tag { Tag::ByteArray => b'A', Tag::Tuple(_, _) => b't', Tag::List(_) => b'l', - Tag::Array(_) => b'a', + Tag::Array(_, _) => b'a', Tag::Range(_) => b'r', Tag::Keyword(_) => b'k', Tag::Object => b'O', @@ -291,7 +329,7 @@ mod tag { size } Tag::List(_) => 8, - Tag::Array(_) => 8, + Tag::Array(_, num_dims) => 4 * (1 + num_dims as usize), Tag::Range(it) => { let tag = it.clone().next().expect("truncated tag"); tag.size() * 3 @@ -334,7 +372,11 @@ mod tag { Tag::Tuple(self.sub(count), count) } b'l' => Tag::List(self.sub(1)), - b'a' => Tag::Array(self.sub(1)), + b'a' => { + let count = self.data[0]; + self.data = &self.data[1..]; + Tag::Array(self.sub(1), count) + } b'r' => Tag::Range(self.sub(1)), b'k' => Tag::Keyword(self.sub(1)), b'O' => Tag::Object, @@ -389,10 +431,10 @@ mod tag { it.fmt(f)?; write!(f, ")")?; } - Tag::Array(it) => { + Tag::Array(it, num_dims) => { write!(f, "Array(")?; it.fmt(f)?; - write!(f, ")")?; + write!(f, ", {})", num_dims)?; } Tag::Range(it) => { write!(f, "Range(")?;