coredevice.comm_kernel: Fix unpacking of lists of numpy.int64

test.coredevice.test_embedding: Add tests for list of numpy.int64
pull/1670/head
Peter Drmota 2021-04-19 18:54:00 +02:00 committed by David Nadlinger
parent af4fadcd54
commit 47bf5d36af
2 changed files with 13 additions and 1 deletions

View File

@ -72,7 +72,7 @@ def _receive_list(kernel, embedding_map):
return list(struct.unpack(kernel.endian + "%sl" % length, buffer))
elif tag == "I":
buffer = kernel._read(8 * length)
return list(struct.unpack(kernel.endian + "%sq" % length, buffer))
return list(numpy.ndarray((length, ), kernel.endian + 'i8', buffer))
elif tag == "f":
buffer = kernel._read(8 * length)
return list(struct.unpack(kernel.endian + "%sd" % length, buffer))

View File

@ -64,6 +64,9 @@ class RoundtripTest(ExperimentCase):
def test_bool_list(self):
self.assertRoundtrip([True, False])
def test_int64_list(self):
self.assertRoundtrip([numpy.int64(0), numpy.int64(1)])
def test_object(self):
obj = object()
self.assertRoundtrip(obj)
@ -216,6 +219,7 @@ class RPCTypesTest(ExperimentCase):
class _RPCCalls(EnvExperiment):
def build(self):
self.setattr_device("core")
self._list_int64 = [numpy.int64(1)]
def args(self, *args) -> TInt32:
return len(args)
@ -251,6 +255,10 @@ class _RPCCalls(EnvExperiment):
def args1kwargs2(self):
return self.kwargs("X", a="A", b=1)
@kernel
def list_int64(self):
return self._list_int64
@kernel
def numpy_things(self):
return (numpy.int32(10), numpy.int64(20), numpy.array([42,]))
@ -295,6 +303,10 @@ class RPCCallsTest(ExperimentCase):
self.assertEqual(exp.args1kwargs2(), 2)
self.assertEqual(exp.numpy_things(),
(numpy.int32(10), numpy.int64(20), numpy.array([42,])))
# Ensure lists of int64s don't decay to variable-length builtin integers.
list_int64 = exp.list_int64()
self.assertEqual(list_int64, [numpy.int64(1)])
self.assertTrue(isinstance(list_int64[0], numpy.int64))
self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all())
self.assertTrue((exp.numpy_full_matrix() == numpy.full((3, 2), 13)).all())
self.assertTrue(numpy.isnan(exp.numpy_nan()).all())