2
0
mirror of https://github.com/m-labs/artiq.git synced 2025-01-25 09:58:13 +08:00

compiler: Implement matrix multiplication

LLVM will take care of optimising the loops. This was still
unnecessarily painful; implementing generics and implementing
this in ARTIQ Python looks very attractive right now.
This commit is contained in:
David Nadlinger 2020-08-02 20:26:20 +01:00
parent 0da4a61d99
commit ef260adca8
2 changed files with 218 additions and 35 deletions

View File

@ -527,7 +527,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
if num_dims > 1:
old_shape = self.append(ir.GetAttr(value, "shape"))
lengths = [self.append(ir.GetAttr(old_shape, i)) for i in range(1, num_dims)]
new_shape = self.append(ir.Alloc(lengths, types.TTuple(old_shape.type.elts[1:])))
new_shape = self._make_array_shape(lengths)
stride = reduce(
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
@ -1444,7 +1444,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
return self.append(ir.Alloc([buffer, shape], result_type))
def _make_array_binop(self, name, op, result_type, lhs_type, rhs_type):
def _make_array_binop(self, name, result_type, lhs_type, rhs_type, body_gen):
try:
result = ir.Argument(result_type, "result")
lhs = ir.Argument(lhs_type, "lhs")
@ -1461,8 +1461,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
ret=builtins.TNone())
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
# TODO: What to use for loc?
func = ir.Function(typ, name, env_args + args, loc=None)
old_loc, self.current_loc = self.current_loc, None
func = ir.Function(typ, name, env_args + args)
func.is_internal = True
func.is_generated = True
self.functions.append(func)
@ -1474,36 +1474,12 @@ class ARTIQIRGenerator(algorithm.Visitor):
old_final_branch, self.final_branch = self.final_branch, None
old_unwind, self.unwind_target = self.unwind_target, None
shape = self.append(ir.GetAttr(lhs, "shape"))
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
self._make_check(
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
lambda: self.alloc_exn(
builtins.TException("ValueError"),
ir.Constant("operands could not be broadcast together",
builtins.TStr())))
# We assume result has correct shape; could just pass buffer pointer as well.
result_buffer = self.append(ir.GetAttr(result, "buffer"))
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
num_total_elts = self._get_total_array_len(shape)
def body_gen(index):
l = self.append(ir.GetElem(lhs_buffer, index))
r = self.append(ir.GetElem(rhs_buffer, index))
self.append(
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, r))))
return self.append(
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
body_gen(result, lhs, rhs)
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
return func
finally:
self.current_loc = old_loc
self.current_function = old_func
self.current_block = old_block
self.final_branch = old_final_branch
@ -1518,8 +1494,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
# rpc_tag is used to turn element types into mangled names for no
# particularly good reason apart from not having to invent yet another
# string representation.
return (ir.rpc_tag(typ["elt"], name_error).decode() +
str(typ["num_dims"].find().value))
if builtins.is_array(typ):
return mangle_name(typ["elt"]) + str(typ["num_dims"].find().value)
return ir.rpc_tag(typ, name_error).decode()
return "_".join(mangle_name(t) for t in types)
@ -1531,8 +1508,41 @@ class ARTIQIRGenerator(algorithm.Visitor):
type(op).__name__,
self._mangle_arrayop_types([result_type, lhs_type, rhs_type]))
if name not in self.array_binop_funcs:
def body_gen(result, lhs, rhs):
# TODO: Move into caller for correct location information (or pass)?
shape = self.append(ir.GetAttr(lhs, "shape"))
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
self._make_check(
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
lambda: self.alloc_exn(
builtins.TException("ValueError"),
ir.Constant("operands could not be broadcast together",
builtins.TStr())))
# We assume result has correct shape; could just pass buffer pointer
# as well.
result_buffer = self.append(ir.GetAttr(result, "buffer"))
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
num_total_elts = self._get_total_array_len(shape)
def loop_gen(index):
l = self.append(ir.GetElem(lhs_buffer, index))
r = self.append(ir.GetElem(rhs_buffer, index))
self.append(
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l,
r))))
return self.append(
ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), loop_gen)
self.array_binop_funcs[name] = self._make_array_binop(
name, op, result_type, lhs_type, rhs_type)
name, result_type, lhs_type, rhs_type, body_gen)
return self.array_binop_funcs[name]
def _invoke_arrayop(self, func, params):
@ -1545,14 +1555,162 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
self.current_block = after_invoke
def _get_array_offset(self, shape, indices):
last_stride = None
result = indices[0]
for dim, index in zip(shape[:-1], indices[1:]):
result = self.append(ir.Arith(ast.Mult(loc=None), result, dim))
result = self.append(ir.Arith(ast.Add(loc=None), result, index))
return result
def _get_matmult(self, result_type, lhs_type, rhs_type):
name = "_array_MatMult_" + self._mangle_arrayop_types(
[result_type, lhs_type, rhs_type])
if name not in self.array_binop_funcs:
def body_gen(result, lhs, rhs):
assert builtins.is_array(result.type), \
"vec @ vec should have been normalised into array result"
# We assume result has correct shape; could just pass buffer pointer
# as well.
result_buffer = self.append(ir.GetAttr(result, "buffer"))
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
num_rows, num_summands, _, num_cols = self._get_matmult_shapes(lhs, rhs)
elt = result.type["elt"].find()
env_type = ir.TEnvironment("loop", {"$total": elt})
env = self.append(ir.Alloc([], env_type))
def row_loop(row_idx):
lhs_base_offset = self.append(
ir.Arith(ast.Mult(loc=None), row_idx, num_summands))
lhs_base = self.append(ir.Offset(lhs_buffer, lhs_base_offset))
result_base_offset = self.append(
ir.Arith(ast.Mult(loc=None), row_idx, num_cols))
result_base = self.append(
ir.Offset(result_buffer, result_base_offset))
def col_loop(col_idx):
rhs_base = self.append(ir.Offset(rhs_buffer, col_idx))
self.append(
ir.SetLocal(env, "$total", ir.Constant(elt.zero(), elt)))
def sum_loop(sum_idx):
lhs_elem = self.append(ir.GetElem(lhs_base, sum_idx))
rhs_offset = self.append(
ir.Arith(ast.Mult(loc=None), sum_idx, num_cols))
rhs_elem = self.append(ir.GetElem(rhs_base, rhs_offset))
product = self.append(
ir.Arith(ast.Mult(loc=None), lhs_elem, rhs_elem))
prev_total = self.append(ir.GetLocal(env, "$total"))
total = self.append(
ir.Arith(ast.Add(loc=None), prev_total, product))
self.append(ir.SetLocal(env, "$total", total))
return self.append(
ir.Arith(ast.Add(loc=None), sum_idx,
ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, num_summands)),
sum_loop)
total = self.append(ir.GetLocal(env, "$total"))
self.append(ir.SetElem(result_base, col_idx, total))
return self.append(
ir.Arith(ast.Add(loc=None), col_idx,
ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, num_cols)), col_loop)
return self.append(
ir.Arith(ast.Add(loc=None), row_idx,
ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, num_rows)), row_loop)
self.array_binop_funcs[name] = self._make_array_binop(
name, result_type, lhs_type, rhs_type, body_gen)
return self.array_binop_funcs[name]
def _get_matmult_shapes(self, lhs, rhs):
lhs_shape = self.append(ir.GetAttr(lhs, "shape"))
if lhs.type["num_dims"].value == 1:
lhs_shape_outer = ir.Constant(1, self._size_type)
lhs_shape_inner = self.append(ir.GetAttr(lhs_shape, 0))
else:
lhs_shape_outer = self.append(ir.GetAttr(lhs_shape, 0))
lhs_shape_inner = self.append(ir.GetAttr(lhs_shape, 1))
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
if rhs.type["num_dims"].value == 1:
rhs_shape_inner = self.append(ir.GetAttr(rhs_shape, 0))
rhs_shape_outer = ir.Constant(1, self._size_type)
else:
rhs_shape_inner = self.append(ir.GetAttr(rhs_shape, 0))
rhs_shape_outer = self.append(ir.GetAttr(rhs_shape, 1))
return lhs_shape_outer, lhs_shape_inner, rhs_shape_inner, rhs_shape_outer
def _make_array_shape(self, dims):
return self.append(ir.Alloc(dims, types.TTuple([self._size_type] * len(dims))))
def _emit_matmult(self, node, left, right):
# TODO: Also expose as numpy.dot.
lhs = self.visit(left)
rhs = self.visit(right)
num_rows, lhs_inner, rhs_inner, num_cols = self._get_matmult_shapes(lhs, rhs)
self._make_check(
self.append(ir.Compare(ast.Eq(loc=None), lhs_inner, rhs_inner)),
lambda lhs_inner, rhs_inner: self.alloc_exn(
builtins.TException("ValueError"),
ir.Constant(
"inner dimensions for matrix multiplication do not match ({0} vs. {1})",
builtins.TStr()), lhs_inner, rhs_inner),
params=[lhs_inner, rhs_inner],
loc=node.loc)
result_shape = self._make_array_shape([num_rows, num_cols])
final_type = node.type.find()
if not builtins.is_array(final_type):
elt = node.type
result_dims = 0
else:
elt = final_type["elt"]
result_dims = final_type["num_dims"].value
result = self._allocate_new_array(elt, result_shape)
func = self._get_matmult(result.type, left.type, right.type)
self._invoke_arrayop(func, [result, lhs, rhs])
if result_dims == 2:
return result
result_buffer = self.append(ir.GetAttr(result, "buffer"))
if result_dims == 1:
shape = self._make_array_shape(
[num_cols if lhs.type["num_dims"].value == 1 else num_rows])
return self.append(ir.Alloc([result_buffer, shape], node.type))
return self.append(ir.GetElem(result_buffer, ir.Constant(0, self._size_type)))
def visit_BinOpT(self, node):
if builtins.is_array(node.type):
if isinstance(node.op, ast.MatMult):
return self._emit_matmult(node, node.left, node.right)
elif builtins.is_array(node.type):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
# Array op implementation will check for matching shape.
# TODO: Broadcasts; select the widest shape.
# TODO: Detect and special-case matrix multiplication.
shape = self.append(ir.GetAttr(lhs, "shape"))
result = self._allocate_new_array(node.type.find()["elt"], shape)

View File

@ -0,0 +1,25 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
mat23 = array([[1, 2, 3], [4, 5, 6]])
mat32 = array([[1, 2], [3, 4], [5, 6]])
vec2 = array([1, 2])
vec3 = array([1, 2, 3])
assert vec3 @ vec3 == 14
a = mat23 @ mat32
assert a.shape == (2, 2)
assert a[0][0] == 22
assert a[0][1] == 28
assert a[1][0] == 49
assert a[1][1] == 64
b = mat23 @ vec3
assert b.shape == (2,)
assert b[0] == 14
assert b[1] == 32
b = vec3 @ mat32
assert b.shape == (2,)
assert b[0] == 22
assert b[1] == 28