Compare commits

..

2 Commits

2 changed files with 26 additions and 2 deletions

View File

@ -5,11 +5,12 @@ import importlib.util
import importlib.machinery
import math
import numpy as np
import numpy.typing as npt
import pathlib
from numpy import int32, int64, uint32, uint64
from scipy import special
from typing import TypeVar, Generic, Literal
from typing import TypeVar, Generic, Literal, Union
T = TypeVar('T')
class Option(Generic[T]):
@ -50,6 +51,13 @@ class _ConstGenericMarker:
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint)
N = TypeVar("N", bound=np.uint64)
class _NDArrayDummy(Generic[T, N]):
pass
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
def round_away_zero(x):
if x >= 0.0:
return math.floor(x + 0.5)
@ -124,6 +132,16 @@ def patch(module):
module.ceil64 = math.ceil
module.np_ceil = np.ceil
# NumPy ndarray functions
module.ndarray = NDArray
module.np_ndarray = np.ndarray
module.np_empty = np.empty
module.np_zeros = np.zeros
module.np_ones = np.ones
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
# NumPy Math functions
module.np_isnan = np.isnan
module.np_isinf = np.isinf

View File

@ -1,4 +1,10 @@
def consume_ndarray_1(n: ndarray[float, 1]):
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass
def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]):
pass
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
pass
def consume_ndarray_i32_1(n: ndarray[int32, 1]):