1
0
forked from M-Labs/artiq

test: Add coredevice tests for matrix multiplication

Also includes a regression test specifically for
mixing multiple types in one kernel.
This commit is contained in:
David Nadlinger 2021-01-12 00:55:11 +01:00
parent f0284b2549
commit 9b39b1e328

View File

@ -87,6 +87,12 @@ class CompareHostDeviceTest(ExperimentCase):
self._test_binop(code, matrix, scalar) self._test_binop(code, matrix, scalar)
self._test_binop(code, matrix, matrix) self._test_binop(code, matrix, matrix)
def test_matrix_mult(self):
for typ in [numpy.int32, numpy.int64, numpy.float]:
mat_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=typ)
mat_b = numpy.array([[7, 8], [9, 10], [11, 12]], dtype=typ)
self._test_binop("a @ b", mat_a, mat_b)
def test_unary_math_fns(self): def test_unary_math_fns(self):
names = [ names = [
a for a, _ in math_fns.unary_fp_intrinsics + math_fns.unary_fp_runtime_calls a for a, _ in math_fns.unary_fp_intrinsics + math_fns.unary_fp_runtime_calls
@ -124,3 +130,28 @@ class CompareHostDeviceTest(ExperimentCase):
self._test_binop(code, 1.0, numpy.array([2.0, 3.0])) self._test_binop(code, 1.0, numpy.array([2.0, 3.0]))
self._test_binop(code, numpy.array([1.0, 2.0]), 3.0) self._test_binop(code, numpy.array([1.0, 2.0]), 3.0)
self._test_binop(code, numpy.array([1.0, 2.0]), numpy.array([3.0, 4.0])) self._test_binop(code, numpy.array([1.0, 2.0]), numpy.array([3.0, 4.0]))
class _MatrixMult(EnvExperiment):
"""Regression test for GitHub #1578 (ICE when mixing different matrix multiplication
types in one kernel).
"""
def build(self):
self.setattr_device("core")
self.imat = numpy.arange(4, dtype=numpy.int64).reshape((2, 2))
self.fmat = numpy.arange(4, dtype=numpy.float).reshape((2, 2))
@kernel
def run(self):
self.verify(self.imat, self.imat, self.imat @ self.imat)
self.verify(self.imat, self.fmat, self.imat @ self.fmat)
self.verify(self.fmat, self.imat, self.fmat @ self.imat)
self.verify(self.fmat, self.fmat, self.fmat @ self.fmat)
def verify(self, a, b, ab):
if not numpy.allclose(a @ b, ab):
raise ValueError("Mismatch for {} @ {}", a, b)
class TestMatrixMult(ExperimentCase):
def test_multiple_matrix_mult_types(self):
self.create(_MatrixMult).run()