test_embedding: integer typing fixes

This commit is contained in:
Sebastien Bourdeauducq 2024-06-24 14:13:30 +08:00
parent 3db8612f12
commit 1d71d8de9a
1 changed files with 20 additions and 20 deletions

View File

@ -44,8 +44,8 @@ class RoundtripTest(ExperimentCase):
self.assertRoundtrip(numpy.False_) self.assertRoundtrip(numpy.False_)
def test_int(self): def test_int(self):
self.assertRoundtrip(numpy.int32(42)) self.assertRoundtrip(int32(42))
self.assertRoundtrip(numpy.int64(42)) self.assertRoundtrip(int64(42))
def test_float(self): def test_float(self):
self.assertRoundtrip(42.0) self.assertRoundtrip(42.0)
@ -66,7 +66,7 @@ class RoundtripTest(ExperimentCase):
self.assertRoundtrip([True, False]) self.assertRoundtrip([True, False])
def test_int64_list(self): def test_int64_list(self):
self.assertRoundtrip([numpy.int64(0), numpy.int64(1)]) self.assertRoundtrip([int64(0), int64(1)])
def test_object(self): def test_object(self):
obj = object() obj = object()
@ -89,12 +89,12 @@ class RoundtripTest(ExperimentCase):
def test_array_1d(self): def test_array_1d(self):
self.assertArrayRoundtrip(numpy.array([True, False])) self.assertArrayRoundtrip(numpy.array([True, False]))
self.assertArrayRoundtrip(numpy.array([1, 2, 3], dtype=numpy.int32)) self.assertArrayRoundtrip(numpy.array([1, 2, 3], dtype=int32))
self.assertArrayRoundtrip(numpy.array([1.0, 2.0, 3.0])) self.assertArrayRoundtrip(numpy.array([1.0, 2.0, 3.0]))
self.assertArrayRoundtrip(numpy.array(["a", "b", "c"])) self.assertArrayRoundtrip(numpy.array(["a", "b", "c"]))
def test_array_2d(self): def test_array_2d(self):
self.assertArrayRoundtrip(numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)) self.assertArrayRoundtrip(numpy.array([[1, 2], [3, 4]], dtype=int32))
self.assertArrayRoundtrip(numpy.array([[1.0, 2.0], [3.0, 4.0]])) self.assertArrayRoundtrip(numpy.array([[1.0, 2.0], [3.0, 4.0]]))
self.assertArrayRoundtrip(numpy.array([["a", "b"], ["c", "d"]])) self.assertArrayRoundtrip(numpy.array([["a", "b"], ["c", "d"]]))
@ -226,7 +226,7 @@ class RPCTypesTest(ExperimentCase):
class _RPCCalls(EnvExperiment): class _RPCCalls(EnvExperiment):
def build(self): def build(self):
self.setattr_device("core") self.setattr_device("core")
self._list_int64 = [numpy.int64(1)] self._list_int64 = [int64(1)]
def args(self, *args) -> int32: def args(self, *args) -> int32:
return len(args) return len(args)
@ -268,7 +268,7 @@ class _RPCCalls(EnvExperiment):
@kernel @kernel
def numpy_things(self): def numpy_things(self):
return (numpy.int32(10), numpy.int64(20)) return (int32(10), int64(20))
@kernel @kernel
def builtin(self): def builtin(self):
@ -297,11 +297,11 @@ class RPCCallsTest(ExperimentCase):
self.assertEqual(exp.kwargs2(), 2) self.assertEqual(exp.kwargs2(), 2)
self.assertEqual(exp.args1kwargs2(), 2) self.assertEqual(exp.args1kwargs2(), 2)
self.assertEqual(exp.numpy_things(), self.assertEqual(exp.numpy_things(),
(numpy.int32(10), numpy.int64(20))) (int32(10), int64(20)))
# Ensure lists of int64s don't decay to variable-length builtin integers. # Ensure lists of int64s don't decay to variable-length builtin integers.
list_int64 = exp.list_int64() list_int64 = exp.list_int64()
self.assertEqual(list_int64, [numpy.int64(1)]) self.assertEqual(list_int64, [int64(1)])
self.assertTrue(isinstance(list_int64[0], numpy.int64)) self.assertTrue(isinstance(list_int64[0], int64))
exp.builtin() exp.builtin()
exp.async_in_try() exp.async_in_try()
@ -393,8 +393,8 @@ class _ListTuple(EnvExperiment):
return 2 return 2
def get_values(self, base_a, base_b, n) -> tuple[list[int32], list[int32]]: def get_values(self, base_a, base_b, n) -> tuple[list[int32], list[int32]]:
return [numpy.int32(base_a + i) for i in range(n)], \ return [int32(base_a + i) for i in range(n)], \
[numpy.int32(base_b + i) for i in range(n)] [int32(base_b + i) for i in range(n)]
class _NestedTupleList(EnvExperiment): class _NestedTupleList(EnvExperiment):
@ -442,8 +442,8 @@ class ListTupleTest(ExperimentCase):
class _ArrayQuoting(EnvExperiment): class _ArrayQuoting(EnvExperiment):
def build(self): def build(self):
self.setattr_device("core") self.setattr_device("core")
self.vec_i32 = numpy.array([0, 1], dtype=numpy.int32) self.vec_i32 = numpy.array([0, 1], dtype=int32)
self.mat_i64 = numpy.array([[0, 1], [2, 3]], dtype=numpy.int64) self.mat_i64 = numpy.array([[0, 1], [2, 3]], dtype=int64)
self.arr_f64 = numpy.array([[[0.0, 1.0], [2.0, 3.0]], self.arr_f64 = numpy.array([[[0.0, 1.0], [2.0, 3.0]],
[[4.0, 5.0], [6.0, 7.0]]]) [[4.0, 5.0], [6.0, 7.0]]])
self.strs = numpy.array(["foo", "bar"]) self.strs = numpy.array(["foo", "bar"])
@ -579,17 +579,17 @@ class NumpyQuotingTest(ExperimentCase):
class _IntBoundary(EnvExperiment): class _IntBoundary(EnvExperiment):
def build(self): def build(self):
self.setattr_device("core") self.setattr_device("core")
self.int32_min = numpy.iinfo(numpy.int32).min self.int32_min = numpy.iinfo(int32).min
self.int32_max = numpy.iinfo(numpy.int32).max self.int32_max = numpy.iinfo(int32).max
self.int64_min = numpy.iinfo(numpy.int64).min self.int64_min = numpy.iinfo(int64).min
self.int64_max = numpy.iinfo(numpy.int64).max self.int64_max = numpy.iinfo(int64).max
@kernel @kernel
def test_int32_bounds(self, min_val: TInt32, max_val: TInt32): def test_int32_bounds(self, min_val: int32, max_val: int32):
return min_val == self.int32_min and max_val == self.int32_max return min_val == self.int32_min and max_val == self.int32_max
@kernel @kernel
def test_int64_bounds(self, min_val: TInt64, max_val: TInt64): def test_int64_bounds(self, min_val: int64, max_val: int64):
return min_val == self.int64_min and max_val == self.int64_max return min_val == self.int64_min and max_val == self.int64_max
@kernel @kernel