standalone: extend test_ndarray_matmul
This commit is contained in:
parent
b034bde3e1
commit
90e23caaaf
|
@ -68,6 +68,19 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
|
||||||
for c in range(len(n[r])):
|
for c in range(len(n[r])):
|
||||||
output_float64(n[r][c])
|
output_float64(n[r][c])
|
||||||
|
|
||||||
|
def output_ndarray_float_3(n: ndarray[float, Literal[3]]):
|
||||||
|
for d in range(len(n)):
|
||||||
|
for r in range(len(n[d])):
|
||||||
|
for c in range(len(n[d][r])):
|
||||||
|
output_float64(n[d][r][c])
|
||||||
|
|
||||||
|
def output_ndarray_float_4(n: ndarray[float, Literal[4]]):
|
||||||
|
for x in range(len(n)):
|
||||||
|
for y in range(len(n[x])):
|
||||||
|
for z in range(len(n[x][y])):
|
||||||
|
for w in range(len(n[x][y][z])):
|
||||||
|
output_float64(n[x][y][z][w])
|
||||||
|
|
||||||
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -530,11 +543,59 @@ def test_ndarray_ipow_broadcast_scalar():
|
||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
def test_ndarray_matmul():
|
def test_ndarray_matmul():
|
||||||
x = np_identity(2)
|
# 2D @ 2D -> 2D
|
||||||
y = x @ np_ones([2, 2])
|
a1 = np_array([[2.0, 3.0], [5.0, 7.0]])
|
||||||
|
b1 = np_array([[11.0, 13.0], [17.0, 23.0]])
|
||||||
|
c1 = a1 @ b1
|
||||||
|
output_int32(np_shape(c1)[0])
|
||||||
|
output_int32(np_shape(c1)[1])
|
||||||
|
output_ndarray_float_2(c1)
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
# 1D @ 1D -> Scalar
|
||||||
output_ndarray_float_2(y)
|
a2 = np_array([2.0, 3.0, 5.0])
|
||||||
|
b2 = np_array([7.0, 11.0, 13.0])
|
||||||
|
c2 = a2 @ b2
|
||||||
|
output_float64(c2)
|
||||||
|
|
||||||
|
# 2D @ 1D -> 1D
|
||||||
|
a3 = np_array([[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]])
|
||||||
|
b3 = np_array([4.0, 5.0, 6.0])
|
||||||
|
c3 = a3 @ b3
|
||||||
|
output_int32(np_shape(c3)[0])
|
||||||
|
output_ndarray_float_1(c3)
|
||||||
|
|
||||||
|
# 1D @ 2D -> 1D
|
||||||
|
a4 = np_array([1.0, 2.0, 3.0])
|
||||||
|
b4 = np_array([[4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
|
||||||
|
c4 = a4 @ b4
|
||||||
|
output_int32(np_shape(c4)[0])
|
||||||
|
output_ndarray_float_1(c4)
|
||||||
|
|
||||||
|
# Broadcasting
|
||||||
|
a5 = np_array([
|
||||||
|
[[ 0.0, 1.0, 2.0, 3.0],
|
||||||
|
[ 4.0, 5.0, 6.0, 7.0]],
|
||||||
|
[[ 8.0, 9.0, 10.0, 11.0],
|
||||||
|
[12.0, 13.0, 14.0, 15.0]],
|
||||||
|
[[16.0, 17.0, 18.0, 19.0],
|
||||||
|
[20.0, 21.0, 22.0, 23.0]]
|
||||||
|
])
|
||||||
|
b5 = np_array([
|
||||||
|
[[[ 0.0, 1.0, 2.0],
|
||||||
|
[ 3.0, 4.0, 5.0],
|
||||||
|
[ 6.0, 7.0, 8.0],
|
||||||
|
[ 9.0, 10.0, 11.0]]],
|
||||||
|
[[[12.0, 13.0, 14.0],
|
||||||
|
[15.0, 16.0, 17.0],
|
||||||
|
[18.0, 19.0, 20.0],
|
||||||
|
[21.0, 22.0, 23.0]]]
|
||||||
|
])
|
||||||
|
c5 = a5 @ b5
|
||||||
|
output_int32(np_shape(c5)[0])
|
||||||
|
output_int32(np_shape(c5)[1])
|
||||||
|
output_int32(np_shape(c5)[2])
|
||||||
|
output_int32(np_shape(c5)[3])
|
||||||
|
output_ndarray_float_4(c5)
|
||||||
|
|
||||||
def test_ndarray_imatmul():
|
def test_ndarray_imatmul():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
|
|
Loading…
Reference in New Issue