forked from M-Labs/artiq
compiler: Implement matrix multiplication
LLVM will take care of optimising the loops. This was still unnecessarily painful; implementing generics and implementing this in ARTIQ Python looks very attractive right now.
This commit is contained in:
parent
0da4a61d99
commit
ef260adca8
|
@ -527,7 +527,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
if num_dims > 1:
|
if num_dims > 1:
|
||||||
old_shape = self.append(ir.GetAttr(value, "shape"))
|
old_shape = self.append(ir.GetAttr(value, "shape"))
|
||||||
lengths = [self.append(ir.GetAttr(old_shape, i)) for i in range(1, num_dims)]
|
lengths = [self.append(ir.GetAttr(old_shape, i)) for i in range(1, num_dims)]
|
||||||
new_shape = self.append(ir.Alloc(lengths, types.TTuple(old_shape.type.elts[1:])))
|
new_shape = self._make_array_shape(lengths)
|
||||||
|
|
||||||
stride = reduce(
|
stride = reduce(
|
||||||
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
||||||
|
@ -1444,7 +1444,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
|
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
|
||||||
return self.append(ir.Alloc([buffer, shape], result_type))
|
return self.append(ir.Alloc([buffer, shape], result_type))
|
||||||
|
|
||||||
def _make_array_binop(self, name, op, result_type, lhs_type, rhs_type):
|
def _make_array_binop(self, name, result_type, lhs_type, rhs_type, body_gen):
|
||||||
try:
|
try:
|
||||||
result = ir.Argument(result_type, "result")
|
result = ir.Argument(result_type, "result")
|
||||||
lhs = ir.Argument(lhs_type, "lhs")
|
lhs = ir.Argument(lhs_type, "lhs")
|
||||||
|
@ -1461,8 +1461,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
ret=builtins.TNone())
|
ret=builtins.TNone())
|
||||||
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
||||||
|
|
||||||
# TODO: What to use for loc?
|
old_loc, self.current_loc = self.current_loc, None
|
||||||
func = ir.Function(typ, name, env_args + args, loc=None)
|
func = ir.Function(typ, name, env_args + args)
|
||||||
func.is_internal = True
|
func.is_internal = True
|
||||||
func.is_generated = True
|
func.is_generated = True
|
||||||
self.functions.append(func)
|
self.functions.append(func)
|
||||||
|
@ -1474,36 +1474,12 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
old_final_branch, self.final_branch = self.final_branch, None
|
old_final_branch, self.final_branch = self.final_branch, None
|
||||||
old_unwind, self.unwind_target = self.unwind_target, None
|
old_unwind, self.unwind_target = self.unwind_target, None
|
||||||
|
|
||||||
shape = self.append(ir.GetAttr(lhs, "shape"))
|
body_gen(result, lhs, rhs)
|
||||||
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
|
||||||
self._make_check(
|
|
||||||
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
|
||||||
lambda: self.alloc_exn(
|
|
||||||
builtins.TException("ValueError"),
|
|
||||||
ir.Constant("operands could not be broadcast together",
|
|
||||||
builtins.TStr())))
|
|
||||||
# We assume result has correct shape; could just pass buffer pointer as well.
|
|
||||||
|
|
||||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
|
||||||
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
|
||||||
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
|
||||||
num_total_elts = self._get_total_array_len(shape)
|
|
||||||
|
|
||||||
def body_gen(index):
|
|
||||||
l = self.append(ir.GetElem(lhs_buffer, index))
|
|
||||||
r = self.append(ir.GetElem(rhs_buffer, index))
|
|
||||||
self.append(
|
|
||||||
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, r))))
|
|
||||||
return self.append(
|
|
||||||
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
|
|
||||||
|
|
||||||
self._make_loop(
|
|
||||||
ir.Constant(0, self._size_type), lambda index: self.append(
|
|
||||||
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
|
||||||
|
|
||||||
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||||
return func
|
return func
|
||||||
finally:
|
finally:
|
||||||
|
self.current_loc = old_loc
|
||||||
self.current_function = old_func
|
self.current_function = old_func
|
||||||
self.current_block = old_block
|
self.current_block = old_block
|
||||||
self.final_branch = old_final_branch
|
self.final_branch = old_final_branch
|
||||||
|
@ -1518,8 +1494,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
# rpc_tag is used to turn element types into mangled names for no
|
# rpc_tag is used to turn element types into mangled names for no
|
||||||
# particularly good reason apart from not having to invent yet another
|
# particularly good reason apart from not having to invent yet another
|
||||||
# string representation.
|
# string representation.
|
||||||
return (ir.rpc_tag(typ["elt"], name_error).decode() +
|
if builtins.is_array(typ):
|
||||||
str(typ["num_dims"].find().value))
|
return mangle_name(typ["elt"]) + str(typ["num_dims"].find().value)
|
||||||
|
return ir.rpc_tag(typ, name_error).decode()
|
||||||
|
|
||||||
return "_".join(mangle_name(t) for t in types)
|
return "_".join(mangle_name(t) for t in types)
|
||||||
|
|
||||||
|
@ -1531,8 +1508,41 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
type(op).__name__,
|
type(op).__name__,
|
||||||
self._mangle_arrayop_types([result_type, lhs_type, rhs_type]))
|
self._mangle_arrayop_types([result_type, lhs_type, rhs_type]))
|
||||||
if name not in self.array_binop_funcs:
|
if name not in self.array_binop_funcs:
|
||||||
|
|
||||||
|
def body_gen(result, lhs, rhs):
|
||||||
|
# TODO: Move into caller for correct location information (or pass)?
|
||||||
|
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||||
|
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||||
|
self._make_check(
|
||||||
|
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
||||||
|
lambda: self.alloc_exn(
|
||||||
|
builtins.TException("ValueError"),
|
||||||
|
ir.Constant("operands could not be broadcast together",
|
||||||
|
builtins.TStr())))
|
||||||
|
# We assume result has correct shape; could just pass buffer pointer
|
||||||
|
# as well.
|
||||||
|
|
||||||
|
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||||
|
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
||||||
|
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
||||||
|
num_total_elts = self._get_total_array_len(shape)
|
||||||
|
|
||||||
|
def loop_gen(index):
|
||||||
|
l = self.append(ir.GetElem(lhs_buffer, index))
|
||||||
|
r = self.append(ir.GetElem(rhs_buffer, index))
|
||||||
|
self.append(
|
||||||
|
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l,
|
||||||
|
r))))
|
||||||
|
return self.append(
|
||||||
|
ir.Arith(ast.Add(loc=None), index,
|
||||||
|
ir.Constant(1, self._size_type)))
|
||||||
|
|
||||||
|
self._make_loop(
|
||||||
|
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||||
|
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), loop_gen)
|
||||||
|
|
||||||
self.array_binop_funcs[name] = self._make_array_binop(
|
self.array_binop_funcs[name] = self._make_array_binop(
|
||||||
name, op, result_type, lhs_type, rhs_type)
|
name, result_type, lhs_type, rhs_type, body_gen)
|
||||||
return self.array_binop_funcs[name]
|
return self.array_binop_funcs[name]
|
||||||
|
|
||||||
def _invoke_arrayop(self, func, params):
|
def _invoke_arrayop(self, func, params):
|
||||||
|
@ -1545,14 +1555,162 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
|
self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
|
||||||
self.current_block = after_invoke
|
self.current_block = after_invoke
|
||||||
|
|
||||||
|
def _get_array_offset(self, shape, indices):
|
||||||
|
last_stride = None
|
||||||
|
result = indices[0]
|
||||||
|
for dim, index in zip(shape[:-1], indices[1:]):
|
||||||
|
result = self.append(ir.Arith(ast.Mult(loc=None), result, dim))
|
||||||
|
result = self.append(ir.Arith(ast.Add(loc=None), result, index))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_matmult(self, result_type, lhs_type, rhs_type):
|
||||||
|
name = "_array_MatMult_" + self._mangle_arrayop_types(
|
||||||
|
[result_type, lhs_type, rhs_type])
|
||||||
|
if name not in self.array_binop_funcs:
|
||||||
|
|
||||||
|
def body_gen(result, lhs, rhs):
|
||||||
|
assert builtins.is_array(result.type), \
|
||||||
|
"vec @ vec should have been normalised into array result"
|
||||||
|
|
||||||
|
# We assume result has correct shape; could just pass buffer pointer
|
||||||
|
# as well.
|
||||||
|
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||||
|
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
||||||
|
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
||||||
|
|
||||||
|
num_rows, num_summands, _, num_cols = self._get_matmult_shapes(lhs, rhs)
|
||||||
|
|
||||||
|
elt = result.type["elt"].find()
|
||||||
|
env_type = ir.TEnvironment("loop", {"$total": elt})
|
||||||
|
env = self.append(ir.Alloc([], env_type))
|
||||||
|
|
||||||
|
def row_loop(row_idx):
|
||||||
|
lhs_base_offset = self.append(
|
||||||
|
ir.Arith(ast.Mult(loc=None), row_idx, num_summands))
|
||||||
|
lhs_base = self.append(ir.Offset(lhs_buffer, lhs_base_offset))
|
||||||
|
result_base_offset = self.append(
|
||||||
|
ir.Arith(ast.Mult(loc=None), row_idx, num_cols))
|
||||||
|
result_base = self.append(
|
||||||
|
ir.Offset(result_buffer, result_base_offset))
|
||||||
|
|
||||||
|
def col_loop(col_idx):
|
||||||
|
rhs_base = self.append(ir.Offset(rhs_buffer, col_idx))
|
||||||
|
|
||||||
|
self.append(
|
||||||
|
ir.SetLocal(env, "$total", ir.Constant(elt.zero(), elt)))
|
||||||
|
|
||||||
|
def sum_loop(sum_idx):
|
||||||
|
lhs_elem = self.append(ir.GetElem(lhs_base, sum_idx))
|
||||||
|
rhs_offset = self.append(
|
||||||
|
ir.Arith(ast.Mult(loc=None), sum_idx, num_cols))
|
||||||
|
rhs_elem = self.append(ir.GetElem(rhs_base, rhs_offset))
|
||||||
|
product = self.append(
|
||||||
|
ir.Arith(ast.Mult(loc=None), lhs_elem, rhs_elem))
|
||||||
|
prev_total = self.append(ir.GetLocal(env, "$total"))
|
||||||
|
total = self.append(
|
||||||
|
ir.Arith(ast.Add(loc=None), prev_total, product))
|
||||||
|
self.append(ir.SetLocal(env, "$total", total))
|
||||||
|
return self.append(
|
||||||
|
ir.Arith(ast.Add(loc=None), sum_idx,
|
||||||
|
ir.Constant(1, self._size_type)))
|
||||||
|
|
||||||
|
self._make_loop(
|
||||||
|
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||||
|
ir.Compare(ast.Lt(loc=None), index, num_summands)),
|
||||||
|
sum_loop)
|
||||||
|
|
||||||
|
total = self.append(ir.GetLocal(env, "$total"))
|
||||||
|
self.append(ir.SetElem(result_base, col_idx, total))
|
||||||
|
|
||||||
|
return self.append(
|
||||||
|
ir.Arith(ast.Add(loc=None), col_idx,
|
||||||
|
ir.Constant(1, self._size_type)))
|
||||||
|
|
||||||
|
self._make_loop(
|
||||||
|
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||||
|
ir.Compare(ast.Lt(loc=None), index, num_cols)), col_loop)
|
||||||
|
return self.append(
|
||||||
|
ir.Arith(ast.Add(loc=None), row_idx,
|
||||||
|
ir.Constant(1, self._size_type)))
|
||||||
|
|
||||||
|
self._make_loop(
|
||||||
|
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||||
|
ir.Compare(ast.Lt(loc=None), index, num_rows)), row_loop)
|
||||||
|
|
||||||
|
self.array_binop_funcs[name] = self._make_array_binop(
|
||||||
|
name, result_type, lhs_type, rhs_type, body_gen)
|
||||||
|
return self.array_binop_funcs[name]
|
||||||
|
|
||||||
|
def _get_matmult_shapes(self, lhs, rhs):
|
||||||
|
lhs_shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||||
|
if lhs.type["num_dims"].value == 1:
|
||||||
|
lhs_shape_outer = ir.Constant(1, self._size_type)
|
||||||
|
lhs_shape_inner = self.append(ir.GetAttr(lhs_shape, 0))
|
||||||
|
else:
|
||||||
|
lhs_shape_outer = self.append(ir.GetAttr(lhs_shape, 0))
|
||||||
|
lhs_shape_inner = self.append(ir.GetAttr(lhs_shape, 1))
|
||||||
|
|
||||||
|
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||||
|
if rhs.type["num_dims"].value == 1:
|
||||||
|
rhs_shape_inner = self.append(ir.GetAttr(rhs_shape, 0))
|
||||||
|
rhs_shape_outer = ir.Constant(1, self._size_type)
|
||||||
|
else:
|
||||||
|
rhs_shape_inner = self.append(ir.GetAttr(rhs_shape, 0))
|
||||||
|
rhs_shape_outer = self.append(ir.GetAttr(rhs_shape, 1))
|
||||||
|
|
||||||
|
return lhs_shape_outer, lhs_shape_inner, rhs_shape_inner, rhs_shape_outer
|
||||||
|
|
||||||
|
def _make_array_shape(self, dims):
|
||||||
|
return self.append(ir.Alloc(dims, types.TTuple([self._size_type] * len(dims))))
|
||||||
|
|
||||||
|
def _emit_matmult(self, node, left, right):
|
||||||
|
# TODO: Also expose as numpy.dot.
|
||||||
|
lhs = self.visit(left)
|
||||||
|
rhs = self.visit(right)
|
||||||
|
|
||||||
|
num_rows, lhs_inner, rhs_inner, num_cols = self._get_matmult_shapes(lhs, rhs)
|
||||||
|
self._make_check(
|
||||||
|
self.append(ir.Compare(ast.Eq(loc=None), lhs_inner, rhs_inner)),
|
||||||
|
lambda lhs_inner, rhs_inner: self.alloc_exn(
|
||||||
|
builtins.TException("ValueError"),
|
||||||
|
ir.Constant(
|
||||||
|
"inner dimensions for matrix multiplication do not match ({0} vs. {1})",
|
||||||
|
builtins.TStr()), lhs_inner, rhs_inner),
|
||||||
|
params=[lhs_inner, rhs_inner],
|
||||||
|
loc=node.loc)
|
||||||
|
result_shape = self._make_array_shape([num_rows, num_cols])
|
||||||
|
|
||||||
|
final_type = node.type.find()
|
||||||
|
if not builtins.is_array(final_type):
|
||||||
|
elt = node.type
|
||||||
|
result_dims = 0
|
||||||
|
else:
|
||||||
|
elt = final_type["elt"]
|
||||||
|
result_dims = final_type["num_dims"].value
|
||||||
|
|
||||||
|
result = self._allocate_new_array(elt, result_shape)
|
||||||
|
func = self._get_matmult(result.type, left.type, right.type)
|
||||||
|
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||||
|
|
||||||
|
if result_dims == 2:
|
||||||
|
return result
|
||||||
|
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||||
|
if result_dims == 1:
|
||||||
|
shape = self._make_array_shape(
|
||||||
|
[num_cols if lhs.type["num_dims"].value == 1 else num_rows])
|
||||||
|
return self.append(ir.Alloc([result_buffer, shape], node.type))
|
||||||
|
return self.append(ir.GetElem(result_buffer, ir.Constant(0, self._size_type)))
|
||||||
|
|
||||||
|
|
||||||
def visit_BinOpT(self, node):
|
def visit_BinOpT(self, node):
|
||||||
if builtins.is_array(node.type):
|
if isinstance(node.op, ast.MatMult):
|
||||||
|
return self._emit_matmult(node, node.left, node.right)
|
||||||
|
elif builtins.is_array(node.type):
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
rhs = self.visit(node.right)
|
rhs = self.visit(node.right)
|
||||||
|
|
||||||
# Array op implementation will check for matching shape.
|
# Array op implementation will check for matching shape.
|
||||||
# TODO: Broadcasts; select the widest shape.
|
# TODO: Broadcasts; select the widest shape.
|
||||||
# TODO: Detect and special-case matrix multiplication.
|
|
||||||
shape = self.append(ir.GetAttr(lhs, "shape"))
|
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||||
|
|
||||||
|
mat23 = array([[1, 2, 3], [4, 5, 6]])
|
||||||
|
mat32 = array([[1, 2], [3, 4], [5, 6]])
|
||||||
|
vec2 = array([1, 2])
|
||||||
|
vec3 = array([1, 2, 3])
|
||||||
|
|
||||||
|
assert vec3 @ vec3 == 14
|
||||||
|
|
||||||
|
a = mat23 @ mat32
|
||||||
|
assert a.shape == (2, 2)
|
||||||
|
assert a[0][0] == 22
|
||||||
|
assert a[0][1] == 28
|
||||||
|
assert a[1][0] == 49
|
||||||
|
assert a[1][1] == 64
|
||||||
|
|
||||||
|
b = mat23 @ vec3
|
||||||
|
assert b.shape == (2,)
|
||||||
|
assert b[0] == 14
|
||||||
|
assert b[1] == 32
|
||||||
|
|
||||||
|
b = vec3 @ mat32
|
||||||
|
assert b.shape == (2,)
|
||||||
|
assert b[0] == 22
|
||||||
|
assert b[1] == 28
|
Loading…
Reference in New Issue