forked from M-Labs/artiq
compiler: Support MatMult in inferencer
Still needs actual codegen support.
This commit is contained in:
parent
4d48470320
commit
78afa2ea8e
@ -402,11 +402,66 @@ class Inferencer(algorithm.Visitor):
|
||||
assert False
|
||||
|
||||
def _coerce_binop(self, op, left, right):
|
||||
if builtins.is_array(left.type) or builtins.is_array(right.type):
|
||||
# Operations on arrays are element-wise (possibly using broadcasting).
|
||||
# TODO: Matrix multiplication (which aren't element-wise).
|
||||
if isinstance(op, ast.MatMult):
|
||||
if types.is_var(left.type) or types.is_var(right.type):
|
||||
return
|
||||
|
||||
# # TODO: Allow only for integer arrays.
|
||||
def num_dims(operand):
|
||||
if not builtins.is_array(operand.type):
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"expected matrix multiplication operand to be of array type, not {type}",
|
||||
{
|
||||
"op": op.loc.source(),
|
||||
"type": types.TypePrinter().name(operand.type)
|
||||
}, op.loc, [operand.loc])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
num_dims = operand.type.find()["num_dims"].value
|
||||
if num_dims not in (1, 2):
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"expected matrix multiplication operand to be 1- or 2-dimensional, not {type}",
|
||||
{
|
||||
"op": op.loc.source(),
|
||||
"type": types.TypePrinter().name(operand.type)
|
||||
}, op.loc, [operand.loc])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
return num_dims
|
||||
|
||||
left_dims = num_dims(left)
|
||||
if not left_dims:
|
||||
return
|
||||
right_dims = num_dims(right)
|
||||
if not right_dims:
|
||||
return
|
||||
|
||||
def map_node_type(typ):
|
||||
return typ.find()["elt"]
|
||||
|
||||
def map_return(typ):
|
||||
if left_dims == 1:
|
||||
if right_dims == 1:
|
||||
result_dims = 0
|
||||
else:
|
||||
result_dims = 1
|
||||
elif right_dims == 1:
|
||||
result_dims = 1
|
||||
else:
|
||||
result_dims = 2
|
||||
result = typ if result_dims == 0 else builtins.TArray(
|
||||
typ, result_dims)
|
||||
return (result, builtins.TArray(typ, left_dims),
|
||||
builtins.TArray(typ, right_dims))
|
||||
|
||||
return self._coerce_numeric((left, right),
|
||||
map_return=map_return,
|
||||
map_node_type=map_node_type)
|
||||
elif builtins.is_array(left.type) or builtins.is_array(right.type):
|
||||
# Operations on arrays are element-wise (possibly using broadcasting).
|
||||
|
||||
# TODO: Allow only for integer arrays.
|
||||
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
|
||||
# ast.RShift)
|
||||
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
|
||||
@ -553,7 +608,7 @@ class Inferencer(algorithm.Visitor):
|
||||
# division always returns a float
|
||||
return self._coerce_numeric((left, right),
|
||||
lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat()))
|
||||
else: # MatMult
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"operator '{op}' is not supported", {"op": op.loc.source()},
|
||||
op.loc)
|
||||
|
11
artiq/test/lit/inferencer/error_matmult.py
Normal file
11
artiq/test/lit/inferencer/error_matmult.py
Normal file
@ -0,0 +1,11 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be of array type
|
||||
1 @ 2
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be of array type
|
||||
[1] @ [2]
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: expected matrix multiplication operand to be 1- or 2-dimensional
|
||||
array([[[0]]]) @ array([[[1]]])
|
17
artiq/test/lit/inferencer/matmult.py
Normal file
17
artiq/test/lit/inferencer/matmult.py
Normal file
@ -0,0 +1,17 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
vec = array([0, 1])
|
||||
mat = array([[0, 1], [2, 3]])
|
||||
|
||||
# CHECK-L: ):numpy.int?
|
||||
vec @ vec
|
||||
|
||||
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=1)
|
||||
vec @ mat
|
||||
|
||||
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=1)
|
||||
mat @ vec
|
||||
|
||||
# CHECK-L: ):numpy.array(elt=numpy.int?, num_dims=2)
|
||||
mat @ mat
|
Loading…
Reference in New Issue
Block a user