1
0
forked from M-Labs/artiq

compiler: Implement unary plus/minus for arrays

Implementation is needlessly generic to anticipate
coercion/transcendental functions.
This commit is contained in:
David Nadlinger 2020-07-30 00:09:12 +01:00
parent 0d8fbd4f19
commit 4426e4144f
5 changed files with 132 additions and 27 deletions

View File

@ -88,6 +88,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
necessary. They are kept track of in global dictionaries, with a mangled name
containing types and operations as key:
:ivar array_unaryop_funcs: the map from mangled name to implementation of unary
operations for arrays
:ivar array_binop_funcs: the map from mangled name to implementation of binary
operations between arrays
"""
@ -118,6 +120,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.function_map = dict()
self.variable_map = dict()
self.method_map = defaultdict(lambda: [])
self.array_unaryop_funcs = dict()
self.array_binop_funcs = dict()
def annotate_calls(self, devirtualization):
@ -1316,6 +1319,68 @@ class ARTIQIRGenerator(algorithm.Visitor):
value_tail.append(ir.Branch(tail))
return phi
def _make_array_unaryop(self, name, make_op, result_type, arg_type):
try:
result = ir.Argument(result_type, "result")
arg = ir.Argument(arg_type, "arg")
# TODO: We'd like to use a "C function" here to be able to supply
# specialised implementations in a library in the future (and e.g. avoid
# passing around the context argument), but the code generator currently
# doesn't allow emitting them.
args = [result, arg]
typ = types.TFunction(args=OrderedDict([(arg.name, arg.type)
for arg in args]),
optargs=OrderedDict(),
ret=builtins.TNone())
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
# TODO: What to use for loc?
func = ir.Function(typ, name, env_args + args, loc=None)
func.is_internal = True
func.is_generated = True
self.functions.append(func)
old_func, self.current_function = self.current_function, func
entry = self.add_block("entry")
old_block, self.current_block = self.current_block, entry
old_final_branch, self.final_branch = self.final_branch, None
old_unwind, self.unwind_target = self.unwind_target, None
shape = self.append(ir.GetAttr(arg, "shape"))
result_buffer = self.append(ir.GetAttr(result, "buffer"))
arg_buffer = self.append(ir.GetAttr(arg, "buffer"))
num_total_elts = self._get_total_array_len(shape)
def body_gen(index):
a = self.append(ir.GetElem(arg_buffer, index))
self.append(
ir.SetElem(result_buffer, index, self.append(make_op(a))))
return self.append(
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
return func
finally:
self.current_function = old_func
self.current_block = old_block
self.final_branch = old_final_branch
self.unwind_target = old_unwind
def _get_array_unaryop(self, name, make_op, result_type, arg_type):
name = "_array_{}_{}".format(
name, self._mangle_arrayop_types([result_type, arg_type]))
if name not in self.array_unaryop_funcs:
self.array_binop_funcs[name] = self._make_array_unaryop(
name, make_op, result_type, arg_type)
return self.array_binop_funcs[name]
def visit_UnaryOpT(self, node):
if isinstance(node.op, ast.Not):
cond = self.coerce_to_bool(self.visit(node.operand))
@ -1327,9 +1392,18 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Arith(ast.BitXor(loc=None),
ir.Constant(-1, operand.type), operand))
elif isinstance(node.op, ast.USub):
def make_sub(val):
return ir.Arith(ast.Sub(loc=None),
ir.Constant(0, val.type), val)
operand = self.visit(node.operand)
return self.append(ir.Arith(ast.Sub(loc=None),
ir.Constant(0, operand.type), operand))
if builtins.is_array(operand.type):
shape = self.append(ir.GetAttr(operand, "shape"))
result = self._alloate_new_array(node.type.find()["elt"], shape)
func = self._get_array_unaryop("USub", make_sub, node.type, operand.type)
self._invoke_arrayop(func, [result, operand])
return result
else:
return self.append(make_sub(operand))
elif isinstance(node.op, ast.UAdd):
# No-op.
return self.visit(node.operand)
@ -1423,10 +1497,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.final_branch = old_final_branch
self.unwind_target = old_unwind
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
# Currently, we always have any type coercions resolved explicitly in the AST.
# In the future, this might no longer be true and the three types might all
# differ.
def _mangle_arrayop_types(self, types):
def name_error(typ):
assert False, "Internal compiler error: No RPC tag for {}".format(typ)
@ -1438,13 +1509,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
return (ir.rpc_tag(typ["elt"], name_error).decode() +
str(typ["num_dims"].find().value))
name = "_array_{}_{}_{}_{}".format(
type(op).__name__, *(map(mangle_name, (result_type, lhs_type, rhs_type))))
return "_".join(mangle_name(t) for t in types)
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
# Currently, we always have any type coercions resolved explicitly in the AST.
# In the future, this might no longer be true and the three types might all
# differ.
name = "_array_{}_{}".format(
type(op).__name__,
self._mangle_arrayop_types([result_type, lhs_type, rhs_type]))
if name not in self.array_binop_funcs:
self.array_binop_funcs[name] = self._make_array_binop(
name, op, result_type, lhs_type, rhs_type)
return self.array_binop_funcs[name]
def _invoke_arrayop(self, func, params):
closure = self.append(
ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
if self.unwind_target is None:
self.append(ir.Call(closure, params, {}))
else:
after_invoke = self.add_block("arrayop.invoke")
self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
self.current_block = after_invoke
def visit_BinOpT(self, node):
if builtins.is_array(node.type):
lhs = self.visit(node.left)
@ -1457,14 +1545,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
result = self._alloate_new_array(node.type.find()["elt"], shape)
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
params = [result, lhs, rhs]
if self.unwind_target is None:
insn = self.append(ir.Call(closure, params, {}))
else:
after_invoke = self.add_block("arrayop.invoke")
insn = self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
self.current_block = after_invoke
self._invoke_arrayop(func, [result, lhs, rhs])
return result
elif builtins.is_numeric(node.type):
lhs = self.visit(node.left)

View File

@ -298,16 +298,27 @@ class Inferencer(algorithm.Visitor):
node.operand.loc)
self.engine.process(diag)
else: # UAdd, USub
if types.is_var(operand_type):
return
if builtins.is_numeric(operand_type):
self._unify(node.type, operand_type,
node.loc, None)
elif not types.is_var(operand_type):
diag = diagnostic.Diagnostic("error",
"expected unary '{op}' operand to be of numeric type, not {type}",
{"op": node.op.loc.source(),
"type": types.TypePrinter().name(operand_type)},
node.operand.loc)
self.engine.process(diag)
self._unify(node.type, operand_type, node.loc, None)
return
if builtins.is_array(operand_type):
elt = operand_type.find()["elt"]
if builtins.is_numeric(elt):
self._unify(node.type, operand_type, node.loc, None)
return
if types.is_var(elt):
return
diag = diagnostic.Diagnostic("error",
"expected unary '{op}' operand to be of numeric type, not {type}",
{"op": node.op.loc.source(),
"type": types.TypePrinter().name(operand_type)},
node.operand.loc)
self.engine.process(diag)
def visit_CoerceT(self, node):
self.generic_visit(node)
@ -436,7 +447,8 @@ class Inferencer(algorithm.Visitor):
return typ
def map_return(typ):
a = builtins.TArray(elt=typ, num_dims=left_dims)
elt = builtins.TFloat() if isinstance(op, ast.Div) else typ
a = builtins.TArray(elt=elt, num_dims=left_dims)
return (a, a, a)
return self._coerce_numeric((left, right),

View File

@ -156,7 +156,7 @@ class RegionOf(algorithm.Visitor):
visit_NameConstantT = visit_immutable
visit_NumT = visit_immutable
visit_EllipsisT = visit_immutable
visit_UnaryOpT = visit_immutable
visit_UnaryOpT = visit_sometimes_allocating # possibly array op
visit_CompareT = visit_immutable
# Value lives forever

View File

@ -0,0 +1,11 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
a = array([1, 2])
b = +a
assert b[0] == 1
assert b[1] == 2
b = -a
assert b[0] == -1
assert b[1] == -2