mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-19 00:16:29 +08:00
test: Add coredevice tests for matrix multiplication
Also includes a regression test specifically for mixing multiple types in one kernel.
This commit is contained in:
parent
f0284b2549
commit
9b39b1e328
@ -87,6 +87,12 @@ class CompareHostDeviceTest(ExperimentCase):
|
||||
self._test_binop(code, matrix, scalar)
|
||||
self._test_binop(code, matrix, matrix)
|
||||
|
||||
def test_matrix_mult(self):
|
||||
for typ in [numpy.int32, numpy.int64, numpy.float]:
|
||||
mat_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=typ)
|
||||
mat_b = numpy.array([[7, 8], [9, 10], [11, 12]], dtype=typ)
|
||||
self._test_binop("a @ b", mat_a, mat_b)
|
||||
|
||||
def test_unary_math_fns(self):
|
||||
names = [
|
||||
a for a, _ in math_fns.unary_fp_intrinsics + math_fns.unary_fp_runtime_calls
|
||||
@ -124,3 +130,28 @@ class CompareHostDeviceTest(ExperimentCase):
|
||||
self._test_binop(code, 1.0, numpy.array([2.0, 3.0]))
|
||||
self._test_binop(code, numpy.array([1.0, 2.0]), 3.0)
|
||||
self._test_binop(code, numpy.array([1.0, 2.0]), numpy.array([3.0, 4.0]))
|
||||
|
||||
|
||||
class _MatrixMult(EnvExperiment):
|
||||
"""Regression test for GitHub #1578 (ICE when mixing different matrix multiplication
|
||||
types in one kernel).
|
||||
"""
|
||||
def build(self):
|
||||
self.setattr_device("core")
|
||||
self.imat = numpy.arange(4, dtype=numpy.int64).reshape((2, 2))
|
||||
self.fmat = numpy.arange(4, dtype=numpy.float).reshape((2, 2))
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.verify(self.imat, self.imat, self.imat @ self.imat)
|
||||
self.verify(self.imat, self.fmat, self.imat @ self.fmat)
|
||||
self.verify(self.fmat, self.imat, self.fmat @ self.imat)
|
||||
self.verify(self.fmat, self.fmat, self.fmat @ self.fmat)
|
||||
|
||||
def verify(self, a, b, ab):
|
||||
if not numpy.allclose(a @ b, ab):
|
||||
raise ValueError("Mismatch for {} @ {}", a, b)
|
||||
|
||||
class TestMatrixMult(ExperimentCase):
|
||||
def test_multiple_matrix_mult_types(self):
|
||||
self.create(_MatrixMult).run()
|
||||
|
Loading…
Reference in New Issue
Block a user