compiler/firmware: RPCs for ndarrays

pull/1508/head
David Nadlinger 2020-08-08 22:34:46 +01:00
parent 5472e830f6
commit 8783ba2072
4 changed files with 97 additions and 14 deletions

View File

@ -64,7 +64,8 @@ def rpc_tag(typ, error_handler):
elif builtins.is_list(typ):
return b"l" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
elif builtins.is_array(typ):
return b"a" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
num_dims = typ["num_dims"].value
return b"a" + bytes([num_dims]) + rpc_tag(typ["elt"], error_handler)
elif builtins.is_range(typ):
return b"r" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
elif is_keyword(typ):

View File

@ -276,8 +276,10 @@ class CommKernel:
length = self._read_int32()
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
elif tag == "a":
length = self._read_int32()
return numpy.array([self._receive_rpc_value(embedding_map) for _ in range(length)])
num_dims = self._read_int8()
shape = tuple(self._read_int32() for _ in range(num_dims))
elems = [self._receive_rpc_value(embedding_map) for _ in range(numpy.prod(shape))]
return numpy.array(elems).reshape(shape)
elif tag == "r":
start = self._receive_rpc_value(embedding_map)
stop = self._receive_rpc_value(embedding_map)
@ -380,6 +382,18 @@ class CommKernel:
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags)
elif tag == "a":
check(isinstance(value, numpy.ndarray),
lambda: "numpy.ndarray")
num_dims = tags.pop(0)
check(num_dims == len(value.shape),
lambda: "{}-dimensional numpy.ndarray".format(num_dims))
for s in value.shape:
self._write_int32(s)
for elt in value.reshape((-1,), order="C"):
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags)
elif tag == "r":
check(isinstance(value, range),
lambda: "range")

View File

