test_embedding: integer typing fixes

This commit is contained in:
Sebastien Bourdeauducq 2024-06-24 14:13:30 +08:00
parent 3db8612f12
commit 1d71d8de9a
1 changed files with 20 additions and 20 deletions

View File

@ -44,8 +44,8 @@ class RoundtripTest(ExperimentCase):
self.assertRoundtrip(numpy.False_)
def test_int(self):
self.assertRoundtrip(numpy.int32(42))
self.assertRoundtrip(numpy.int64(42))
self.assertRoundtrip(int32(42))
self.assertRoundtrip(int64(42))
def test_float(self):
self.assertRoundtrip(42.0)
@ -66,7 +66,7 @@ class RoundtripTest(ExperimentCase):
self.assertRoundtrip([True, False])
def test_int64_list(self):
self.assertRoundtrip([numpy.int64(0), numpy.int64(1)])
self.assertRoundtrip([int64(0), int64(1)])
def test_object(self):
obj = object()
@ -89,12 +89,12 @@ class RoundtripTest(ExperimentCase):
def test_array_1d(self):
self.assertArrayRoundtrip(numpy.array([True, False]))
self.assertArrayRoundtrip(numpy.array([1, 2, 3], dtype=numpy.int32))
self.assertArrayRoundtrip(numpy.array([1, 2, 3], dtype=int32))
self.assertArrayRoundtrip(numpy.array([1.0, 2.0, 3.0]))
self.assertArrayRoundtrip(numpy.array(["a", "b", "c"]))
def test_array_2d(self):
self.assertArrayRoundtrip(numpy.array([[1, 2], [3, 4]], dtype=numpy.int32))
self.assertArrayRoundtrip(numpy.array([[1, 2], [3, 4]], dtype=int32))
self.assertArrayRoundtrip(numpy.array([[1.0, 2.0], [3.0, 4.0]]))
self.assertArrayRoundtrip(numpy.array([["a", "b"], ["c", "d"]]))
@ -226,7 +226,7 @@ class RPCTypesTest(ExperimentCase):
class _RPCCalls(EnvExperiment):
def build(self):
self.setattr_device("core")
self._list_int64 = [numpy.int64(1)]
self._list_int64 = [int64(1)]
def args(self, *args) -> int32:
return len(args)
@ -268,7 +268,7 @@ class _RPCCalls(EnvExperiment):
@kernel
def numpy_things(self):
return (numpy.int32(10), numpy.int64(20))
return (int32(10), int64(20))
@kernel
def builtin(self):
@ -297,11 +297,11 @@ class RPCCallsTest(ExperimentCase):
self.assertEqual(exp.kwargs2(), 2)
self.assertEqual(exp.args1kwargs2(), 2)
self.assertEqual(exp.numpy_things(),
(numpy.int32(10), numpy.int64(20)))
(int32(10), int64(20)))
# Ensure lists of int64s don't decay to variable-length builtin integers.
list_int64 = exp.list_int64()
self.assertEqual(list_int64, [numpy.int64(1)])
self.assertTrue(isinstance(list_int64[0], numpy.int64))
self.assertEqual(list_int64, [int64(1)])
self.assertTrue(isinstance(list_int64[0], int64))
exp.builtin()
exp.async_in_try()
@ -393,8 +393,8 @@ class _ListTuple(EnvExperiment):
return 2
def get_values(self, base_a, base_b, n) -> tuple[list[int32], list[int32]]:
return [numpy.int32(base_a + i) for i in range(n)], \
[numpy.int32(base_b + i) for i in range(n)]
return [int32(base_a + i) for i in range(n)], \
[int32(base_b + i) for i in range(n)]
class _NestedTupleList(EnvExperiment):
@ -442,8 +442,8 @@ class ListTupleTest(ExperimentCase):
class _ArrayQuoting(EnvExperiment):
def build(self):
self.setattr_device("core")
self.vec_i32 = numpy.array([0, 1], dtype=numpy.int32)
self.mat_i64 = numpy.array([[0, 1], [2, 3]], dtype=numpy.int64)
self.vec_i32 = numpy.array([0, 1], dtype=int32)
self.mat_i64 = numpy.array([[0, 1], [2, 3]], dtype=int64)
self.arr_f64 = numpy.array([[[0.0, 1.0], [2.0, 3.0]],
[[4.0, 5.0], [6.0, 7.0]]])
self.strs = numpy.array(["foo", "bar"])
@ -579,17 +579,17 @@ class NumpyQuotingTest(ExperimentCase):
class _IntBoundary(EnvExperiment):
def build(self):
self.setattr_device("core")
self.int32_min = numpy.iinfo(numpy.int32).min
self.int32_max = numpy.iinfo(numpy.int32).max
self.int64_min = numpy.iinfo(numpy.int64).min
self.int64_max = numpy.iinfo(numpy.int64).max
self.int32_min = numpy.iinfo(int32).min
self.int32_max = numpy.iinfo(int32).max
self.int64_min = numpy.iinfo(int64).min
self.int64_max = numpy.iinfo(int64).max
@kernel
def test_int32_bounds(self, min_val: TInt32, max_val: TInt32):
def test_int32_bounds(self, min_val: int32, max_val: int32):
return min_val == self.int32_min and max_val == self.int32_max
@kernel
def test_int64_bounds(self, min_val: TInt64, max_val: TInt64):
def test_int64_bounds(self, min_val: int64, max_val: int64):
return min_val == self.int64_min and max_val == self.int64_max
@kernel