2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-25 19:28:26 +08:00

compiler: Support MatMult in inferencer

Still needs actual codegen support.
This commit is contained in:
David Nadlinger 2020-08-02 17:52:15 +01:00
parent 4d48470320
commit 78afa2ea8e
3 changed files with 88 additions and 5 deletions

View File

@ -402,11 +402,66 @@ class Inferencer(algorithm.Visitor):
assert False
def _coerce_binop(self, op, left, right):
if builtins.is_array(left.type) or builtins.is_array(right.type):
# Operations on arrays are element-wise (possibly using broadcasting).
# TODO: Matrix multiplication (which aren't element-wise).
if isinstance(op, ast.MatMult):
if types.is_var(left.type) or types.is_var(right.type):
return
# # TODO: Allow only for integer arrays.
def num_dims(operand):
if not builtins.is_array(operand.type):
diag = diagnostic.Diagnostic(
"error",
"expected matrix multiplication operand to be of array type, not {type}",
{
"op": op.loc.source(),
"type": types.TypePrinter().name(operand.type)
}, op.loc, [operand.loc])
self.engine.process(diag)
return
num_dims = operand.type.find()["num_dims"].value
if num_dims not in (1, 2):
diag = diagnostic.Diagnostic(
"error",
"expected matrix multiplication operand to be 1- or 2-dimensional, not {type}",
{
"op": op.loc.source(),
"type": types.TypePrinter().name(operand.type)
}, op.loc, [operand.loc])
self.engine.process(diag)
return
return num_dims
left_dims = num_dims(left)
if not left_dims:
return
right_dims = num_dims(right)
if not right_dims:
return
def map_node_type(typ):
return typ.find()["elt"]
def map_return(typ):
if left_dims == 1:
if right_dims == 1:
result_dims = 0
else:
result_dims = 1
elif right_dims == 1:
result_dims = 1
else:
result_dims = 2
result = typ if result_dims == 0 else builtins.TArray(
typ, result_dims)
return (result, builtins.TArray(typ, left_dims),
builtins.TArray(typ, right_dims))
return self._coerce_numeric((left, right),
map_return=map_return,
map_node_type=map_node_type)
elif builtins.is_array(left.type) or builtins.is_array(right.type):
# Operations on arrays are element-wise (possibly using broadcasting).
# TODO: Allow only for integer arrays.
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
# ast.RShift)
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
@ -553,7 +608,7 @@ class Inferencer(algorithm.Visitor):
# division always returns a float
return self._coerce_numeric((left, right),
lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat()))
else: # MatMult
else:
diag = diagnostic.Diagnostic("error",
"operator '{op}' is not supported", {"op": op.loc.source()},
op.loc)

View File

@ -0,0 +1,11 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be of array type
1 @ 2
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be of array type
[1] @ [2]
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be 1- or 2-dimensional
array([[[0]]]) @ array([[[1]]])

View File

@ -0,0 +1,17 @@
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
# RUN: OutputCheck %s --file-to-check=%t
vec = array([0, 1])
mat = array([[0, 1], [2, 3]])
# CHECK-L: ):numpy.int?
vec @ vec
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=1)
vec @ mat
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=1)
mat @ vec
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=2)
mat @ mat