diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index c4d469cd7..22482c8c0 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -83,6 +83,13 @@ class ARTIQIRGenerator(algorithm.Visitor): :ivar method_map: (map of :class:`ast.AttributeT` to :class:`ir.GetAttribute`) the map from method resolution nodes to instructions retrieving the called function inside a translated :class:`ast.CallT` node + + Finally, functions that implement array operations are instantiated on the fly as + necessary. They are kept track of in global dictionaries, with a mangled name + containing types and operations as key: + + :ivar array_binop_funcs: the map from mangled name to implementation of binary + operations between arrays """ _size_type = builtins.TInt32() @@ -111,6 +118,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.function_map = dict() self.variable_map = dict() self.method_map = defaultdict(lambda: []) + self.array_binop_funcs = dict() def annotate_calls(self, devirtualization): for var_node in devirtualization.variable_map: @@ -1337,8 +1345,124 @@ class ARTIQIRGenerator(algorithm.Visitor): name="{}.{}".format(_readable_name(value), node.type.name))) + def _get_total_array_len(self, shape): + lengths = [ + self.append(ir.GetAttr(shape, i)) for i in range(len(shape.type.elts)) + ] + return reduce(lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)), + lengths[1:], lengths[0]) + + def _alloate_new_array(self, elt, shape): + total_length = self._get_total_array_len(shape) + buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt))) + result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts))) + return self.append(ir.Alloc([buffer, shape], result_type)) + + def _make_array_binop(self, name, op, result_type, lhs_type, rhs_type): + try: + result = ir.Argument(result_type, "result") + lhs = ir.Argument(lhs_type, "lhs") + rhs = ir.Argument(rhs_type, "rhs") + + # TODO: We'd like to use a "C function" here to be able to supply + # specialised implementations in a library in the future (and e.g. avoid + # passing around the context argument), but the code generator currently + # doesn't allow emitting them. + args = [result, lhs, rhs] + typ = types.TFunction(args=OrderedDict([(arg.name, arg.type) + for arg in args]), + optargs=OrderedDict(), + ret=builtins.TNone()) + env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")] + + # TODO: What to use for loc? + func = ir.Function(typ, name, env_args + args, loc=None) + func.is_internal = True + func.is_generated = True + self.functions.append(func) + old_func, self.current_function = self.current_function, func + + entry = self.add_block("entry") + old_block, self.current_block = self.current_block, entry + + old_final_branch, self.final_branch = self.final_branch, None + old_unwind, self.unwind_target = self.unwind_target, None + + shape = self.append(ir.GetAttr(lhs, "shape")) + rhs_shape = self.append(ir.GetAttr(rhs, "shape")) + self._make_check( + self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)), + lambda: self.alloc_exn( + builtins.TException("ValueError"), + ir.Constant("operands could not be broadcast together", + builtins.TStr()))) + # We assume result has correct shape; could just pass buffer pointer as well. + + result_buffer = self.append(ir.GetAttr(result, "buffer")) + lhs_buffer = self.append(ir.GetAttr(lhs, "buffer")) + rhs_buffer = self.append(ir.GetAttr(rhs, "buffer")) + num_total_elts = self._get_total_array_len(shape) + + def body_gen(index): + l = self.append(ir.GetElem(lhs_buffer, index)) + r = self.append(ir.GetElem(rhs_buffer, index)) + self.append( + ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, r)))) + return self.append( + ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type))) + + self._make_loop( + ir.Constant(0, self._size_type), lambda index: self.append( + ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen) + + self.append(ir.Return(ir.Constant(None, builtins.TNone()))) + return func + finally: + self.current_function = old_func + self.current_block = old_block + self.final_branch = old_final_branch + self.unwind_target = old_unwind + + def _get_array_binop(self, op, result_type, lhs_type, rhs_type): + # Currently, we always have any type coercions resolved explicitly in the AST. + # In the future, this might no longer be true and the three types might all + # differ. + def name_error(typ): + assert False, "Internal compiler error: No RPC tag for {}".format(typ) + def mangle_name(typ): + typ = typ.find() + return ir.rpc_tag(typ["elt"], name_error).decode() +\ + str(typ["num_dims"].find().value) + name = "_array_{}_{}_{}_{}".format( + type(op).__name__, + *(map(mangle_name, (result_type, lhs_type, rhs_type)))) + if name not in self.array_binop_funcs: + self.array_binop_funcs[name] = self._make_array_binop( + name, op, result_type, lhs_type, rhs_type) + return self.array_binop_funcs[name] + def visit_BinOpT(self, node): - if builtins.is_numeric(node.type): + if builtins.is_array(node.type): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + + # Array op implementation will check for matching shape. + # TODO: Broadcasts; select the widest shape. + # TODO: Detect and special-case matrix multiplication. + shape = self.append(ir.GetAttr(lhs, "shape")) + result = self._alloate_new_array(node.type.find()["elt"], shape) + + func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type) + closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {})))) + params = [result, lhs, rhs] + if self.unwind_target is None: + insn = self.append(ir.Call(closure, params, {})) + else: + after_invoke = self.add_block("arrayop.invoke") + insn = self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target)) + self.current_block = after_invoke + return result + elif builtins.is_numeric(node.type): lhs = self.visit(node.left) rhs = self.visit(node.right) if isinstance(node.op, (ast.LShift, ast.RShift)): @@ -1703,11 +1827,8 @@ class ARTIQIRGenerator(algorithm.Visitor): ir.Constant(0, self._size_type)) lengths.append(self.iterable_len(first_elt)) - num_total_elts = reduce( - lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)), - lengths[1:], lengths[0]) - shape = self.append(ir.Alloc(lengths, result_type.attributes["shape"])) + num_total_elts = self._get_total_array_len(shape) # Assign buffer from nested iterables. buffer = self.append( diff --git a/artiq/test/lit/integration/array_ops.py b/artiq/test/lit/integration/array_ops.py new file mode 100644 index 000000000..72d2f0812 --- /dev/null +++ b/artiq/test/lit/integration/array_ops.py @@ -0,0 +1,19 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s + +a = array([1, 2, 3]) +b = array([4, 5, 6]) + +c = a + b +assert c[0] == 5 +assert c[1] == 7 +assert c[2] == 9 + +c = a * b +assert c[0] == 4 +assert c[1] == 10 +assert c[2] == 18 + +c = b // a +assert c[0] == 4 +assert c[1] == 2 +assert c[2] == 2