forked from M-Labs/nac3
1
0
Fork 0

standalone: extend test_ndarray_matmul

This commit is contained in:
lyken 2024-08-21 23:20:28 +08:00
parent 4fef633090
commit 70c26561e1
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
1 changed files with 65 additions and 4 deletions

View File

@ -68,6 +68,19 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
for c in range(len(n[r])):
output_float64(n[r][c])
def output_ndarray_float_3(n: ndarray[float, Literal[3]]):
for d in range(len(n)):
for r in range(len(n[d])):
for c in range(len(n[d][r])):
output_float64(n[d][r][c])
def output_ndarray_float_4(n: ndarray[float, Literal[4]]):
for x in range(len(n)):
for y in range(len(n[x])):
for z in range(len(n[x][y])):
for w in range(len(n[x][y][z])):
output_float64(n[x][y][z][w])
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass
@ -530,11 +543,59 @@ def test_ndarray_ipow_broadcast_scalar():
output_ndarray_float_2(x)
def test_ndarray_matmul():
x = np_identity(2)
y = x @ np_ones([2, 2])
# 2D @ 2D -> 2D
a1 = np_array([[2.0, 3.0], [5.0, 7.0]])
b1 = np_array([[11.0, 13.0], [17.0, 23.0]])
c1 = a1 @ b1
output_int32(np_shape(c1)[0])
output_int32(np_shape(c1)[1])
output_ndarray_float_2(c1)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
# 1D @ 1D -> Scalar
a2 = np_array([2.0, 3.0, 5.0])
b2 = np_array([7.0, 11.0, 13.0])
c2 = a2 @ b2
output_float64(c2)
# 2D @ 1D -> 1D
a3 = np_array([[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]])
b3 = np_array([4.0, 5.0, 6.0])
c3 = a3 @ b3
output_int32(np_shape(c3)[0])
output_ndarray_float_1(c3)
# 1D @ 2D -> 1D
a4 = np_array([1.0, 2.0, 3.0])
b4 = np_array([[4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
c4 = a4 @ b4
output_int32(np_shape(c4)[0])
output_ndarray_float_1(c4)
# Broadcasting
a5 = np_array([
[[ 0.0, 1.0, 2.0, 3.0],
[ 4.0, 5.0, 6.0, 7.0]],
[[ 8.0, 9.0, 10.0, 11.0],
[12.0, 13.0, 14.0, 15.0]],
[[16.0, 17.0, 18.0, 19.0],
[20.0, 21.0, 22.0, 23.0]]
])
b5 = np_array([
[[[ 0.0, 1.0, 2.0],
[ 3.0, 4.0, 5.0],
[ 6.0, 7.0, 8.0],
[ 9.0, 10.0, 11.0]]],
[[[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
[18.0, 19.0, 20.0],
[21.0, 22.0, 23.0]]]
])
c5 = a5 @ b5
output_int32(np_shape(c5)[0])
output_int32(np_shape(c5)[1])
output_int32(np_shape(c5)[2])
output_int32(np_shape(c5)[3])
output_ndarray_float_4(c5)
def test_ndarray_imatmul():
x = np_identity(2)