diff --git a/artiq/runtime.rs/src/rpc.rs b/artiq/runtime.rs/src/rpc.rs index cc01c8d10..beef28119 100644 --- a/artiq/runtime.rs/src/rpc.rs +++ b/artiq/runtime.rs/src/rpc.rs @@ -1,20 +1,84 @@ +use std::slice; use std::io::{self, Read, Write}; use proto::*; use self::tag::{Tag, TagIterator, split_tag}; -fn recv_value(reader: &mut Read, tag: Tag, data: &mut *const ()) -> io::Result<()> { +unsafe fn recv_value(reader: &mut Read, tag: Tag, data: &mut *mut (), + alloc: &Fn(usize) -> io::Result<*mut ()>) -> io::Result<()> { + macro_rules! consume_value { + ($ty:ty, |$ptr:ident| $map:expr) => ({ + let ptr = (*data) as *mut $ty; + *data = ptr.offset(1) as *mut (); + (|$ptr: *mut $ty| $map)(ptr) + }) + } + match tag { Tag::None => Ok(()), - _ => unreachable!() + Tag::Bool => + consume_value!(u8, |ptr| { + *ptr = try!(read_u8(reader)); Ok(()) + }), + Tag::Int32 => + consume_value!(u32, |ptr| { + *ptr = try!(read_u32(reader)); Ok(()) + }), + Tag::Int64 | Tag::Float64 => + consume_value!(u64, |ptr| { + *ptr = try!(read_u64(reader)); Ok(()) + }), + Tag::String => { + consume_value!(*mut u8, |ptr| { + let length = try!(read_u32(reader)); + // NB: the received string includes a trailing \0 + *ptr = try!(alloc(length as usize)) as *mut u8; + try!(reader.read_exact(slice::from_raw_parts_mut(*ptr, length as usize))); + Ok(()) + }) + } + Tag::Tuple(it, arity) => { + let mut it = it.clone(); + for _ in 0..arity { + let tag = it.next().expect("truncated tag"); + try!(recv_value(reader, tag, data, alloc)) + } + Ok(()) + } + Tag::List(it) | Tag::Array(it) => { + struct List { length: u32, elements: *mut () }; + consume_value!(List, |ptr| { + (*ptr).length = try!(read_u32(reader)); + + let tag = it.clone().next().expect("truncated tag"); + (*ptr).elements = try!(alloc(tag.size() * (*ptr).length as usize)); + + let mut data = (*ptr).elements; + for _ in 0..(*ptr).length as usize { + try!(recv_value(reader, tag, &mut data, alloc)); + } + Ok(()) + }) + } + Tag::Range(it) => { + let tag = it.clone().next().expect("truncated tag"); + try!(recv_value(reader, tag, data, alloc)); + try!(recv_value(reader, tag, data, alloc)); + try!(recv_value(reader, tag, data, alloc)); + Ok(()) + } + Tag::Keyword(_) => unreachable!(), + Tag::Object => unreachable!() } } -pub fn recv_return(reader: &mut Read, tag_bytes: &[u8], data: *const ()) -> io::Result<()> { +pub fn recv_return(reader: &mut Read, tag_bytes: &[u8], data: *mut (), + alloc: &Fn(usize) -> io::Result<*mut ()>) -> io::Result<()> { let mut it = TagIterator::new(tag_bytes); trace!("recv ...->{}", it); + let tag = it.next().expect("truncated tag"); let mut data = data; - try!(recv_value(reader, it.next().expect("RPC without a return value"), &mut data)); + try!(unsafe { recv_value(reader, tag, &mut data, alloc) }); Ok(()) } @@ -37,26 +101,18 @@ unsafe fn send_value(writer: &mut Write, tag: Tag, data: &mut *const ()) -> io:: try!(write_u8(writer, tag.as_u8())); match tag { Tag::None => Ok(()), - Tag::Bool => { + Tag::Bool => consume_value!(u8, |ptr| - write_u8(writer, *ptr)) - } - Tag::Int32 => { + write_u8(writer, *ptr)), + Tag::Int32 => consume_value!(u32, |ptr| - write_u32(writer, *ptr)) - } - Tag::Int64 => { + write_u32(writer, *ptr)), + Tag::Int64 | Tag::Float64 => consume_value!(u64, |ptr| - write_u64(writer, *ptr)) - } - Tag::Float64 => { - consume_value!(u64, |ptr| - write_u64(writer, *ptr)) - } - Tag::String => { + write_u64(writer, *ptr)), + Tag::String => consume_value!(*const u8, |ptr| - write_string(writer, from_c_str(*ptr))) - } + write_string(writer, from_c_str(*ptr))), Tag::Tuple(it, arity) => { let mut it = it.clone(); try!(write_u8(writer, arity)); @@ -172,6 +228,33 @@ mod tag { Tag::Object => b'O', } } + + pub fn size(self) -> usize { + match self { + Tag::None => 0, + Tag::Bool => 1, + Tag::Int32 => 4, + Tag::Int64 => 8, + Tag::Float64 => 8, + Tag::String => 4, + Tag::Tuple(it, arity) => { + let mut size = 0; + for _ in 0..arity { + let tag = it.clone().next().expect("truncated tag"); + size += tag.size(); + } + size + } + Tag::List(_) => 8, + Tag::Array(_) => 8, + Tag::Range(it) => { + let tag = it.clone().next().expect("truncated tag"); + tag.size() * 3 + } + Tag::Keyword(_) => unreachable!(), + Tag::Object => unreachable!(), + } + } } #[derive(Debug, Clone, Copy)] @@ -245,7 +328,7 @@ mod tag { try!(write!(f, "Float64")), Tag::String => try!(write!(f, "String")), - Tag::Tuple(it, cnt) => { + Tag::Tuple(it, _) => { try!(write!(f, "Tuple(")); try!(it.fmt(f)); try!(write!(f, ")")) diff --git a/artiq/runtime.rs/src/session.rs b/artiq/runtime.rs/src/session.rs index d4b85ead6..f549c90a7 100644 --- a/artiq/runtime.rs/src/session.rs +++ b/artiq/runtime.rs/src/session.rs @@ -240,7 +240,17 @@ fn process_host_message(waiter: Waiter, match reply { kern::RpcRecvRequest { slot } => { let mut data = io::Cursor::new(data); - rpc::recv_return(&mut data, &tag, slot) + rpc::recv_return(&mut data, &tag, slot, &|size| { + try!(kern_send(waiter, kern::RpcRecvReply { + alloc_size: size, exception: None + })); + kern_recv(waiter, |reply| { + match reply { + kern::RpcRecvRequest { slot } => Ok(slot), + _ => unreachable!() + } + }) + }) } other => unexpected!("unexpected reply from kernel CPU: {:?}", other)