diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 81b4b5ce5..fa5912999 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -527,7 +527,7 @@ class ARTIQIRGenerator(algorithm.Visitor): if num_dims > 1: old_shape = self.append(ir.GetAttr(value, "shape")) lengths = [self.append(ir.GetAttr(old_shape, i)) for i in range(1, num_dims)] - new_shape = self.append(ir.Alloc(lengths, types.TTuple(old_shape.type.elts[1:]))) + new_shape = self._make_array_shape(lengths) stride = reduce( lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)), @@ -1444,7 +1444,7 @@ class ARTIQIRGenerator(algorithm.Visitor): result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts))) return self.append(ir.Alloc([buffer, shape], result_type)) - def _make_array_binop(self, name, op, result_type, lhs_type, rhs_type): + def _make_array_binop(self, name, result_type, lhs_type, rhs_type, body_gen): try: result = ir.Argument(result_type, "result") lhs = ir.Argument(lhs_type, "lhs") @@ -1461,8 +1461,8 @@ class ARTIQIRGenerator(algorithm.Visitor): ret=builtins.TNone()) env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")] - # TODO: What to use for loc? - func = ir.Function(typ, name, env_args + args, loc=None) + old_loc, self.current_loc = self.current_loc, None + func = ir.Function(typ, name, env_args + args) func.is_internal = True func.is_generated = True self.functions.append(func) @@ -1474,36 +1474,12 @@ class ARTIQIRGenerator(algorithm.Visitor): old_final_branch, self.final_branch = self.final_branch, None old_unwind, self.unwind_target = self.unwind_target, None - shape = self.append(ir.GetAttr(lhs, "shape")) - rhs_shape = self.append(ir.GetAttr(rhs, "shape")) - self._make_check( - self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)), - lambda: self.alloc_exn( - builtins.TException("ValueError"), - ir.Constant("operands could not be broadcast together", - builtins.TStr()))) - # We assume result has correct shape; could just pass buffer pointer as well. - - result_buffer = self.append(ir.GetAttr(result, "buffer")) - lhs_buffer = self.append(ir.GetAttr(lhs, "buffer")) - rhs_buffer = self.append(ir.GetAttr(rhs, "buffer")) - num_total_elts = self._get_total_array_len(shape) - - def body_gen(index): - l = self.append(ir.GetElem(lhs_buffer, index)) - r = self.append(ir.GetElem(rhs_buffer, index)) - self.append( - ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, r)))) - return self.append( - ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type))) - - self._make_loop( - ir.Constant(0, self._size_type), lambda index: self.append( - ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen) + body_gen(result, lhs, rhs) self.append(ir.Return(ir.Constant(None, builtins.TNone()))) return func finally: + self.current_loc = old_loc self.current_function = old_func self.current_block = old_block self.final_branch = old_final_branch @@ -1518,8 +1494,9 @@ class ARTIQIRGenerator(algorithm.Visitor): # rpc_tag is used to turn element types into mangled names for no # particularly good reason apart from not having to invent yet another # string representation. - return (ir.rpc_tag(typ["elt"], name_error).decode() + - str(typ["num_dims"].find().value)) + if builtins.is_array(typ): + return mangle_name(typ["elt"]) + str(typ["num_dims"].find().value) + return ir.rpc_tag(typ, name_error).decode() return "_".join(mangle_name(t) for t in types) @@ -1531,8 +1508,41 @@ class ARTIQIRGenerator(algorithm.Visitor): type(op).__name__, self._mangle_arrayop_types([result_type, lhs_type, rhs_type])) if name not in self.array_binop_funcs: + + def body_gen(result, lhs, rhs): + # TODO: Move into caller for correct location information (or pass)? + shape = self.append(ir.GetAttr(lhs, "shape")) + rhs_shape = self.append(ir.GetAttr(rhs, "shape")) + self._make_check( + self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)), + lambda: self.alloc_exn( + builtins.TException("ValueError"), + ir.Constant("operands could not be broadcast together", + builtins.TStr()))) + # We assume result has correct shape; could just pass buffer pointer + # as well. + + result_buffer = self.append(ir.GetAttr(result, "buffer")) + lhs_buffer = self.append(ir.GetAttr(lhs, "buffer")) + rhs_buffer = self.append(ir.GetAttr(rhs, "buffer")) + num_total_elts = self._get_total_array_len(shape) + + def loop_gen(index): + l = self.append(ir.GetElem(lhs_buffer, index)) + r = self.append(ir.GetElem(rhs_buffer, index)) + self.append( + ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, + r)))) + return self.append( + ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, self._size_type))) + + self._make_loop( + ir.Constant(0, self._size_type), lambda index: self.append( + ir.Compare(ast.Lt(loc=None), index, num_total_elts)), loop_gen) + self.array_binop_funcs[name] = self._make_array_binop( - name, op, result_type, lhs_type, rhs_type) + name, result_type, lhs_type, rhs_type, body_gen) return self.array_binop_funcs[name] def _invoke_arrayop(self, func, params): @@ -1545,14 +1555,162 @@ class ARTIQIRGenerator(algorithm.Visitor): self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target)) self.current_block = after_invoke + def _get_array_offset(self, shape, indices): + last_stride = None + result = indices[0] + for dim, index in zip(shape[:-1], indices[1:]): + result = self.append(ir.Arith(ast.Mult(loc=None), result, dim)) + result = self.append(ir.Arith(ast.Add(loc=None), result, index)) + return result + + def _get_matmult(self, result_type, lhs_type, rhs_type): + name = "_array_MatMult_" + self._mangle_arrayop_types( + [result_type, lhs_type, rhs_type]) + if name not in self.array_binop_funcs: + + def body_gen(result, lhs, rhs): + assert builtins.is_array(result.type), \ + "vec @ vec should have been normalised into array result" + + # We assume result has correct shape; could just pass buffer pointer + # as well. + result_buffer = self.append(ir.GetAttr(result, "buffer")) + lhs_buffer = self.append(ir.GetAttr(lhs, "buffer")) + rhs_buffer = self.append(ir.GetAttr(rhs, "buffer")) + + num_rows, num_summands, _, num_cols = self._get_matmult_shapes(lhs, rhs) + + elt = result.type["elt"].find() + env_type = ir.TEnvironment("loop", {"$total": elt}) + env = self.append(ir.Alloc([], env_type)) + + def row_loop(row_idx): + lhs_base_offset = self.append( + ir.Arith(ast.Mult(loc=None), row_idx, num_summands)) + lhs_base = self.append(ir.Offset(lhs_buffer, lhs_base_offset)) + result_base_offset = self.append( + ir.Arith(ast.Mult(loc=None), row_idx, num_cols)) + result_base = self.append( + ir.Offset(result_buffer, result_base_offset)) + + def col_loop(col_idx): + rhs_base = self.append(ir.Offset(rhs_buffer, col_idx)) + + self.append( + ir.SetLocal(env, "$total", ir.Constant(elt.zero(), elt))) + + def sum_loop(sum_idx): + lhs_elem = self.append(ir.GetElem(lhs_base, sum_idx)) + rhs_offset = self.append( + ir.Arith(ast.Mult(loc=None), sum_idx, num_cols)) + rhs_elem = self.append(ir.GetElem(rhs_base, rhs_offset)) + product = self.append( + ir.Arith(ast.Mult(loc=None), lhs_elem, rhs_elem)) + prev_total = self.append(ir.GetLocal(env, "$total")) + total = self.append( + ir.Arith(ast.Add(loc=None), prev_total, product)) + self.append(ir.SetLocal(env, "$total", total)) + return self.append( + ir.Arith(ast.Add(loc=None), sum_idx, + ir.Constant(1, self._size_type))) + + self._make_loop( + ir.Constant(0, self._size_type), lambda index: self.append( + ir.Compare(ast.Lt(loc=None), index, num_summands)), + sum_loop) + + total = self.append(ir.GetLocal(env, "$total")) + self.append(ir.SetElem(result_base, col_idx, total)) + + return self.append( + ir.Arith(ast.Add(loc=None), col_idx, + ir.Constant(1, self._size_type))) + + self._make_loop( + ir.Constant(0, self._size_type), lambda index: self.append( + ir.Compare(ast.Lt(loc=None), index, num_cols)), col_loop) + return self.append( + ir.Arith(ast.Add(loc=None), row_idx, + ir.Constant(1, self._size_type))) + + self._make_loop( + ir.Constant(0, self._size_type), lambda index: self.append( + ir.Compare(ast.Lt(loc=None), index, num_rows)), row_loop) + + self.array_binop_funcs[name] = self._make_array_binop( + name, result_type, lhs_type, rhs_type, body_gen) + return self.array_binop_funcs[name] + + def _get_matmult_shapes(self, lhs, rhs): + lhs_shape = self.append(ir.GetAttr(lhs, "shape")) + if lhs.type["num_dims"].value == 1: + lhs_shape_outer = ir.Constant(1, self._size_type) + lhs_shape_inner = self.append(ir.GetAttr(lhs_shape, 0)) + else: + lhs_shape_outer = self.append(ir.GetAttr(lhs_shape, 0)) + lhs_shape_inner = self.append(ir.GetAttr(lhs_shape, 1)) + + rhs_shape = self.append(ir.GetAttr(rhs, "shape")) + if rhs.type["num_dims"].value == 1: + rhs_shape_inner = self.append(ir.GetAttr(rhs_shape, 0)) + rhs_shape_outer = ir.Constant(1, self._size_type) + else: + rhs_shape_inner = self.append(ir.GetAttr(rhs_shape, 0)) + rhs_shape_outer = self.append(ir.GetAttr(rhs_shape, 1)) + + return lhs_shape_outer, lhs_shape_inner, rhs_shape_inner, rhs_shape_outer + + def _make_array_shape(self, dims): + return self.append(ir.Alloc(dims, types.TTuple([self._size_type] * len(dims)))) + + def _emit_matmult(self, node, left, right): + # TODO: Also expose as numpy.dot. + lhs = self.visit(left) + rhs = self.visit(right) + + num_rows, lhs_inner, rhs_inner, num_cols = self._get_matmult_shapes(lhs, rhs) + self._make_check( + self.append(ir.Compare(ast.Eq(loc=None), lhs_inner, rhs_inner)), + lambda lhs_inner, rhs_inner: self.alloc_exn( + builtins.TException("ValueError"), + ir.Constant( + "inner dimensions for matrix multiplication do not match ({0} vs. {1})", + builtins.TStr()), lhs_inner, rhs_inner), + params=[lhs_inner, rhs_inner], + loc=node.loc) + result_shape = self._make_array_shape([num_rows, num_cols]) + + final_type = node.type.find() + if not builtins.is_array(final_type): + elt = node.type + result_dims = 0 + else: + elt = final_type["elt"] + result_dims = final_type["num_dims"].value + + result = self._allocate_new_array(elt, result_shape) + func = self._get_matmult(result.type, left.type, right.type) + self._invoke_arrayop(func, [result, lhs, rhs]) + + if result_dims == 2: + return result + result_buffer = self.append(ir.GetAttr(result, "buffer")) + if result_dims == 1: + shape = self._make_array_shape( + [num_cols if lhs.type["num_dims"].value == 1 else num_rows]) + return self.append(ir.Alloc([result_buffer, shape], node.type)) + return self.append(ir.GetElem(result_buffer, ir.Constant(0, self._size_type))) + + def visit_BinOpT(self, node): - if builtins.is_array(node.type): + if isinstance(node.op, ast.MatMult): + return self._emit_matmult(node, node.left, node.right) + elif builtins.is_array(node.type): lhs = self.visit(node.left) rhs = self.visit(node.right) # Array op implementation will check for matching shape. # TODO: Broadcasts; select the widest shape. - # TODO: Detect and special-case matrix multiplication. shape = self.append(ir.GetAttr(lhs, "shape")) result = self._allocate_new_array(node.type.find()["elt"], shape) diff --git a/artiq/test/lit/integration/array_matmult.py b/artiq/test/lit/integration/array_matmult.py new file mode 100644 index 000000000..7519c10ff --- /dev/null +++ b/artiq/test/lit/integration/array_matmult.py @@ -0,0 +1,25 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s + +mat23 = array([[1, 2, 3], [4, 5, 6]]) +mat32 = array([[1, 2], [3, 4], [5, 6]]) +vec2 = array([1, 2]) +vec3 = array([1, 2, 3]) + +assert vec3 @ vec3 == 14 + +a = mat23 @ mat32 +assert a.shape == (2, 2) +assert a[0][0] == 22 +assert a[0][1] == 28 +assert a[1][0] == 49 +assert a[1][1] == 64 + +b = mat23 @ vec3 +assert b.shape == (2,) +assert b[0] == 14 +assert b[1] == 32 + +b = vec3 @ mat32 +assert b.shape == (2,) +assert b[0] == 22 +assert b[1] == 28