forked from M-Labs/artiq
compiler: Add inferencer support for array operations
This commit is contained in:
parent
ef57cad1a3
commit
e77c7d1c39
|
@ -313,6 +313,10 @@ class Inferencer(algorithm.Visitor):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type):
|
if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type):
|
||||||
pass
|
pass
|
||||||
|
elif (builtins.is_array(node.type) and builtins.is_array(node.value.type)
|
||||||
|
and builtins.is_numeric(node.type.find()["elt"])
|
||||||
|
and builtins.is_numeric(node.value.type.find()["elt"])):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
printer = types.TypePrinter()
|
printer = types.TypePrinter()
|
||||||
note = diagnostic.Diagnostic("note",
|
note = diagnostic.Diagnostic("note",
|
||||||
|
@ -338,7 +342,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
self.visit(node)
|
self.visit(node)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def _coerce_numeric(self, nodes, map_return=lambda typ: typ):
|
def _coerce_numeric(self, nodes, map_return=lambda typ: typ, map_node_type =lambda typ:typ):
|
||||||
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
||||||
node_types = []
|
node_types = []
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
|
@ -354,6 +358,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
node_types.append(node.type)
|
node_types.append(node.type)
|
||||||
else:
|
else:
|
||||||
node_types.append(node.type)
|
node_types.append(node.type)
|
||||||
|
node_types = [map_node_type(typ) for typ in node_types]
|
||||||
if any(map(types.is_var, node_types)): # not enough info yet
|
if any(map(types.is_var, node_types)): # not enough info yet
|
||||||
return
|
return
|
||||||
elif not all(map(builtins.is_numeric, node_types)):
|
elif not all(map(builtins.is_numeric, node_types)):
|
||||||
|
@ -386,7 +391,58 @@ class Inferencer(algorithm.Visitor):
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
def _coerce_binop(self, op, left, right):
|
def _coerce_binop(self, op, left, right):
|
||||||
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
if builtins.is_array(left.type) or builtins.is_array(right.type):
|
||||||
|
# Operations on arrays are element-wise (possibly using broadcasting).
|
||||||
|
# TODO: Matrix multiplication (which aren't element-wise).
|
||||||
|
|
||||||
|
# # TODO: Allow only for integer arrays.
|
||||||
|
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
|
||||||
|
# ast.RShift)
|
||||||
|
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
|
||||||
|
ast.Pow, ast.Sub, ast.Div)
|
||||||
|
if not isinstance(op, allowed_array_ops):
|
||||||
|
diag = diagnostic.Diagnostic(
|
||||||
|
"error", "operator '{op}' not valid for array types",
|
||||||
|
{"op": op.loc.source()}, op.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
def num_dims(typ):
|
||||||
|
if builtins.is_array(typ):
|
||||||
|
# TODO: If number of dimensions is ever made a non-fixed parameter,
|
||||||
|
# need to acutally unify num_dims in _coerce_binop.
|
||||||
|
return typ.find()["num_dims"].value
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# TODO: Broadcasting.
|
||||||
|
left_dims = num_dims(left.type)
|
||||||
|
right_dims = num_dims(right.type)
|
||||||
|
if left_dims != right_dims:
|
||||||
|
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
||||||
|
{"num_dims": left_dims}, left.loc)
|
||||||
|
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
||||||
|
{"num_dims": right_dims}, right.loc)
|
||||||
|
diag = diagnostic.Diagnostic(
|
||||||
|
"error", "dimensions of '{op}' array operands must match",
|
||||||
|
{"op": op.loc.source()}, op.loc, [left.loc, right.loc], [note1, note2])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
def map_node_type(typ):
|
||||||
|
if builtins.is_array(typ):
|
||||||
|
return typ.find()["elt"]
|
||||||
|
else:
|
||||||
|
# This is (if later valid) a single value broadcast across the array.
|
||||||
|
return typ
|
||||||
|
|
||||||
|
def map_return(typ):
|
||||||
|
a = builtins.TArray(elt=typ, num_dims=left_dims)
|
||||||
|
return (a, a, a)
|
||||||
|
|
||||||
|
return self._coerce_numeric((left, right),
|
||||||
|
map_return=map_return,
|
||||||
|
map_node_type=map_node_type)
|
||||||
|
elif isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||||
ast.LShift, ast.RShift)):
|
ast.LShift, ast.RShift)):
|
||||||
# bitwise operators require integers
|
# bitwise operators require integers
|
||||||
for operand in (left, right):
|
for operand in (left, right):
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
a = array([[1, 2], [3, 4]])
|
||||||
|
b = array([7, 8])
|
||||||
|
|
||||||
|
# NumPy supports implicit broadcasting over axes, which we don't (yet).
|
||||||
|
# Make sure there is a nice error message.
|
||||||
|
# CHECK-L: ${LINE:+3}: error: dimensions of '+' array operands must match
|
||||||
|
# CHECK-L: ${LINE:+2}: note: operand of dimension 2
|
||||||
|
# CHECK-L: ${LINE:+1}: note: operand of dimension 1
|
||||||
|
a + b
|
Loading…
Reference in New Issue