compiler: Implement broadcasting of math functions

This commit is contained in:
David Nadlinger 2020-08-03 01:29:39 +01:00
parent be7d78253f
commit cc00ae9580
4 changed files with 59 additions and 13 deletions

View File

@ -32,7 +32,9 @@ numpy_builtins = ["transpose"]
def unary_fp_type(name): def unary_fp_type(name):
return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]), return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]),
builtins.TFloat(), name) builtins.TFloat(),
name,
broadcast_across_arrays=True)
numpy_map = { numpy_map = {

View File

@ -1335,8 +1335,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
ret=builtins.TNone()) ret=builtins.TNone())
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")] env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
# TODO: What to use for loc? old_loc, self.current_loc = self.current_loc, None
func = ir.Function(typ, name, env_args + args, loc=None) func = ir.Function(typ, name, env_args + args)
func.is_internal = True func.is_internal = True
func.is_generated = True func.is_generated = True
self.functions.append(func) self.functions.append(func)
@ -1357,7 +1357,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
def body_gen(index): def body_gen(index):
a = self.append(ir.GetElem(arg_buffer, index)) a = self.append(ir.GetElem(arg_buffer, index))
self.append( self.append(
ir.SetElem(result_buffer, index, self.append(make_op(a)))) ir.SetElem(result_buffer, index, make_op(a)))
return self.append( return self.append(
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type))) ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
@ -1368,6 +1368,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.Return(ir.Constant(None, builtins.TNone()))) self.append(ir.Return(ir.Constant(None, builtins.TNone())))
return func return func
finally: finally:
self.current_loc = old_loc
self.current_function = old_func self.current_function = old_func
self.current_block = old_block self.current_block = old_block
self.final_branch = old_final_branch self.final_branch = old_final_branch
@ -1393,8 +1394,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
ir.Constant(-1, operand.type), operand)) ir.Constant(-1, operand.type), operand))
elif isinstance(node.op, ast.USub): elif isinstance(node.op, ast.USub):
def make_sub(val): def make_sub(val):
return ir.Arith(ast.Sub(loc=None), return self.append(ir.Arith(ast.Sub(loc=None),
ir.Constant(0, val.type), val) ir.Constant(0, val.type), val))
operand = self.visit(node.operand) operand = self.visit(node.operand)
if builtins.is_array(operand.type): if builtins.is_array(operand.type):
shape = self.append(ir.GetAttr(operand, "shape")) shape = self.append(ir.GetAttr(operand, "shape"))
@ -1403,7 +1404,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self._invoke_arrayop(func, [result, operand]) self._invoke_arrayop(func, [result, operand])
return result return result
else: else:
return self.append(make_sub(operand)) return make_sub(operand)
elif isinstance(node.op, ast.UAdd): elif isinstance(node.op, ast.UAdd):
# No-op. # No-op.
return self.visit(node.operand) return self.visit(node.operand)
@ -1419,8 +1420,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
result_elt = node.type.find()["elt"] result_elt = node.type.find()["elt"]
shape = self.append(ir.GetAttr(value, "shape")) shape = self.append(ir.GetAttr(value, "shape"))
result = self._allocate_new_array(result_elt, shape) result = self._allocate_new_array(result_elt, shape)
func = self._get_array_unaryop("Coerce", func = self._get_array_unaryop(
lambda v: ir.Coerce(v, result_elt), "Coerce", lambda v: self.append(ir.Coerce(v, result_elt)),
node.type, value.type) node.type, value.type)
self._invoke_arrayop(func, [result, value]) self._invoke_arrayop(func, [result, value])
return result return result
@ -2383,12 +2384,32 @@ class ARTIQIRGenerator(algorithm.Visitor):
if types.is_builtin(node.func.type): if types.is_builtin(node.func.type):
insn = self.visit_builtin_call(node) insn = self.visit_builtin_call(node)
elif (types.is_broadcast_across_arrays(node.func.type) and len(args) >= 1
and builtins.is_array(args[0].type)):
# The iodelay machinery set up in the surrounding code was
# deprecated/a relic from the past when array broadcasting support
# was added, so no attempt to keep the delay tracking intact is
# made.
assert len(args) == 1, "Broadcasting for multiple arguments not implemented"
def make_call(val):
return self._user_call(ir.Constant(None, callee.type), [val], {},
node.arg_exprs)
shape = self.append(ir.GetAttr(args[0], "shape"))
result = self._allocate_new_array(node.type.find()["elt"], shape)
# TODO: Generate more generically if non-externals are allowed.
name = node.func.type.find().name
func = self._get_array_unaryop(name, make_call, node.type, args[0].type)
self._invoke_arrayop(func, [result, args[0]])
insn = result
else: else:
insn = self._user_call(callee, args, keywords, node.arg_exprs) insn = self._user_call(callee, args, keywords, node.arg_exprs)
if isinstance(node.func, asttyped.AttributeT): if isinstance(node.func, asttyped.AttributeT):
attr_node = node.func attr_node = node.func
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn) self.method_map[(attr_node.value.type.find(),
attr_node.attr)].append(insn)
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0): if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
after_delay = self.add_block("delay.tail") after_delay = self.add_block("delay.tail")

