test_embedding: port imports and type annotations to NAC3

This commit is contained in:
Sebastien Bourdeauducq 2022-02-26 18:47:59 +08:00
parent 70531ae1e2
commit 9a05907b7a
1 changed files with 24 additions and 23 deletions

View File

@ -1,4 +1,5 @@
import numpy import numpy
from numpy import int32, int64
import unittest import unittest
from time import sleep from time import sleep
@ -103,7 +104,7 @@ class _DefaultArg(EnvExperiment):
def build(self): def build(self):
self.setattr_device("core") self.setattr_device("core")
def test(self, foo=42) -> TInt32: def test(self, foo=42) -> int32:
return foo return foo
@kernel @kernel
@ -121,40 +122,40 @@ class _RPCTypes(EnvExperiment):
def build(self): def build(self):
self.setattr_device("core") self.setattr_device("core")
def return_bool(self) -> TBool: def return_bool(self) -> bool:
return True return True
def return_int32(self) -> TInt32: def return_int32(self) -> int32:
return 1 return 1
def return_int64(self) -> TInt64: def return_int64(self) -> int64:
return 0x100000000 return 0x100000000
def return_float(self) -> TFloat: def return_float(self) -> float:
return 1.0 return 1.0
def return_str(self) -> TStr: def return_str(self) -> str:
return "foo" return "foo"
def return_bytes(self) -> TBytes: def return_bytes(self) -> bytes:
return b"foo" return b"foo"
def return_bytearray(self) -> TByteArray: def return_bytearray(self) -> bytearray:
return bytearray(b"foo") return bytearray(b"foo")
def return_tuple(self) -> TTuple([TInt32, TInt32]): def return_tuple(self) -> tuple[int32, int32]:
return (1, 2) return (1, 2)
def return_list(self) -> TList(TInt32): def return_list(self) -> list[int32]:
return [2, 3] return [2, 3]
def return_range(self) -> TRange32: def return_range(self) -> range:
return range(10) return range(10)
def return_array(self) -> TArray(TInt32): def return_array(self) -> numpy.ndarray: # NAC3TODO [int32]
return numpy.array([1, 2]) return numpy.array([1, 2])
def return_matrix(self) -> TArray(TInt32, 2): def return_matrix(self) -> numpy.ndarray: # NAC3TODO [int32, 2]
return numpy.array([[1, 2], [3, 4]]) return numpy.array([[1, 2], [3, 4]])
def return_mismatch(self): def return_mismatch(self):
@ -221,10 +222,10 @@ class _RPCCalls(EnvExperiment):
self.setattr_device("core") self.setattr_device("core")
self._list_int64 = [numpy.int64(1)] self._list_int64 = [numpy.int64(1)]
def args(self, *args) -> TInt32: def args(self, *args) -> int32:
return len(args) return len(args)
def kwargs(self, x="", **kwargs) -> TInt32: def kwargs(self, x="", **kwargs) -> int32:
return len(kwargs) return len(kwargs)
@kernel @kernel
@ -304,11 +305,11 @@ class _Annotation(EnvExperiment):
self.setattr_device("core") self.setattr_device("core")
@kernel @kernel
def overflow(self, x: TInt64) -> TBool: def overflow(self, x: int64) -> bool:
return (x << 32) != 0 return (x << 32) != 0
@kernel @kernel
def monomorphize(self, x: TList(TInt32)): def monomorphize(self, x: list[int32]):
pass pass
@ -382,10 +383,10 @@ class _ListTuple(EnvExperiment):
if data[i] != data[0] + i: if data[i] != data[0] + i:
raise ValueError raise ValueError
def get_num_iters(self) -> TInt32: def get_num_iters(self) -> int32:
return 2 return 2
def get_values(self, base_a, base_b, n) -> TTuple([TList(TInt32), TList(TInt32)]): def get_values(self, base_a, base_b, n) -> tuple[list[int32], list[int32]]:
return [numpy.int32(base_a + i) for i in range(n)], \ return [numpy.int32(base_a + i) for i in range(n)], \
[numpy.int32(base_b + i) for i in range(n)] [numpy.int32(base_b + i) for i in range(n)]
@ -396,8 +397,8 @@ class _NestedTupleList(EnvExperiment):
self.data = [(0x12345678, [("foo", [0.0, 1.0], [2, 3])]), self.data = [(0x12345678, [("foo", [0.0, 1.0], [2, 3])]),
(0x76543210, [("bar", [4.0, 5.0], [6, 7])])] (0x76543210, [("bar", [4.0, 5.0], [6, 7])])]
def get_data(self) -> TList(TTuple( def get_data(self) -> list[tuple
[TInt32, TList(TTuple([TStr, TList(TFloat), TList(TInt32)]))])): [int32, list[tuple[str, list[float], list[int32]]]]]:
return self.data return self.data
@kernel @kernel
@ -411,7 +412,7 @@ class _EmptyList(EnvExperiment):
def build(self): def build(self):
self.setattr_device("core") self.setattr_device("core")
def get_empty(self) -> TList(TInt32): def get_empty(self) -> list[int32]:
return [] return []
@kernel @kernel
@ -522,7 +523,7 @@ class _Alignment(EnvExperiment):
self.setattr_device("core") self.setattr_device("core")
@rpc @rpc
def a_tuple(self) -> TList(TTuple([TBool, TFloat, TBool])): def a_tuple(self) -> list[tuple[bool, float, bool]]:
return [(True, 1234.5678, True)] return [(True, 1234.5678, True)]
@kernel @kernel