forked from M-Labs/artiq
compiler: Implement broadcasting of math functions
This commit is contained in:
parent
be7d78253f
commit
cc00ae9580
@ -32,7 +32,9 @@ numpy_builtins = ["transpose"]
|
||||
|
||||
def unary_fp_type(name):
|
||||
return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]),
|
||||
builtins.TFloat(), name)
|
||||
builtins.TFloat(),
|
||||
name,
|
||||
broadcast_across_arrays=True)
|
||||
|
||||
|
||||
numpy_map = {
|
||||
|
@ -1335,8 +1335,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
ret=builtins.TNone())
|
||||
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
||||
|
||||
# TODO: What to use for loc?
|
||||
func = ir.Function(typ, name, env_args + args, loc=None)
|
||||
old_loc, self.current_loc = self.current_loc, None
|
||||
func = ir.Function(typ, name, env_args + args)
|
||||
func.is_internal = True
|
||||
func.is_generated = True
|
||||
self.functions.append(func)
|
||||
@ -1357,7 +1357,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
def body_gen(index):
|
||||
a = self.append(ir.GetElem(arg_buffer, index))
|
||||
self.append(
|
||||
ir.SetElem(result_buffer, index, self.append(make_op(a))))
|
||||
ir.SetElem(result_buffer, index, make_op(a)))
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
|
||||
|
||||
@ -1368,6 +1368,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||
return func
|
||||
finally:
|
||||
self.current_loc = old_loc
|
||||
self.current_function = old_func
|
||||
self.current_block = old_block
|
||||
self.final_branch = old_final_branch
|
||||
@ -1393,8 +1394,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
ir.Constant(-1, operand.type), operand))
|
||||
elif isinstance(node.op, ast.USub):
|
||||
def make_sub(val):
|
||||
return ir.Arith(ast.Sub(loc=None),
|
||||
ir.Constant(0, val.type), val)
|
||||
return self.append(ir.Arith(ast.Sub(loc=None),
|
||||
ir.Constant(0, val.type), val))
|
||||
operand = self.visit(node.operand)
|
||||
if builtins.is_array(operand.type):
|
||||
shape = self.append(ir.GetAttr(operand, "shape"))
|
||||
@ -1403,7 +1404,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self._invoke_arrayop(func, [result, operand])
|
||||
return result
|
||||
else:
|
||||
return self.append(make_sub(operand))
|
||||
return make_sub(operand)
|
||||
elif isinstance(node.op, ast.UAdd):
|
||||
# No-op.
|
||||
return self.visit(node.operand)
|
||||
@ -1419,9 +1420,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
result_elt = node.type.find()["elt"]
|
||||
shape = self.append(ir.GetAttr(value, "shape"))
|
||||
result = self._allocate_new_array(result_elt, shape)
|
||||
func = self._get_array_unaryop("Coerce",
|
||||
lambda v: ir.Coerce(v, result_elt),
|
||||
node.type, value.type)
|
||||
func = self._get_array_unaryop(
|
||||
"Coerce", lambda v: self.append(ir.Coerce(v, result_elt)),
|
||||
node.type, value.type)
|
||||
self._invoke_arrayop(func, [result, value])
|
||||
return result
|
||||
else:
|
||||
@ -2383,12 +2384,32 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
|
||||
if types.is_builtin(node.func.type):
|
||||
insn = self.visit_builtin_call(node)
|
||||
elif (types.is_broadcast_across_arrays(node.func.type) and len(args) >= 1
|
||||
and builtins.is_array(args[0].type)):
|
||||
# The iodelay machinery set up in the surrounding code was
|
||||
# deprecated/a relic from the past when array broadcasting support
|
||||
# was added, so no attempt to keep the delay tracking intact is
|
||||
# made.
|
||||
assert len(args) == 1, "Broadcasting for multiple arguments not implemented"
|
||||
|
||||
def make_call(val):
|
||||
return self._user_call(ir.Constant(None, callee.type), [val], {},
|
||||
node.arg_exprs)
|
||||
|
||||
shape = self.append(ir.GetAttr(args[0], "shape"))
|
||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
|
||||
# TODO: Generate more generically if non-externals are allowed.
|
||||
name = node.func.type.find().name
|
||||
func = self._get_array_unaryop(name, make_call, node.type, args[0].type)
|
||||
self._invoke_arrayop(func, [result, args[0]])
|
||||
insn = result
|
||||
else:
|
||||
insn = self._user_call(callee, args, keywords, node.arg_exprs)
|
||||
|
||||
if isinstance(node.func, asttyped.AttributeT):
|
||||
attr_node = node.func
|
||||
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)
|
||||
self.method_map[(attr_node.value.type.find(),
|
||||
attr_node.attr)].append(insn)
|
||||
|
||||
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
||||
after_delay = self.add_block("delay.tail")
|
||||
|
@ -1232,6 +1232,16 @@ class Inferencer(algorithm.Visitor):
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
# Array broadcasting for unary functions explicitly marked as such.
|
||||
if len(node.args) == typ_arity == 1 and types.is_broadcast_across_arrays(typ):
|
||||
arg_type = node.args[0].type.find()
|
||||
if builtins.is_array(arg_type):
|
||||
typ_arg, = typ_args.values()
|
||||
self._unify(typ_arg, arg_type["elt"], node.args[0].loc, None)
|
||||
self._unify(node.type, builtins.TArray(typ_ret, arg_type["num_dims"]),
|
||||
node.loc, None)
|
||||
return
|
||||
|
||||
for actualarg, (formalname, formaltyp) in \
|
||||
zip(node.args, list(typ_args.items()) + list(typ_optargs.items())):
|
||||
self._unify(actualarg.type, formaltyp,
|
||||
|
@ -298,11 +298,14 @@ class TExternalFunction(TFunction):
|
||||
Flag ``nounwind`` means the function never raises an exception.
|
||||
Flag ``nowrite`` means the function never writes any memory
|
||||
that the ARTIQ Python code can observe.
|
||||
:ivar broadcast_across_arrays: (bool)
|
||||
If True, the function is transparently applied element-wise when called
|
||||
with TArray arguments.
|
||||
"""
|
||||
|
||||
attributes = OrderedDict()
|
||||
|
||||
def __init__(self, args, ret, name, flags=set()):
|
||||
def __init__(self, args, ret, name, flags=set(), broadcast_across_arrays=False):
|
||||
assert isinstance(flags, set)
|
||||
for flag in flags:
|
||||
assert flag in {'nounwind', 'nowrite'}
|
||||
@ -310,6 +313,7 @@ class TExternalFunction(TFunction):
|
||||
self.name = name
|
||||
self.delay = TFixedDelay(iodelay.Const(0))
|
||||
self.flags = flags
|
||||
self.broadcast_across_arrays = broadcast_across_arrays
|
||||
|
||||
def unify(self, other):
|
||||
if other is self:
|
||||
@ -644,6 +648,15 @@ def is_builtin_function(typ, name=None):
|
||||
return isinstance(typ, TBuiltinFunction) and \
|
||||
typ.name == name
|
||||
|
||||
def is_broadcast_across_arrays(typ):
|
||||
# For now, broadcasting is only exposed to predefined external functions, and
|
||||
# statically selected. Might be extended to user-defined functions if the design
|
||||
# pans out.
|
||||
typ = typ.find()
|
||||
if not isinstance(typ, TExternalFunction):
|
||||
return False
|
||||
return typ.broadcast_across_arrays
|
||||
|
||||
def is_constructor(typ, name=None):
|
||||
typ = typ.find()
|
||||
if name is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user