@ -48,7 +48,7 @@ unsafe fn recv_value<R, E>(reader: &mut R, tag: Tag, data: &mut *mut (),
}
Ok(())
}
Tag::List(it) | Tag::Array(it) => {
Tag::List(it) => {
#[repr(C)]
struct List { elements: *mut (), length: u32 };
consume_value!(List, |ptr| {
@ -64,6 +64,25 @@ unsafe fn recv_value<R, E>(reader: &mut R, tag: Tag, data: &mut *mut (),
Ok(())
})
}
Tag::Array(it, num_dims) => {
consume_value!(*mut (), |buffer| {
let mut total_len: u32 = 1;
for _ in 0..num_dims {
let len = reader.read_u32()?;
total_len *= len;
consume_value!(u32, |ptr| *ptr = len )
}
let elt_tag = it.clone().next().expect("truncated tag");
*buffer = alloc(elt_tag.size() * total_len as usize)?;
let mut data = *buffer;
for _ in 0..total_len {
recv_value(reader, elt_tag, &mut data, alloc)?
}
Ok(())
})
}
Tag::Range(it) => {
let tag = it.clone().next().expect("truncated tag");
recv_value(reader, tag, data, alloc)?;
@ -132,7 +151,7 @@ unsafe fn send_value<W>(writer: &mut W, tag: Tag, data: &mut *const ())
}
Ok(())
}
Tag::List(it) | Tag::Array(it) => {
Tag::List(it) => {
#[repr(C)]
struct List { elements: *const (), length: u32 };
consume_value!(List, |ptr| {
@ -145,6 +164,25 @@ unsafe fn send_value<W>(writer: &mut W, tag: Tag, data: &mut *const ())
Ok(())
})
}
Tag::Array(it, num_dims) => {
writer.write_u8(num_dims)?;
consume_value!(*const(), |buffer| {
let elt_tag = it.clone().next().expect("truncated tag");
let mut total_len = 1;
for _ in 0..num_dims {
consume_value!(u32, |len| {
writer.write_u32(*len)?;
total_len *= *len;
})
}
let mut data = *buffer;
for _ in 0..total_len as usize {
send_value(writer, elt_tag, &mut data)?;
}
Ok(())
})
}
Tag::Range(it) => {
let tag = it.clone().next().expect("truncated tag");
send_value(writer, tag, data)?;
@ -226,7 +264,7 @@ mod tag {
ByteArray,
Tuple(TagIterator<'a>, u8),
List(TagIterator<'a>),
Array(TagIterator<'a>),
Array(TagIterator<'a>, u8),
Range(TagIterator<'a>),
Keyword(TagIterator<'a>),
Object
@ -245,7 +283,7 @@ mod tag {
Tag::ByteArray => b'A',
Tag::Tuple(_, _) => b't',
Tag::List(_) => b'l',
Tag::Array(_) => b'a',
Tag::Array(_, _) => b'a',
Tag::Range(_) => b'r',
Tag::Keyword(_) => b'k',
Tag::Object => b'O',
@ -272,7 +310,7 @@ mod tag {
size
}
Tag::List(_) => 8,
Tag::Array(_) => 8,
Tag::Array(_, num_dims) => 4 * (1 + num_dims as usize),
Tag::Range(it) => {
let tag = it.clone().next().expect("truncated tag");
tag.size() * 3
@ -315,7 +353,11 @@ mod tag {
Tag::Tuple(self.sub(count), count)
}
b'l' => Tag::List(self.sub(1)),
b'a' => Tag::Array(self.sub(1)),
b'a' => {
let count = self.data[0];
self.data = &self.data[1..];
Tag::Array(self.sub(1), count)
}
b'r' => Tag::Range(self.sub(1)),
b'k' => Tag::Keyword(self.sub(1)),
b'O' => Tag::Object,
@ -370,10 +412,10 @@ mod tag {
it.fmt(f)?;
write!(f, ")")?;
}
Tag::Array(it) => {
Tag::Array(it, num_dims) => {
write!(f, "Array(")?;
it.fmt(f)?;
write!(f, ")")?;
write!(f, ", {})", num_dims)?;
}
Tag::Range(it) => {
write!(f, "Range(")?;

View File

@ -22,6 +22,12 @@ class RoundtripTest(ExperimentCase):
self.assertEqual(obj, objcopy)
exp.roundtrip(obj, callback)
def assertArrayRoundtrip(self, obj):
exp = self.create(_Roundtrip)
def callback(objcopy):
numpy.testing.assert_array_equal(obj, objcopy)
exp.roundtrip(obj, callback)
def test_None(self):
self.assertRoundtrip(None)
@ -48,9 +54,6 @@ class RoundtripTest(ExperimentCase):
def test_list(self):
self.assertRoundtrip([10])
def test_array(self):
self.assertRoundtrip(numpy.array([10]))
def test_object(self):
obj = object()
self.assertRoundtrip(obj)
@ -64,6 +67,19 @@ class RoundtripTest(ExperimentCase):
def test_list_mixed_tuple(self):
self.assertRoundtrip([(0x12345678, [("foo", [0.0, 1.0], [0, 1])])])
def test_array_1d(self):
self.assertArrayRoundtrip(numpy.array([1, 2, 3], dtype=numpy.int32))
self.assertArrayRoundtrip(numpy.array([1.0, 2.0, 3.0]))
self.assertArrayRoundtrip(numpy.array(["a", "b", "c"]))
def test_array_2d(self):
self.assertArrayRoundtrip(numpy.array([[1, 2], [3, 4]], dtype=numpy.int32))
self.assertArrayRoundtrip(numpy.array([[1.0, 2.0], [3.0, 4.0]]))
self.assertArrayRoundtrip(numpy.array([["a", "b"], ["c", "d"]]))
def test_array_jagged(self):
self.assertArrayRoundtrip(numpy.array([[1, 2], [3]]))
class _DefaultArg(EnvExperiment):
def build(self):
@ -117,6 +133,12 @@ class _RPCTypes(EnvExperiment):
def return_range(self) -> TRange32:
return range(10)
def return_array(self) -> TArray(TInt32):
return numpy.array([1, 2])
def return_matrix(self) -> TArray(TInt32, 2):
return numpy.array([[1, 2], [3, 4]])
def return_mismatch(self):
return b"foo"
@ -132,6 +154,8 @@ class _RPCTypes(EnvExperiment):
core_log(self.return_tuple())
core_log(self.return_list())
core_log(self.return_range())
core_log(self.return_array())
core_log(self.return_matrix())
def accept(self, value):
pass
@ -150,6 +174,8 @@ class _RPCTypes(EnvExperiment):
self.accept((2, 3))
self.accept([1, 2])
self.accept(range(10))
self.accept(numpy.array([1, 2]))
self.accept(numpy.array([[1, 2], [3, 4]]))
self.accept(self)
@kernel