forked from M-Labs/artiq
1
0
Fork 0

compiler: Add inferencer support for array operations

This commit is contained in:
David Nadlinger 2020-07-28 22:44:27 +01:00
parent ef57cad1a3
commit e77c7d1c39
2 changed files with 70 additions and 2 deletions

View File

@ -313,6 +313,10 @@ class Inferencer(algorithm.Visitor):
self.generic_visit(node) self.generic_visit(node)
if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type): if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type):
pass pass
elif (builtins.is_array(node.type) and builtins.is_array(node.value.type)
and builtins.is_numeric(node.type.find()["elt"])
and builtins.is_numeric(node.value.type.find()["elt"])):
pass
else: else:
printer = types.TypePrinter() printer = types.TypePrinter()
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
@ -338,7 +342,7 @@ class Inferencer(algorithm.Visitor):
self.visit(node) self.visit(node)
return node return node
def _coerce_numeric(self, nodes, map_return=lambda typ: typ): def _coerce_numeric(self, nodes, map_return=lambda typ: typ, map_node_type =lambda typ:typ):
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex. # See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
node_types = [] node_types = []
for node in nodes: for node in nodes:
@ -354,6 +358,7 @@ class Inferencer(algorithm.Visitor):
node_types.append(node.type) node_types.append(node.type)
else: else:
node_types.append(node.type) node_types.append(node.type)
node_types = [map_node_type(typ) for typ in node_types]
if any(map(types.is_var, node_types)): # not enough info yet if any(map(types.is_var, node_types)): # not enough info yet
return return
elif not all(map(builtins.is_numeric, node_types)): elif not all(map(builtins.is_numeric, node_types)):
@ -386,7 +391,58 @@ class Inferencer(algorithm.Visitor):
assert False assert False
def _coerce_binop(self, op, left, right): def _coerce_binop(self, op, left, right):
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor, if builtins.is_array(left.type) or builtins.is_array(right.type):
# Operations on arrays are element-wise (possibly using broadcasting).
# TODO: Matrix multiplication (which aren't element-wise).
# # TODO: Allow only for integer arrays.
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
# ast.RShift)
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
ast.Pow, ast.Sub, ast.Div)
if not isinstance(op, allowed_array_ops):
diag = diagnostic.Diagnostic(
"error", "operator '{op}' not valid for array types",
{"op": op.loc.source()}, op.loc)
self.engine.process(diag)
return
def num_dims(typ):
if builtins.is_array(typ):
# TODO: If number of dimensions is ever made a non-fixed parameter,
# need to acutally unify num_dims in _coerce_binop.
return typ.find()["num_dims"].value
return 0
# TODO: Broadcasting.
left_dims = num_dims(left.type)
right_dims = num_dims(right.type)
if left_dims != right_dims:
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
{"num_dims": left_dims}, left.loc)
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
{"num_dims": right_dims}, right.loc)
diag = diagnostic.Diagnostic(
"error", "dimensions of '{op}' array operands must match",
{"op": op.loc.source()}, op.loc, [left.loc, right.loc], [note1, note2])
self.engine.process(diag)
return
def map_node_type(typ):
if builtins.is_array(typ):
return typ.find()["elt"]
else:
# This is (if later valid) a single value broadcast across the array.
return typ
def map_return(typ):
a = builtins.TArray(elt=typ, num_dims=left_dims)
return (a, a, a)
return self._coerce_numeric((left, right),
map_return=map_return,
map_node_type=map_node_type)
elif isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
ast.LShift, ast.RShift)): ast.LShift, ast.RShift)):
# bitwise operators require integers # bitwise operators require integers
for operand in (left, right): for operand in (left, right):

View File

@ -0,0 +1,12 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
a = array([[1, 2], [3, 4]])
b = array([7, 8])
# NumPy supports implicit broadcasting over axes, which we don't (yet).
# Make sure there is a nice error message.
# CHECK-L: ${LINE:+3}: error: dimensions of '+' array operands must match
# CHECK-L: ${LINE:+2}: note: operand of dimension 2
# CHECK-L: ${LINE:+1}: note: operand of dimension 1
a + b