forked from M-Labs/artiq
compiler: Support MatMult in inferencer
Still needs actual codegen support.
This commit is contained in:
parent
4d48470320
commit
78afa2ea8e
@ -402,11 +402,66 @@ class Inferencer(algorithm.Visitor):
|
|||||||
assert False
|
assert False
|
||||||
|
|
||||||
def _coerce_binop(self, op, left, right):
|
def _coerce_binop(self, op, left, right):
|
||||||
if builtins.is_array(left.type) or builtins.is_array(right.type):
|
if isinstance(op, ast.MatMult):
|
||||||
# Operations on arrays are element-wise (possibly using broadcasting).
|
if types.is_var(left.type) or types.is_var(right.type):
|
||||||
# TODO: Matrix multiplication (which aren't element-wise).
|
return
|
||||||
|
|
||||||
# # TODO: Allow only for integer arrays.
|
def num_dims(operand):
|
||||||
|
if not builtins.is_array(operand.type):
|
||||||
|
diag = diagnostic.Diagnostic(
|
||||||
|
"error",
|
||||||
|
"expected matrix multiplication operand to be of array type, not {type}",
|
||||||
|
{
|
||||||
|
"op": op.loc.source(),
|
||||||
|
"type": types.TypePrinter().name(operand.type)
|
||||||
|
}, op.loc, [operand.loc])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
num_dims = operand.type.find()["num_dims"].value
|
||||||
|
if num_dims not in (1, 2):
|
||||||
|
diag = diagnostic.Diagnostic(
|
||||||
|
"error",
|
||||||
|
"expected matrix multiplication operand to be 1- or 2-dimensional, not {type}",
|
||||||
|
{
|
||||||
|
"op": op.loc.source(),
|
||||||
|
"type": types.TypePrinter().name(operand.type)
|
||||||
|
}, op.loc, [operand.loc])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
return num_dims
|
||||||
|
|
||||||
|
left_dims = num_dims(left)
|
||||||
|
if not left_dims:
|
||||||
|
return
|
||||||
|
right_dims = num_dims(right)
|
||||||
|
if not right_dims:
|
||||||
|
return
|
||||||
|
|
||||||
|
def map_node_type(typ):
|
||||||
|
return typ.find()["elt"]
|
||||||
|
|
||||||
|
def map_return(typ):
|
||||||
|
if left_dims == 1:
|
||||||
|
if right_dims == 1:
|
||||||
|
result_dims = 0
|
||||||
|
else:
|
||||||
|
result_dims = 1
|
||||||
|
elif right_dims == 1:
|
||||||
|
result_dims = 1
|
||||||
|
else:
|
||||||
|
result_dims = 2
|
||||||
|
result = typ if result_dims == 0 else builtins.TArray(
|
||||||
|
typ, result_dims)
|
||||||
|
return (result, builtins.TArray(typ, left_dims),
|
||||||
|
builtins.TArray(typ, right_dims))
|
||||||
|
|
||||||
|
return self._coerce_numeric((left, right),
|
||||||
|
map_return=map_return,
|
||||||
|
map_node_type=map_node_type)
|
||||||
|
elif builtins.is_array(left.type) or builtins.is_array(right.type):
|
||||||
|
# Operations on arrays are element-wise (possibly using broadcasting).
|
||||||
|
|
||||||
|
# TODO: Allow only for integer arrays.
|
||||||
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
|
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
|
||||||
# ast.RShift)
|
# ast.RShift)
|
||||||
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
|
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
|
||||||
@ -553,7 +608,7 @@ class Inferencer(algorithm.Visitor):
|
|||||||
# division always returns a float
|
# division always returns a float
|
||||||
return self._coerce_numeric((left, right),
|
return self._coerce_numeric((left, right),
|
||||||
lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat()))
|
lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat()))
|
||||||
else: # MatMult
|
else:
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"operator '{op}' is not supported", {"op": op.loc.source()},
|
"operator '{op}' is not supported", {"op": op.loc.source()},
|
||||||
op.loc)
|
op.loc)
|
||||||
|
11
artiq/test/lit/inferencer/error_matmult.py
Normal file
11
artiq/test/lit/inferencer/error_matmult.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be of array type
|
||||||
|
1 @ 2
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be of array type
|
||||||
|
[1] @ [2]
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be 1- or 2-dimensional
|
||||||
|
array([[[0]]]) @ array([[[1]]])
|
17
artiq/test/lit/inferencer/matmult.py
Normal file
17
artiq/test/lit/inferencer/matmult.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
vec = array([0, 1])
|
||||||
|
mat = array([[0, 1], [2, 3]])
|
||||||
|
|
||||||
|
# CHECK-L: ):numpy.int?
|
||||||
|
vec @ vec
|
||||||
|
|
||||||
|
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=1)
|
||||||
|
vec @ mat
|
||||||
|
|
||||||
|
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=1)
|
||||||
|
mat @ vec
|
||||||
|
|
||||||
|
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=2)
|
||||||
|
mat @ mat
|
Loading…
Reference in New Issue
Block a user