From e77c7d1c39bf151830143a07161377eb9e7e4a96 Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Tue, 28 Jul 2020 22:44:27 +0100 Subject: [PATCH] compiler: Add inferencer support for array operations --- artiq/compiler/transforms/inferencer.py | 60 +++++++++++++++++++- artiq/test/lit/inferencer/error_array_ops.py | 12 ++++ 2 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 artiq/test/lit/inferencer/error_array_ops.py diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 1c9ea3e6c..f2a0a79ee 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -313,6 +313,10 @@ class Inferencer(algorithm.Visitor): self.generic_visit(node) if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type): pass + elif (builtins.is_array(node.type) and builtins.is_array(node.value.type) + and builtins.is_numeric(node.type.find()["elt"]) + and builtins.is_numeric(node.value.type.find()["elt"])): + pass else: printer = types.TypePrinter() note = diagnostic.Diagnostic("note", @@ -338,7 +342,7 @@ class Inferencer(algorithm.Visitor): self.visit(node) return node - def _coerce_numeric(self, nodes, map_return=lambda typ: typ): + def _coerce_numeric(self, nodes, map_return=lambda typ: typ, map_node_type =lambda typ:typ): # See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex. node_types = [] for node in nodes: @@ -354,6 +358,7 @@ class Inferencer(algorithm.Visitor): node_types.append(node.type) else: node_types.append(node.type) + node_types = [map_node_type(typ) for typ in node_types] if any(map(types.is_var, node_types)): # not enough info yet return elif not all(map(builtins.is_numeric, node_types)): @@ -386,7 +391,58 @@ class Inferencer(algorithm.Visitor): assert False def _coerce_binop(self, op, left, right): - if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor, + if builtins.is_array(left.type) or builtins.is_array(right.type): + # Operations on arrays are element-wise (possibly using broadcasting). + # TODO: Matrix multiplication (which aren't element-wise). + + # # TODO: Allow only for integer arrays. + # allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift, + # ast.RShift) + allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod, + ast.Pow, ast.Sub, ast.Div) + if not isinstance(op, allowed_array_ops): + diag = diagnostic.Diagnostic( + "error", "operator '{op}' not valid for array types", + {"op": op.loc.source()}, op.loc) + self.engine.process(diag) + return + + def num_dims(typ): + if builtins.is_array(typ): + # TODO: If number of dimensions is ever made a non-fixed parameter, + # need to acutally unify num_dims in _coerce_binop. + return typ.find()["num_dims"].value + return 0 + + # TODO: Broadcasting. + left_dims = num_dims(left.type) + right_dims = num_dims(right.type) + if left_dims != right_dims: + note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}", + {"num_dims": left_dims}, left.loc) + note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}", + {"num_dims": right_dims}, right.loc) + diag = diagnostic.Diagnostic( + "error", "dimensions of '{op}' array operands must match", + {"op": op.loc.source()}, op.loc, [left.loc, right.loc], [note1, note2]) + self.engine.process(diag) + return + + def map_node_type(typ): + if builtins.is_array(typ): + return typ.find()["elt"] + else: + # This is (if later valid) a single value broadcast across the array. + return typ + + def map_return(typ): + a = builtins.TArray(elt=typ, num_dims=left_dims) + return (a, a, a) + + return self._coerce_numeric((left, right), + map_return=map_return, + map_node_type=map_node_type) + elif isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift, ast.RShift)): # bitwise operators require integers for operand in (left, right): diff --git a/artiq/test/lit/inferencer/error_array_ops.py b/artiq/test/lit/inferencer/error_array_ops.py new file mode 100644 index 000000000..4f85290c1 --- /dev/null +++ b/artiq/test/lit/inferencer/error_array_ops.py @@ -0,0 +1,12 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +a = array([[1, 2], [3, 4]]) +b = array([7, 8]) + +# NumPy supports implicit broadcasting over axes, which we don't (yet). +# Make sure there is a nice error message. +# CHECK-L: ${LINE:+3}: error: dimensions of '+' array operands must match +# CHECK-L: ${LINE:+2}: note: operand of dimension 2 +# CHECK-L: ${LINE:+1}: note: operand of dimension 1 +a + b