mirror of
https://github.com/m-labs/artiq.git
synced 2025-01-25 01:48:12 +08:00
compiler/firmware: RPCs for ndarrays
This commit is contained in:
parent
5472e830f6
commit
8783ba2072
@ -64,7 +64,8 @@ def rpc_tag(typ, error_handler):
|
||||
elif builtins.is_list(typ):
|
||||
return b"l" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
|
||||
elif builtins.is_array(typ):
|
||||
return b"a" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
|
||||
num_dims = typ["num_dims"].value
|
||||
return b"a" + bytes([num_dims]) + rpc_tag(typ["elt"], error_handler)
|
||||
elif builtins.is_range(typ):
|
||||
return b"r" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
|
||||
elif is_keyword(typ):
|
||||
|
@ -276,8 +276,10 @@ class CommKernel:
|
||||
length = self._read_int32()
|
||||
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
|
||||
elif tag == "a":
|
||||
length = self._read_int32()
|
||||
return numpy.array([self._receive_rpc_value(embedding_map) for _ in range(length)])
|
||||
num_dims = self._read_int8()
|
||||
shape = tuple(self._read_int32() for _ in range(num_dims))
|
||||
elems = [self._receive_rpc_value(embedding_map) for _ in range(numpy.prod(shape))]
|
||||
return numpy.array(elems).reshape(shape)
|
||||
elif tag == "r":
|
||||
start = self._receive_rpc_value(embedding_map)
|
||||
stop = self._receive_rpc_value(embedding_map)
|
||||
@ -380,6 +382,18 @@ class CommKernel:
|
||||
tags_copy = bytearray(tags)
|
||||
self._send_rpc_value(tags_copy, elt, root, function)
|
||||
self._skip_rpc_value(tags)
|
||||
elif tag == "a":
|
||||
check(isinstance(value, numpy.ndarray),
|
||||
lambda: "numpy.ndarray")
|
||||
num_dims = tags.pop(0)
|
||||
check(num_dims == len(value.shape),
|
||||
lambda: "{}-dimensional numpy.ndarray".format(num_dims))
|
||||
for s in value.shape:
|
||||
self._write_int32(s)
|
||||
for elt in value.reshape((-1,), order="C"):
|
||||
tags_copy = bytearray(tags)
|
||||
self._send_rpc_value(tags_copy, elt, root, function)
|
||||
self._skip_rpc_value(tags)
|
||||
elif tag == "r":
|
||||
check(isinstance(value, range),
|
||||
lambda: "range")
|
||||
|
@ -48,7 +48,7 @@ unsafe fn recv_value<R, E>(reader: &mut R, tag: Tag, data: &mut *mut (),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Tag::List(it) | Tag::Array(it) => {
|
||||
Tag::List(it) => {
|
||||
#[repr(C)]
|
||||
struct List { elements: *mut (), length: u32 };
|
||||
consume_value!(List, |ptr| {
|
||||
@ -64,6 +64,25 @@ unsafe fn recv_value<R, E>(reader: &mut R, tag: Tag, data: &mut *mut (),
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
Tag::Array(it, num_dims) => {
|
||||
consume_value!(*mut (), |buffer| {
|
||||
let mut total_len: u32 = 1;
|
||||
for _ in 0..num_dims {
|
||||
let len = reader.read_u32()?;
|
||||
total_len *= len;
|
||||
consume_value!(u32, |ptr| *ptr = len )
|
||||
}
|
||||
|
||||
let elt_tag = it.clone().next().expect("truncated tag");
|
||||
*buffer = alloc(elt_tag.size() * total_len as usize)?;
|
||||
|
||||
let mut data = *buffer;
|
||||
for _ in 0..total_len {
|
||||
recv_value(reader, elt_tag, &mut data, alloc)?
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
Tag::Range(it) => {
|
||||
let tag = it.clone().next().expect("truncated tag");
|
||||
recv_value(reader, tag, data, alloc)?;
|
||||
@ -132,7 +151,7 @@ unsafe fn send_value<W>(writer: &mut W, tag: Tag, data: &mut *const ())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Tag::List(it) | Tag::Array(it) => {
|
||||
Tag::List(it) => {
|
||||
#[repr(C)]
|
||||
struct List { elements: *const (), length: u32 };
|
||||
consume_value!(List, |ptr| {
|
||||
@ -145,6 +164,25 @@ unsafe fn send_value<W>(writer: &mut W, tag: Tag, data: &mut *const ())
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
Tag::Array(it, num_dims) => {
|
||||
writer.write_u8(num_dims)?;
|
||||
consume_value!(*const(), |buffer| {
|
||||
let elt_tag = it.clone().next().expect("truncated tag");
|
||||
|
||||
let mut total_len = 1;
|
||||
for _ in 0..num_dims {
|
||||
consume_value!(u32, |len| {
|
||||
writer.write_u32(*len)?;
|
||||
total_len *= *len;
|
||||
})
|
||||
}
|
||||
let mut data = *buffer;
|
||||
for _ in 0..total_len as usize {
|
||||
send_value(writer, elt_tag, &mut data)?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
Tag::Range(it) => {
|
||||
let tag = it.clone().next().expect("truncated tag");
|
||||
send_value(writer, tag, data)?;
|
||||
@ -226,7 +264,7 @@ mod tag {
|
||||
ByteArray,
|
||||
Tuple(TagIterator<'a>, u8),
|
||||
List(TagIterator<'a>),
|
||||
Array(TagIterator<'a>),
|
||||
Array(TagIterator<'a>, u8),
|
||||
Range(TagIterator<'a>),
|
||||
Keyword(TagIterator<'a>),
|
||||
Object
|
||||
@ -245,7 +283,7 @@ mod tag {
|
||||
Tag::ByteArray => b'A',
|
||||
Tag::Tuple(_, _) => b't',
|
||||
Tag::List(_) => b'l',
|
||||
Tag::Array(_) => b'a',
|
||||
Tag::Array(_, _) => b'a',
|
||||
Tag::Range(_) => b'r',
|
||||
Tag::Keyword(_) => b'k',
|
||||
Tag::Object => b'O',
|
||||
@ -272,7 +310,7 @@ mod tag {
|
||||
size
|
||||
}
|
||||
Tag::List(_) => 8,
|
||||
Tag::Array(_) => 8,
|
||||
Tag::Array(_, num_dims) => 4 * (1 + num_dims as usize),
|
||||
Tag::Range(it) => {
|
||||
let tag = it.clone().next().expect("truncated tag");
|
||||
tag.size() * 3
|
||||
@ -315,7 +353,11 @@ mod tag {
|
||||
Tag::Tuple(self.sub(count), count)
|
||||
}
|
||||
b'l' => Tag::List(self.sub(1)),
|
||||
b'a' => Tag::Array(self.sub(1)),
|
||||
b'a' => {
|
||||
let count = self.data[0];
|
||||
self.data = &self.data[1..];
|
||||
Tag::Array(self.sub(1), count)
|
||||
}
|
||||
b'r' => Tag::Range(self.sub(1)),
|
||||
b'k' => Tag::Keyword(self.sub(1)),
|
||||
b'O' => Tag::Object,
|
||||
@ -370,10 +412,10 @@ mod tag {
|
||||
it.fmt(f)?;
|
||||
write!(f, ")")?;
|
||||
}
|
||||
Tag::Array(it) => {
|
||||
Tag::Array(it, num_dims) => {
|
||||
write!(f, "Array(")?;
|
||||
it.fmt(f)?;
|
||||
write!(f, ")")?;
|
||||
write!(f, ", {})", num_dims)?;
|
||||
}
|
||||
Tag::Range(it) => {
|
||||
write!(f, "Range(")?;
|
||||
|
@ -22,6 +22,12 @@ class RoundtripTest(ExperimentCase):
|
||||
self.assertEqual(obj, objcopy)
|
||||
exp.roundtrip(obj, callback)
|
||||
|
||||
def assertArrayRoundtrip(self, obj):
|
||||
exp = self.create(_Roundtrip)
|
||||
def callback(objcopy):
|
||||
numpy.testing.assert_array_equal(obj, objcopy)
|
||||
exp.roundtrip(obj, callback)
|
||||
|
||||
def test_None(self):
|
||||
self.assertRoundtrip(None)
|
||||
|
||||
@ -48,9 +54,6 @@ class RoundtripTest(ExperimentCase):
|
||||
def test_list(self):
|
||||
self.assertRoundtrip([10])
|
||||
|
||||
def test_array(self):
|
||||
self.assertRoundtrip(numpy.array([10]))
|
||||
|
||||
def test_object(self):
|
||||
obj = object()
|
||||
self.assertRoundtrip(obj)
|
||||
@ -64,6 +67,19 @@ class RoundtripTest(ExperimentCase):
|
||||
def test_list_mixed_tuple(self):
|
||||
self.assertRoundtrip([(0x12345678, [("foo", [0.0, 1.0], [0, 1])])])
|
||||
|
||||
def test_array_1d(self):
|
||||
self.assertArrayRoundtrip(numpy.array([1, 2, 3], dtype=numpy.int32))
|
||||
self.assertArrayRoundtrip(numpy.array([1.0, 2.0, 3.0]))
|
||||
self.assertArrayRoundtrip(numpy.array(["a", "b", "c"]))
|
||||
|
||||
def test_array_2d(self):
|
||||
self.assertArrayRoundtrip(numpy.array([[1, 2], [3, 4]], dtype=numpy.int32))
|
||||
self.assertArrayRoundtrip(numpy.array([[1.0, 2.0], [3.0, 4.0]]))
|
||||
self.assertArrayRoundtrip(numpy.array([["a", "b"], ["c", "d"]]))
|
||||
|
||||
def test_array_jagged(self):
|
||||
self.assertArrayRoundtrip(numpy.array([[1, 2], [3]]))
|
||||
|
||||
|
||||
class _DefaultArg(EnvExperiment):
|
||||
def build(self):
|
||||
@ -117,6 +133,12 @@ class _RPCTypes(EnvExperiment):
|
||||
def return_range(self) -> TRange32:
|
||||
return range(10)
|
||||
|
||||
def return_array(self) -> TArray(TInt32):
|
||||
return numpy.array([1, 2])
|
||||
|
||||
def return_matrix(self) -> TArray(TInt32, 2):
|
||||
return numpy.array([[1, 2], [3, 4]])
|
||||
|
||||
def return_mismatch(self):
|
||||
return b"foo"
|
||||
|
||||
@ -132,6 +154,8 @@ class _RPCTypes(EnvExperiment):
|
||||
core_log(self.return_tuple())
|
||||
core_log(self.return_list())
|
||||
core_log(self.return_range())
|
||||
core_log(self.return_array())
|
||||
core_log(self.return_matrix())
|
||||
|
||||
def accept(self, value):
|
||||
pass
|
||||
@ -150,6 +174,8 @@ class _RPCTypes(EnvExperiment):
|
||||
self.accept((2, 3))
|
||||
self.accept([1, 2])
|
||||
self.accept(range(10))
|
||||
self.accept(numpy.array([1, 2]))
|
||||
self.accept(numpy.array([[1, 2], [3, 4]]))
|
||||
self.accept(self)
|
||||
|
||||
@kernel
|
||||
|
Loading…
Reference in New Issue
Block a user