forked from M-Labs/artiq
compiler: Add inferencer support for array operations
This commit is contained in:
parent
ef57cad1a3
commit
e77c7d1c39
|
@ -313,6 +313,10 @@ class Inferencer(algorithm.Visitor):
|
|||
self.generic_visit(node)
|
||||
if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type):
|
||||
pass
|
||||
elif (builtins.is_array(node.type) and builtins.is_array(node.value.type)
|
||||
and builtins.is_numeric(node.type.find()["elt"])
|
||||
and builtins.is_numeric(node.value.type.find()["elt"])):
|
||||
pass
|
||||
else:
|
||||
printer = types.TypePrinter()
|
||||
note = diagnostic.Diagnostic("note",
|
||||
|
@ -338,7 +342,7 @@ class Inferencer(algorithm.Visitor):
|
|||
self.visit(node)
|
||||
return node
|
||||
|
||||
def _coerce_numeric(self, nodes, map_return=lambda typ: typ):
|
||||
def _coerce_numeric(self, nodes, map_return=lambda typ: typ, map_node_type =lambda typ:typ):
|
||||
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
||||
node_types = []
|
||||
for node in nodes:
|
||||
|
@ -354,6 +358,7 @@ class Inferencer(algorithm.Visitor):
|
|||
node_types.append(node.type)
|
||||
else:
|
||||
node_types.append(node.type)
|
||||
node_types = [map_node_type(typ) for typ in node_types]
|
||||
if any(map(types.is_var, node_types)): # not enough info yet
|
||||
return
|
||||
elif not all(map(builtins.is_numeric, node_types)):
|
||||
|
@ -386,7 +391,58 @@ class Inferencer(algorithm.Visitor):
|
|||
assert False
|
||||
|
||||
def _coerce_binop(self, op, left, right):
|
||||
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||
if builtins.is_array(left.type) or builtins.is_array(right.type):
|
||||
# Operations on arrays are element-wise (possibly using broadcasting).
|
||||
# TODO: Matrix multiplication (which aren't element-wise).
|
||||
|
||||
# # TODO: Allow only for integer arrays.
|
||||
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
|
||||
# ast.RShift)
|
||||
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
|
||||
ast.Pow, ast.Sub, ast.Div)
|
||||
if not isinstance(op, allowed_array_ops):
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error", "operator '{op}' not valid for array types",
|
||||
{"op": op.loc.source()}, op.loc)
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
def num_dims(typ):
|
||||
if builtins.is_array(typ):
|
||||
# TODO: If number of dimensions is ever made a non-fixed parameter,
|
||||
# need to acutally unify num_dims in _coerce_binop.
|
||||
return typ.find()["num_dims"].value
|
||||
return 0
|
||||
|
||||
# TODO: Broadcasting.
|
||||
left_dims = num_dims(left.type)
|
||||
right_dims = num_dims(right.type)
|
||||
if left_dims != right_dims:
|
||||
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
||||
{"num_dims": left_dims}, left.loc)
|
||||
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
||||
{"num_dims": right_dims}, right.loc)
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error", "dimensions of '{op}' array operands must match",
|
||||
{"op": op.loc.source()}, op.loc, [left.loc, right.loc], [note1, note2])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
def map_node_type(typ):
|
||||
if builtins.is_array(typ):
|
||||
return typ.find()["elt"]
|
||||
else:
|
||||
# This is (if later valid) a single value broadcast across the array.
|
||||
return typ
|
||||
|
||||
def map_return(typ):
|
||||
a = builtins.TArray(elt=typ, num_dims=left_dims)
|
||||
return (a, a, a)
|
||||
|
||||
return self._coerce_numeric((left, right),
|
||||
map_return=map_return,
|
||||
map_node_type=map_node_type)
|
||||
elif isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||
ast.LShift, ast.RShift)):
|
||||
# bitwise operators require integers
|
||||
for operand in (left, right):
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
a = array([[1, 2], [3, 4]])
|
||||
b = array([7, 8])
|
||||
|
||||
# NumPy supports implicit broadcasting over axes, which we don't (yet).
|
||||
# Make sure there is a nice error message.
|
||||
# CHECK-L: ${LINE:+3}: error: dimensions of '+' array operands must match
|
||||
# CHECK-L: ${LINE:+2}: note: operand of dimension 2
|
||||
# CHECK-L: ${LINE:+1}: note: operand of dimension 1
|
||||
a + b
|
Loading…
Reference in New Issue