forked from M-Labs/artiq
compiler: Implement broadcasting of math functions
This commit is contained in:
parent
be7d78253f
commit
cc00ae9580
|
@ -32,7 +32,9 @@ numpy_builtins = ["transpose"]
|
||||||
|
|
||||||
def unary_fp_type(name):
|
def unary_fp_type(name):
|
||||||
return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]),
|
return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]),
|
||||||
builtins.TFloat(), name)
|
builtins.TFloat(),
|
||||||
|
name,
|
||||||
|
broadcast_across_arrays=True)
|
||||||
|
|
||||||
|
|
||||||
numpy_map = {
|
numpy_map = {
|
||||||
|
|
|
@ -1335,8 +1335,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
ret=builtins.TNone())
|
ret=builtins.TNone())
|
||||||
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
||||||
|
|
||||||
# TODO: What to use for loc?
|
old_loc, self.current_loc = self.current_loc, None
|
||||||
func = ir.Function(typ, name, env_args + args, loc=None)
|
func = ir.Function(typ, name, env_args + args)
|
||||||
func.is_internal = True
|
func.is_internal = True
|
||||||
func.is_generated = True
|
func.is_generated = True
|
||||||
self.functions.append(func)
|
self.functions.append(func)
|
||||||
|
@ -1357,7 +1357,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
def body_gen(index):
|
def body_gen(index):
|
||||||
a = self.append(ir.GetElem(arg_buffer, index))
|
a = self.append(ir.GetElem(arg_buffer, index))
|
||||||
self.append(
|
self.append(
|
||||||
ir.SetElem(result_buffer, index, self.append(make_op(a))))
|
ir.SetElem(result_buffer, index, make_op(a)))
|
||||||
return self.append(
|
return self.append(
|
||||||
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
|
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
|
||||||
|
|
||||||
|
@ -1368,6 +1368,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||||
return func
|
return func
|
||||||
finally:
|
finally:
|
||||||
|
self.current_loc = old_loc
|
||||||
self.current_function = old_func
|
self.current_function = old_func
|
||||||
self.current_block = old_block
|
self.current_block = old_block
|
||||||
self.final_branch = old_final_branch
|
self.final_branch = old_final_branch
|
||||||
|
@ -1393,8 +1394,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
ir.Constant(-1, operand.type), operand))
|
ir.Constant(-1, operand.type), operand))
|
||||||
elif isinstance(node.op, ast.USub):
|
elif isinstance(node.op, ast.USub):
|
||||||
def make_sub(val):
|
def make_sub(val):
|
||||||
return ir.Arith(ast.Sub(loc=None),
|
return self.append(ir.Arith(ast.Sub(loc=None),
|
||||||
ir.Constant(0, val.type), val)
|
ir.Constant(0, val.type), val))
|
||||||
operand = self.visit(node.operand)
|
operand = self.visit(node.operand)
|
||||||
if builtins.is_array(operand.type):
|
if builtins.is_array(operand.type):
|
||||||
shape = self.append(ir.GetAttr(operand, "shape"))
|
shape = self.append(ir.GetAttr(operand, "shape"))
|
||||||
|
@ -1403,7 +1404,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
self._invoke_arrayop(func, [result, operand])
|
self._invoke_arrayop(func, [result, operand])
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return self.append(make_sub(operand))
|
return make_sub(operand)
|
||||||
elif isinstance(node.op, ast.UAdd):
|
elif isinstance(node.op, ast.UAdd):
|
||||||
# No-op.
|
# No-op.
|
||||||
return self.visit(node.operand)
|
return self.visit(node.operand)
|
||||||
|
@ -1419,8 +1420,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
result_elt = node.type.find()["elt"]
|
result_elt = node.type.find()["elt"]
|
||||||
shape = self.append(ir.GetAttr(value, "shape"))
|
shape = self.append(ir.GetAttr(value, "shape"))
|
||||||
result = self._allocate_new_array(result_elt, shape)
|
result = self._allocate_new_array(result_elt, shape)
|
||||||
func = self._get_array_unaryop("Coerce",
|
func = self._get_array_unaryop(
|
||||||
lambda v: ir.Coerce(v, result_elt),
|
"Coerce", lambda v: self.append(ir.Coerce(v, result_elt)),
|
||||||
node.type, value.type)
|
node.type, value.type)
|
||||||
self._invoke_arrayop(func, [result, value])
|
self._invoke_arrayop(func, [result, value])
|
||||||
return result
|
return result
|
||||||
|
@ -2383,12 +2384,32 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
|
|
||||||
if types.is_builtin(node.func.type):
|
if types.is_builtin(node.func.type):
|
||||||
insn = self.visit_builtin_call(node)
|
insn = self.visit_builtin_call(node)
|
||||||
|
elif (types.is_broadcast_across_arrays(node.func.type) and len(args) >= 1
|
||||||
|
and builtins.is_array(args[0].type)):
|
||||||
|
# The iodelay machinery set up in the surrounding code was
|
||||||
|
# deprecated/a relic from the past when array broadcasting support
|
||||||
|
# was added, so no attempt to keep the delay tracking intact is
|
||||||
|
# made.
|
||||||
|
assert len(args) == 1, "Broadcasting for multiple arguments not implemented"
|
||||||
|
|
||||||
|
def make_call(val):
|
||||||
|
return self._user_call(ir.Constant(None, callee.type), [val], {},
|
||||||
|
node.arg_exprs)
|
||||||
|
|
||||||
|
shape = self.append(ir.GetAttr(args[0], "shape"))
|
||||||
|
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||||
|
|
||||||
|
# TODO: Generate more generically if non-externals are allowed.
|
||||||
|
name = node.func.type.find().name
|
||||||
|
func = self._get_array_unaryop(name, make_call, node.type, args[0].type)
|
||||||
|
self._invoke_arrayop(func, [result, args[0]])
|
||||||
|
insn = result
|
||||||
else:
|
else:
|
||||||
insn = self._user_call(callee, args, keywords, node.arg_exprs)
|
insn = self._user_call(callee, args, keywords, node.arg_exprs)
|
||||||
|
|
||||||
if isinstance(node.func, asttyped.AttributeT):
|
if isinstance(node.func, asttyped.AttributeT):
|
||||||
attr_node = node.func
|
attr_node = node.func
|
||||||
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)
|
self.method_map[(attr_node.value.type.find(),
|
||||||
|
attr_node.attr)].append(insn)
|
||||||
|
|
||||||
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
|
||||||
after_delay = self.add_block("delay.tail")
|
after_delay = self.add_block("delay.tail")
|
||||||
|
|
|
@ -1232,6 +1232,16 @@ class Inferencer(algorithm.Visitor):
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Array broadcasting for unary functions explicitly marked as such.
|
||||||
|
if len(node.args) == typ_arity == 1 and types.is_broadcast_across_arrays(typ):
|
||||||
|
arg_type = node.args[0].type.find()
|
||||||
|
if builtins.is_array(arg_type):
|
||||||
|
typ_arg, = typ_args.values()
|
||||||
|
self._unify(typ_arg, arg_type["elt"], node.args[0].loc, None)
|
||||||
|
self._unify(node.type, builtins.TArray(typ_ret, arg_type["num_dims"]),
|
||||||
|
node.loc, None)
|
||||||
|
return
|
||||||
|
|
||||||
for actualarg, (formalname, formaltyp) in \
|
for actualarg, (formalname, formaltyp) in \
|
||||||
zip(node.args, list(typ_args.items()) + list(typ_optargs.items())):
|
zip(node.args, list(typ_args.items()) + list(typ_optargs.items())):
|
||||||
self._unify(actualarg.type, formaltyp,
|
self._unify(actualarg.type, formaltyp,
|
||||||
|
|
|
@ -298,11 +298,14 @@ class TExternalFunction(TFunction):
|
||||||
Flag ``nounwind`` means the function never raises an exception.
|
Flag ``nounwind`` means the function never raises an exception.
|
||||||
Flag ``nowrite`` means the function never writes any memory
|
Flag ``nowrite`` means the function never writes any memory
|
||||||
that the ARTIQ Python code can observe.
|
that the ARTIQ Python code can observe.
|
||||||
|
:ivar broadcast_across_arrays: (bool)
|
||||||
|
If True, the function is transparently applied element-wise when called
|
||||||
|
with TArray arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = OrderedDict()
|
attributes = OrderedDict()
|
||||||
|
|
||||||
def __init__(self, args, ret, name, flags=set()):
|
def __init__(self, args, ret, name, flags=set(), broadcast_across_arrays=False):
|
||||||
assert isinstance(flags, set)
|
assert isinstance(flags, set)
|
||||||
for flag in flags:
|
for flag in flags:
|
||||||
assert flag in {'nounwind', 'nowrite'}
|
assert flag in {'nounwind', 'nowrite'}
|
||||||
|
@ -310,6 +313,7 @@ class TExternalFunction(TFunction):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.delay = TFixedDelay(iodelay.Const(0))
|
self.delay = TFixedDelay(iodelay.Const(0))
|
||||||
self.flags = flags
|
self.flags = flags
|
||||||
|
self.broadcast_across_arrays = broadcast_across_arrays
|
||||||
|
|
||||||
def unify(self, other):
|
def unify(self, other):
|
||||||
if other is self:
|
if other is self:
|
||||||
|
@ -644,6 +648,15 @@ def is_builtin_function(typ, name=None):
|
||||||
return isinstance(typ, TBuiltinFunction) and \
|
return isinstance(typ, TBuiltinFunction) and \
|
||||||
typ.name == name
|
typ.name == name
|
||||||
|
|
||||||
|
def is_broadcast_across_arrays(typ):
|
||||||
|
# For now, broadcasting is only exposed to predefined external functions, and
|
||||||
|
# statically selected. Might be extended to user-defined functions if the design
|
||||||
|
# pans out.
|
||||||
|
typ = typ.find()
|
||||||
|
if not isinstance(typ, TExternalFunction):
|
||||||
|
return False
|
||||||
|
return typ.broadcast_across_arrays
|
||||||
|
|
||||||
def is_constructor(typ, name=None):
|
def is_constructor(typ, name=None):
|
||||||
typ = typ.find()
|
typ = typ.find()
|
||||||
if name is not None:
|
if name is not None:
|
||||||
|
|
Loading…
Reference in New Issue