View File

@ -1232,6 +1232,16 @@ class Inferencer(algorithm.Visitor):
self.engine.process(diag) self.engine.process(diag)
return return
# Array broadcasting for unary functions explicitly marked as such.
if len(node.args) == typ_arity == 1 and types.is_broadcast_across_arrays(typ):
arg_type = node.args[0].type.find()
if builtins.is_array(arg_type):
typ_arg, = typ_args.values()
self._unify(typ_arg, arg_type["elt"], node.args[0].loc, None)
self._unify(node.type, builtins.TArray(typ_ret, arg_type["num_dims"]),
node.loc, None)
return
for actualarg, (formalname, formaltyp) in \ for actualarg, (formalname, formaltyp) in \
zip(node.args, list(typ_args.items()) + list(typ_optargs.items())): zip(node.args, list(typ_args.items()) + list(typ_optargs.items())):
self._unify(actualarg.type, formaltyp, self._unify(actualarg.type, formaltyp,

View File

@ -298,11 +298,14 @@ class TExternalFunction(TFunction):
Flag ``nounwind`` means the function never raises an exception. Flag ``nounwind`` means the function never raises an exception.
Flag ``nowrite`` means the function never writes any memory Flag ``nowrite`` means the function never writes any memory
that the ARTIQ Python code can observe. that the ARTIQ Python code can observe.
:ivar broadcast_across_arrays: (bool)
If True, the function is transparently applied element-wise when called
with TArray arguments.
""" """
attributes = OrderedDict() attributes = OrderedDict()
def __init__(self, args, ret, name, flags=set()): def __init__(self, args, ret, name, flags=set(), broadcast_across_arrays=False):
assert isinstance(flags, set) assert isinstance(flags, set)
for flag in flags: for flag in flags:
assert flag in {'nounwind', 'nowrite'} assert flag in {'nounwind', 'nowrite'}
@ -310,6 +313,7 @@ class TExternalFunction(TFunction):
self.name = name self.name = name
self.delay = TFixedDelay(iodelay.Const(0)) self.delay = TFixedDelay(iodelay.Const(0))
self.flags = flags self.flags = flags
self.broadcast_across_arrays = broadcast_across_arrays
def unify(self, other): def unify(self, other):
if other is self: if other is self:
@ -644,6 +648,15 @@ def is_builtin_function(typ, name=None):
return isinstance(typ, TBuiltinFunction) and \ return isinstance(typ, TBuiltinFunction) and \
typ.name == name typ.name == name
def is_broadcast_across_arrays(typ):
# For now, broadcasting is only exposed to predefined external functions, and
# statically selected. Might be extended to user-defined functions if the design
# pans out.
typ = typ.find()
if not isinstance(typ, TExternalFunction):
return False
return typ.broadcast_across_arrays
def is_constructor(typ, name=None): def is_constructor(typ, name=None):
typ = typ.find() typ = typ.find()
if name is not None: if name is not None: