From 8783ba207214d52c67fa9c2236250b20319befdb Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Sat, 8 Aug 2020 22:34:46 +0100 Subject: [PATCH] compiler/firmware: RPCs for ndarrays --- artiq/compiler/ir.py | 3 +- artiq/coredevice/comm_kernel.py | 18 ++++++- artiq/firmware/libproto_artiq/rpc_proto.rs | 58 +++++++++++++++++++--- artiq/test/coredevice/test_embedding.py | 32 ++++++++++-- 4 files changed, 97 insertions(+), 14 deletions(-) diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 941b04d36..3f984606f 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -64,7 +64,8 @@ def rpc_tag(typ, error_handler): elif builtins.is_list(typ): return b"l" + rpc_tag(builtins.get_iterable_elt(typ), error_handler) elif builtins.is_array(typ): - return b"a" + rpc_tag(builtins.get_iterable_elt(typ), error_handler) + num_dims = typ["num_dims"].value + return b"a" + bytes([num_dims]) + rpc_tag(typ["elt"], error_handler) elif builtins.is_range(typ): return b"r" + rpc_tag(builtins.get_iterable_elt(typ), error_handler) elif is_keyword(typ): diff --git a/artiq/coredevice/comm_kernel.py b/artiq/coredevice/comm_kernel.py index 838bad042..41bddd553 100644 --- a/artiq/coredevice/comm_kernel.py +++ b/artiq/coredevice/comm_kernel.py @@ -276,8 +276,10 @@ class CommKernel: length = self._read_int32() return [self._receive_rpc_value(embedding_map) for _ in range(length)] elif tag == "a": - length = self._read_int32() - return numpy.array([self._receive_rpc_value(embedding_map) for _ in range(length)]) + num_dims = self._read_int8() + shape = tuple(self._read_int32() for _ in range(num_dims)) + elems = [self._receive_rpc_value(embedding_map) for _ in range(numpy.prod(shape))] + return numpy.array(elems).reshape(shape) elif tag == "r": start = self._receive_rpc_value(embedding_map) stop = self._receive_rpc_value(embedding_map) @@ -380,6 +382,18 @@ class CommKernel: tags_copy = bytearray(tags) self._send_rpc_value(tags_copy, elt, root, function) self._skip_rpc_value(tags) + elif tag == "a": + check(isinstance(value, numpy.ndarray), + lambda: "numpy.ndarray") + num_dims = tags.pop(0) + check(num_dims == len(value.shape), + lambda: "{}-dimensional numpy.ndarray".format(num_dims)) + for s in value.shape: + self._write_int32(s) + for elt in value.reshape((-1,), order="C"): + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, elt, root, function) + self._skip_rpc_value(tags) elif tag == "r": check(isinstance(value, range), lambda: "range") diff --git a/artiq/firmware/libproto_artiq/rpc_proto.rs b/artiq/firmware/libproto_artiq/rpc_proto.rs index 750f3f467..b35e6b905 100644 --- a/artiq/firmware/libproto_artiq/rpc_proto.rs +++ b/artiq/firmware/libproto_artiq/rpc_proto.rs @@ -48,7 +48,7 @@ unsafe fn recv_value(reader: &mut R, tag: Tag, data: &mut *mut (), } Ok(()) } - Tag::List(it) | Tag::Array(it) => { + Tag::List(it) => { #[repr(C)] struct List { elements: *mut (), length: u32 }; consume_value!(List, |ptr| { @@ -64,6 +64,25 @@ unsafe fn recv_value(reader: &mut R, tag: Tag, data: &mut *mut (), Ok(()) }) } + Tag::Array(it, num_dims) => { + consume_value!(*mut (), |buffer| { + let mut total_len: u32 = 1; + for _ in 0..num_dims { + let len = reader.read_u32()?; + total_len *= len; + consume_value!(u32, |ptr| *ptr = len ) + } + + let elt_tag = it.clone().next().expect("truncated tag"); + *buffer = alloc(elt_tag.size() * total_len as usize)?; + + let mut data = *buffer; + for _ in 0..total_len { + recv_value(reader, elt_tag, &mut data, alloc)? + } + Ok(()) + }) + } Tag::Range(it) => { let tag = it.clone().next().expect("truncated tag"); recv_value(reader, tag, data, alloc)?; @@ -132,7 +151,7 @@ unsafe fn send_value(writer: &mut W, tag: Tag, data: &mut *const ()) } Ok(()) } - Tag::List(it) | Tag::Array(it) => { + Tag::List(it) => { #[repr(C)] struct List { elements: *const (), length: u32 }; consume_value!(List, |ptr| { @@ -145,6 +164,25 @@ unsafe fn send_value(writer: &mut W, tag: Tag, data: &mut *const ()) Ok(()) }) } + Tag::Array(it, num_dims) => { + writer.write_u8(num_dims)?; + consume_value!(*const(), |buffer| { + let elt_tag = it.clone().next().expect("truncated tag"); + + let mut total_len = 1; + for _ in 0..num_dims { + consume_value!(u32, |len| { + writer.write_u32(*len)?; + total_len *= *len; + }) + } + let mut data = *buffer; + for _ in 0..total_len as usize { + send_value(writer, elt_tag, &mut data)?; + } + Ok(()) + }) + } Tag::Range(it) => { let tag = it.clone().next().expect("truncated tag"); send_value(writer, tag, data)?; @@ -226,7 +264,7 @@ mod tag { ByteArray, Tuple(TagIterator<'a>, u8), List(TagIterator<'a>), - Array(TagIterator<'a>), + Array(TagIterator<'a>, u8), Range(TagIterator<'a>), Keyword(TagIterator<'a>), Object @@ -245,7 +283,7 @@ mod tag { Tag::ByteArray => b'A', Tag::Tuple(_, _) => b't', Tag::List(_) => b'l', - Tag::Array(_) => b'a', + Tag::Array(_, _) => b'a', Tag::Range(_) => b'r', Tag::Keyword(_) => b'k', Tag::Object => b'O', @@ -272,7 +310,7 @@ mod tag { size } Tag::List(_) => 8, - Tag::Array(_) => 8, + Tag::Array(_, num_dims) => 4 * (1 + num_dims as usize), Tag::Range(it) => { let tag = it.clone().next().expect("truncated tag"); tag.size() * 3 @@ -315,7 +353,11 @@ mod tag { Tag::Tuple(self.sub(count), count) } b'l' => Tag::List(self.sub(1)), - b'a' => Tag::Array(self.sub(1)), + b'a' => { + let count = self.data[0]; + self.data = &self.data[1..]; + Tag::Array(self.sub(1), count) + } b'r' => Tag::Range(self.sub(1)), b'k' => Tag::Keyword(self.sub(1)), b'O' => Tag::Object, @@ -370,10 +412,10 @@ mod tag { it.fmt(f)?; write!(f, ")")?; } - Tag::Array(it) => { + Tag::Array(it, num_dims) => { write!(f, "Array(")?; it.fmt(f)?; - write!(f, ")")?; + write!(f, ", {})", num_dims)?; } Tag::Range(it) => { write!(f, "Range(")?; diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index 38e1a3d07..fbd47a1a4 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -22,6 +22,12 @@ class RoundtripTest(ExperimentCase): self.assertEqual(obj, objcopy) exp.roundtrip(obj, callback) + def assertArrayRoundtrip(self, obj): + exp = self.create(_Roundtrip) + def callback(objcopy): + numpy.testing.assert_array_equal(obj, objcopy) + exp.roundtrip(obj, callback) + def test_None(self): self.assertRoundtrip(None) @@ -48,9 +54,6 @@ class RoundtripTest(ExperimentCase): def test_list(self): self.assertRoundtrip([10]) - def test_array(self): - self.assertRoundtrip(numpy.array([10])) - def test_object(self): obj = object() self.assertRoundtrip(obj) @@ -64,6 +67,19 @@ class RoundtripTest(ExperimentCase): def test_list_mixed_tuple(self): self.assertRoundtrip([(0x12345678, [("foo", [0.0, 1.0], [0, 1])])]) + def test_array_1d(self): + self.assertArrayRoundtrip(numpy.array([1, 2, 3], dtype=numpy.int32)) + self.assertArrayRoundtrip(numpy.array([1.0, 2.0, 3.0])) + self.assertArrayRoundtrip(numpy.array(["a", "b", "c"])) + + def test_array_2d(self): + self.assertArrayRoundtrip(numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)) + self.assertArrayRoundtrip(numpy.array([[1.0, 2.0], [3.0, 4.0]])) + self.assertArrayRoundtrip(numpy.array([["a", "b"], ["c", "d"]])) + + def test_array_jagged(self): + self.assertArrayRoundtrip(numpy.array([[1, 2], [3]])) + class _DefaultArg(EnvExperiment): def build(self): @@ -117,6 +133,12 @@ class _RPCTypes(EnvExperiment): def return_range(self) -> TRange32: return range(10) + def return_array(self) -> TArray(TInt32): + return numpy.array([1, 2]) + + def return_matrix(self) -> TArray(TInt32, 2): + return numpy.array([[1, 2], [3, 4]]) + def return_mismatch(self): return b"foo" @@ -132,6 +154,8 @@ class _RPCTypes(EnvExperiment): core_log(self.return_tuple()) core_log(self.return_list()) core_log(self.return_range()) + core_log(self.return_array()) + core_log(self.return_matrix()) def accept(self, value): pass @@ -150,6 +174,8 @@ class _RPCTypes(EnvExperiment): self.accept((2, 3)) self.accept([1, 2]) self.accept(range(10)) + self.accept(numpy.array([1, 2])) + self.accept(numpy.array([[1, 2], [3, 4]])) self.accept(self) @kernel