diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 010d4ec07..7649682c2 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -402,11 +402,66 @@ class Inferencer(algorithm.Visitor): assert False def _coerce_binop(self, op, left, right): - if builtins.is_array(left.type) or builtins.is_array(right.type): - # Operations on arrays are element-wise (possibly using broadcasting). - # TODO: Matrix multiplication (which aren't element-wise). + if isinstance(op, ast.MatMult): + if types.is_var(left.type) or types.is_var(right.type): + return - # # TODO: Allow only for integer arrays. + def num_dims(operand): + if not builtins.is_array(operand.type): + diag = diagnostic.Diagnostic( + "error", + "expected matrix multiplication operand to be of array type, not {type}", + { + "op": op.loc.source(), + "type": types.TypePrinter().name(operand.type) + }, op.loc, [operand.loc]) + self.engine.process(diag) + return + num_dims = operand.type.find()["num_dims"].value + if num_dims not in (1, 2): + diag = diagnostic.Diagnostic( + "error", + "expected matrix multiplication operand to be 1- or 2-dimensional, not {type}", + { + "op": op.loc.source(), + "type": types.TypePrinter().name(operand.type) + }, op.loc, [operand.loc]) + self.engine.process(diag) + return + return num_dims + + left_dims = num_dims(left) + if not left_dims: + return + right_dims = num_dims(right) + if not right_dims: + return + + def map_node_type(typ): + return typ.find()["elt"] + + def map_return(typ): + if left_dims == 1: + if right_dims == 1: + result_dims = 0 + else: + result_dims = 1 + elif right_dims == 1: + result_dims = 1 + else: + result_dims = 2 + result = typ if result_dims == 0 else builtins.TArray( + typ, result_dims) + return (result, builtins.TArray(typ, left_dims), + builtins.TArray(typ, right_dims)) + + return self._coerce_numeric((left, right), + map_return=map_return, + map_node_type=map_node_type) + elif builtins.is_array(left.type) or builtins.is_array(right.type): + # Operations on arrays are element-wise (possibly using broadcasting). + + # TODO: Allow only for integer arrays. # allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift, # ast.RShift) allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod, @@ -553,7 +608,7 @@ class Inferencer(algorithm.Visitor): # division always returns a float return self._coerce_numeric((left, right), lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat())) - else: # MatMult + else: diag = diagnostic.Diagnostic("error", "operator '{op}' is not supported", {"op": op.loc.source()}, op.loc) diff --git a/artiq/test/lit/inferencer/error_matmult.py b/artiq/test/lit/inferencer/error_matmult.py new file mode 100644 index 000000000..2586aec31 --- /dev/null +++ b/artiq/test/lit/inferencer/error_matmult.py @@ -0,0 +1,11 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be of array type +1 @ 2 + +# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be of array type +[1] @ [2] + +# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be 1- or 2-dimensional +array([[[0]]]) @ array([[[1]]]) diff --git a/artiq/test/lit/inferencer/matmult.py b/artiq/test/lit/inferencer/matmult.py new file mode 100644 index 000000000..e8e982c57 --- /dev/null +++ b/artiq/test/lit/inferencer/matmult.py @@ -0,0 +1,17 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +vec = array([0, 1]) +mat = array([[0, 1], [2, 3]]) + +# CHECK-L: ):numpy.int? +vec @ vec + +# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=1) +vec @ mat + +# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=1) +mat @ vec + +# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=2) +mat @ mat