compiler: add support for bytearray values in RPC (#714).

This commit is contained in:
whitequark 2017-06-09 07:10:30 +00:00
parent 9ed4e9c1cd
commit 284382b1f5
4 changed files with 19 additions and 2 deletions

View File

@ -1170,6 +1170,8 @@ class LLVMIRGenerator:
return b"s" return b"s"
elif builtins.is_bytes(typ): elif builtins.is_bytes(typ):
return b"B" return b"B"
elif builtins.is_bytearray(typ):
return b"A"
elif builtins.is_list(typ): elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ), return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
error_handler) error_handler)

View File

@ -371,6 +371,8 @@ class CommKernel:
return self._read_string() return self._read_string()
elif tag == "B": elif tag == "B":
return self._read_bytes() return self._read_bytes()
elif tag == "A":
return self._read_bytes()
elif tag == "l": elif tag == "l":
length = self._read_int32() length = self._read_int32()
return [self._receive_rpc_value(embedding_map) for _ in range(length)] return [self._receive_rpc_value(embedding_map) for _ in range(length)]
@ -467,6 +469,10 @@ class CommKernel:
check(isinstance(value, bytes), check(isinstance(value, bytes),
lambda: "bytes") lambda: "bytes")
self._write_bytes(value) self._write_bytes(value)
elif tag == "A":
check(isinstance(value, bytearray),
lambda: "bytearray")
self._write_bytes(value)
elif tag == "l": elif tag == "l":
check(isinstance(value, list), check(isinstance(value, list),
lambda: "list") lambda: "list")

View File

@ -28,7 +28,7 @@ unsafe fn recv_value(reader: &mut Read, tag: Tag, data: &mut *mut (),
consume_value!(u64, |ptr| { consume_value!(u64, |ptr| {
*ptr = reader.read_u64()?; Ok(()) *ptr = reader.read_u64()?; Ok(())
}), }),
Tag::String | Tag::Bytes => { Tag::String | Tag::Bytes | Tag::ByteArray => {
consume_value!(CMutSlice<u8>, |ptr| { consume_value!(CMutSlice<u8>, |ptr| {
let length = reader.read_u32()? as usize; let length = reader.read_u32()? as usize;
*ptr = CMutSlice::new(alloc(length)? as *mut u8, length); *ptr = CMutSlice::new(alloc(length)? as *mut u8, length);
@ -108,7 +108,7 @@ unsafe fn send_value(writer: &mut Write, tag: Tag, data: &mut *const ()) -> io::
Tag::String => Tag::String =>
consume_value!(CSlice<u8>, |ptr| consume_value!(CSlice<u8>, |ptr|
writer.write_string(str::from_utf8((*ptr).as_ref()).unwrap())), writer.write_string(str::from_utf8((*ptr).as_ref()).unwrap())),
Tag::Bytes => Tag::Bytes | Tag::ByteArray =>
consume_value!(CSlice<u8>, |ptr| consume_value!(CSlice<u8>, |ptr|
writer.write_bytes((*ptr).as_ref())), writer.write_bytes((*ptr).as_ref())),
Tag::Tuple(it, arity) => { Tag::Tuple(it, arity) => {
@ -206,6 +206,7 @@ mod tag {
Float64, Float64,
String, String,
Bytes, Bytes,
ByteArray,
Tuple(TagIterator<'a>, u8), Tuple(TagIterator<'a>, u8),
List(TagIterator<'a>), List(TagIterator<'a>),
Array(TagIterator<'a>), Array(TagIterator<'a>),
@ -224,6 +225,7 @@ mod tag {
Tag::Float64 => b'f', Tag::Float64 => b'f',
Tag::String => b's', Tag::String => b's',
Tag::Bytes => b'B', Tag::Bytes => b'B',
Tag::ByteArray => b'A',
Tag::Tuple(_, _) => b't', Tag::Tuple(_, _) => b't',
Tag::List(_) => b'l', Tag::List(_) => b'l',
Tag::Array(_) => b'a', Tag::Array(_) => b'a',
@ -242,6 +244,7 @@ mod tag {
Tag::Float64 => 8, Tag::Float64 => 8,
Tag::String => 4, Tag::String => 4,
Tag::Bytes => 4, Tag::Bytes => 4,
Tag::ByteArray => 4,
Tag::Tuple(it, arity) => { Tag::Tuple(it, arity) => {
let mut size = 0; let mut size = 0;
for _ in 0..arity { for _ in 0..arity {
@ -287,6 +290,7 @@ mod tag {
b'f' => Tag::Float64, b'f' => Tag::Float64,
b's' => Tag::String, b's' => Tag::String,
b'B' => Tag::Bytes, b'B' => Tag::Bytes,
b'A' => Tag::ByteArray,
b't' => { b't' => {
let count = self.data[0]; let count = self.data[0];
self.data = &self.data[1..]; self.data = &self.data[1..];
@ -336,6 +340,8 @@ mod tag {
write!(f, "String")?, write!(f, "String")?,
Tag::Bytes => Tag::Bytes =>
write!(f, "Bytes")?, write!(f, "Bytes")?,
Tag::ByteArray =>
write!(f, "ByteArray")?,
Tag::Tuple(it, _) => { Tag::Tuple(it, _) => {
write!(f, "Tuple(")?; write!(f, "Tuple(")?;
it.fmt(f)?; it.fmt(f)?;

View File

@ -41,6 +41,9 @@ class RoundtripTest(ExperimentCase):
def test_bytes(self): def test_bytes(self):
self.assertRoundtrip(b"foo") self.assertRoundtrip(b"foo")
def test_bytearray(self):
self.assertRoundtrip(bytearray(b"foo"))
def test_list(self): def test_list(self):
self.assertRoundtrip([10]) self.assertRoundtrip([10])