diff --git a/artiq/compiler/math_fns.py b/artiq/compiler/math_fns.py index 4e67dfef0..5ae80ea4d 100644 --- a/artiq/compiler/math_fns.py +++ b/artiq/compiler/math_fns.py @@ -32,7 +32,9 @@ numpy_builtins = ["transpose"] def unary_fp_type(name): return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]), - builtins.TFloat(), name) + builtins.TFloat(), + name, + broadcast_across_arrays=True) numpy_map = { diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index ab162288d..39731b709 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1335,8 +1335,8 @@ class ARTIQIRGenerator(algorithm.Visitor): ret=builtins.TNone()) env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")] - # TODO: What to use for loc? - func = ir.Function(typ, name, env_args + args, loc=None) + old_loc, self.current_loc = self.current_loc, None + func = ir.Function(typ, name, env_args + args) func.is_internal = True func.is_generated = True self.functions.append(func) @@ -1357,7 +1357,7 @@ class ARTIQIRGenerator(algorithm.Visitor): def body_gen(index): a = self.append(ir.GetElem(arg_buffer, index)) self.append( - ir.SetElem(result_buffer, index, self.append(make_op(a)))) + ir.SetElem(result_buffer, index, make_op(a))) return self.append( ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type))) @@ -1368,6 +1368,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.append(ir.Return(ir.Constant(None, builtins.TNone()))) return func finally: + self.current_loc = old_loc self.current_function = old_func self.current_block = old_block self.final_branch = old_final_branch @@ -1393,8 +1394,8 @@ class ARTIQIRGenerator(algorithm.Visitor): ir.Constant(-1, operand.type), operand)) elif isinstance(node.op, ast.USub): def make_sub(val): - return ir.Arith(ast.Sub(loc=None), - ir.Constant(0, val.type), val) + return self.append(ir.Arith(ast.Sub(loc=None), + ir.Constant(0, val.type), val)) operand = self.visit(node.operand) if builtins.is_array(operand.type): shape = self.append(ir.GetAttr(operand, "shape")) @@ -1403,7 +1404,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self._invoke_arrayop(func, [result, operand]) return result else: - return self.append(make_sub(operand)) + return make_sub(operand) elif isinstance(node.op, ast.UAdd): # No-op. return self.visit(node.operand) @@ -1419,9 +1420,9 @@ class ARTIQIRGenerator(algorithm.Visitor): result_elt = node.type.find()["elt"] shape = self.append(ir.GetAttr(value, "shape")) result = self._allocate_new_array(result_elt, shape) - func = self._get_array_unaryop("Coerce", - lambda v: ir.Coerce(v, result_elt), - node.type, value.type) + func = self._get_array_unaryop( + "Coerce", lambda v: self.append(ir.Coerce(v, result_elt)), + node.type, value.type) self._invoke_arrayop(func, [result, value]) return result else: @@ -2383,12 +2384,32 @@ class ARTIQIRGenerator(algorithm.Visitor): if types.is_builtin(node.func.type): insn = self.visit_builtin_call(node) + elif (types.is_broadcast_across_arrays(node.func.type) and len(args) >= 1 + and builtins.is_array(args[0].type)): + # The iodelay machinery set up in the surrounding code was + # deprecated/a relic from the past when array broadcasting support + # was added, so no attempt to keep the delay tracking intact is + # made. + assert len(args) == 1, "Broadcasting for multiple arguments not implemented" + + def make_call(val): + return self._user_call(ir.Constant(None, callee.type), [val], {}, + node.arg_exprs) + + shape = self.append(ir.GetAttr(args[0], "shape")) + result = self._allocate_new_array(node.type.find()["elt"], shape) + + # TODO: Generate more generically if non-externals are allowed. + name = node.func.type.find().name + func = self._get_array_unaryop(name, make_call, node.type, args[0].type) + self._invoke_arrayop(func, [result, args[0]]) + insn = result else: insn = self._user_call(callee, args, keywords, node.arg_exprs) - if isinstance(node.func, asttyped.AttributeT): attr_node = node.func - self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn) + self.method_map[(attr_node.value.type.find(), + attr_node.attr)].append(insn) if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0): after_delay = self.add_block("delay.tail") diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index d017af963..6b765c2c7 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -1232,6 +1232,16 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) return + # Array broadcasting for unary functions explicitly marked as such. + if len(node.args) == typ_arity == 1 and types.is_broadcast_across_arrays(typ): + arg_type = node.args[0].type.find() + if builtins.is_array(arg_type): + typ_arg, = typ_args.values() + self._unify(typ_arg, arg_type["elt"], node.args[0].loc, None) + self._unify(node.type, builtins.TArray(typ_ret, arg_type["num_dims"]), + node.loc, None) + return + for actualarg, (formalname, formaltyp) in \ zip(node.args, list(typ_args.items()) + list(typ_optargs.items())): self._unify(actualarg.type, formaltyp, diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index 281b53577..e7b68a3a4 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -298,11 +298,14 @@ class TExternalFunction(TFunction): Flag ``nounwind`` means the function never raises an exception. Flag ``nowrite`` means the function never writes any memory that the ARTIQ Python code can observe. + :ivar broadcast_across_arrays: (bool) + If True, the function is transparently applied element-wise when called + with TArray arguments. """ attributes = OrderedDict() - def __init__(self, args, ret, name, flags=set()): + def __init__(self, args, ret, name, flags=set(), broadcast_across_arrays=False): assert isinstance(flags, set) for flag in flags: assert flag in {'nounwind', 'nowrite'} @@ -310,6 +313,7 @@ class TExternalFunction(TFunction): self.name = name self.delay = TFixedDelay(iodelay.Const(0)) self.flags = flags + self.broadcast_across_arrays = broadcast_across_arrays def unify(self, other): if other is self: @@ -644,6 +648,15 @@ def is_builtin_function(typ, name=None): return isinstance(typ, TBuiltinFunction) and \ typ.name == name +def is_broadcast_across_arrays(typ): + # For now, broadcasting is only exposed to predefined external functions, and + # statically selected. Might be extended to user-defined functions if the design + # pans out. + typ = typ.find() + if not isinstance(typ, TExternalFunction): + return False + return typ.broadcast_across_arrays + def is_constructor(typ, name=None): typ = typ.find() if name is not None: