test_embedding: port imports and type annotations to NAC3

This commit is contained in:
Sebastien Bourdeauducq 2022-02-26 18:47:59 +08:00
parent 70531ae1e2
commit 9a05907b7a
1 changed files with 24 additions and 23 deletions

View File

@ -1,4 +1,5 @@
import numpy
from numpy import int32, int64
import unittest
from time import sleep
@ -103,7 +104,7 @@ class _DefaultArg(EnvExperiment):
def build(self):
self.setattr_device("core")
def test(self, foo=42) -> TInt32:
def test(self, foo=42) -> int32:
return foo
@kernel
@ -121,40 +122,40 @@ class _RPCTypes(EnvExperiment):
def build(self):
self.setattr_device("core")
def return_bool(self) -> TBool:
def return_bool(self) -> bool:
return True
def return_int32(self) -> TInt32:
def return_int32(self) -> int32:
return 1
def return_int64(self) -> TInt64:
def return_int64(self) -> int64:
return 0x100000000
def return_float(self) -> TFloat:
def return_float(self) -> float:
return 1.0
def return_str(self) -> TStr:
def return_str(self) -> str:
return "foo"
def return_bytes(self) -> TBytes:
def return_bytes(self) -> bytes:
return b"foo"
def return_bytearray(self) -> TByteArray:
def return_bytearray(self) -> bytearray:
return bytearray(b"foo")
def return_tuple(self) -> TTuple([TInt32, TInt32]):
def return_tuple(self) -> tuple[int32, int32]:
return (1, 2)
def return_list(self) -> TList(TInt32):
def return_list(self) -> list[int32]:
return [2, 3]
def return_range(self) -> TRange32:
def return_range(self) -> range:
return range(10)
def return_array(self) -> TArray(TInt32):
def return_array(self) -> numpy.ndarray: # NAC3TODO [int32]
return numpy.array([1, 2])
def return_matrix(self) -> TArray(TInt32, 2):
def return_matrix(self) -> numpy.ndarray: # NAC3TODO [int32, 2]
return numpy.array([[1, 2], [3, 4]])
def return_mismatch(self):
@ -221,10 +222,10 @@ class _RPCCalls(EnvExperiment):
self.setattr_device("core")
self._list_int64 = [numpy.int64(1)]
def args(self, *args) -> TInt32:
def args(self, *args) -> int32:
return len(args)
def kwargs(self, x="", **kwargs) -> TInt32:
def kwargs(self, x="", **kwargs) -> int32:
return len(kwargs)
@kernel
@ -304,11 +305,11 @@ class _Annotation(EnvExperiment):
self.setattr_device("core")
@kernel
def overflow(self, x: TInt64) -> TBool:
def overflow(self, x: int64) -> bool:
return (x << 32) != 0
@kernel
def monomorphize(self, x: TList(TInt32)):
def monomorphize(self, x: list[int32]):
pass
@ -382,10 +383,10 @@ class _ListTuple(EnvExperiment):
if data[i] != data[0] + i:
raise ValueError
def get_num_iters(self) -> TInt32:
def get_num_iters(self) -> int32:
return 2
def get_values(self, base_a, base_b, n) -> TTuple([TList(TInt32), TList(TInt32)]):
def get_values(self, base_a, base_b, n) -> tuple[list[int32], list[int32]]:
return [numpy.int32(base_a + i) for i in range(n)], \
[numpy.int32(base_b + i) for i in range(n)]
@ -396,8 +397,8 @@ class _NestedTupleList(EnvExperiment):
self.data = [(0x12345678, [("foo", [0.0, 1.0], [2, 3])]),
(0x76543210, [("bar", [4.0, 5.0], [6, 7])])]
def get_data(self) -> TList(TTuple(
[TInt32, TList(TTuple([TStr, TList(TFloat), TList(TInt32)]))])):
def get_data(self) -> list[tuple
[int32, list[tuple[str, list[float], list[int32]]]]]:
return self.data
@kernel
@ -411,7 +412,7 @@ class _EmptyList(EnvExperiment):
def build(self):
self.setattr_device("core")
def get_empty(self) -> TList(TInt32):
def get_empty(self) -> list[int32]:
return []
@kernel
@ -522,7 +523,7 @@ class _Alignment(EnvExperiment):
self.setattr_device("core")
@rpc
def a_tuple(self) -> TList(TTuple([TBool, TFloat, TBool])):
def a_tuple(self) -> list[tuple[bool, float, bool]]:
return [(True, 1234.5678, True)]
@kernel