mirror of
https://github.com/m-labs/artiq.git
synced 2025-01-25 09:58:13 +08:00
compiler: Implement unary plus/minus for arrays
Implementation is needlessly generic to anticipate coercion/transcendental functions.
This commit is contained in:
parent
0d8fbd4f19
commit
4426e4144f
@ -88,6 +88,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
necessary. They are kept track of in global dictionaries, with a mangled name
|
||||
containing types and operations as key:
|
||||
|
||||
:ivar array_unaryop_funcs: the map from mangled name to implementation of unary
|
||||
operations for arrays
|
||||
:ivar array_binop_funcs: the map from mangled name to implementation of binary
|
||||
operations between arrays
|
||||
"""
|
||||
@ -118,6 +120,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.function_map = dict()
|
||||
self.variable_map = dict()
|
||||
self.method_map = defaultdict(lambda: [])
|
||||
self.array_unaryop_funcs = dict()
|
||||
self.array_binop_funcs = dict()
|
||||
|
||||
def annotate_calls(self, devirtualization):
|
||||
@ -1316,6 +1319,68 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
value_tail.append(ir.Branch(tail))
|
||||
return phi
|
||||
|
||||
def _make_array_unaryop(self, name, make_op, result_type, arg_type):
|
||||
try:
|
||||
result = ir.Argument(result_type, "result")
|
||||
arg = ir.Argument(arg_type, "arg")
|
||||
|
||||
# TODO: We'd like to use a "C function" here to be able to supply
|
||||
# specialised implementations in a library in the future (and e.g. avoid
|
||||
# passing around the context argument), but the code generator currently
|
||||
# doesn't allow emitting them.
|
||||
args = [result, arg]
|
||||
typ = types.TFunction(args=OrderedDict([(arg.name, arg.type)
|
||||
for arg in args]),
|
||||
optargs=OrderedDict(),
|
||||
ret=builtins.TNone())
|
||||
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
||||
|
||||
# TODO: What to use for loc?
|
||||
func = ir.Function(typ, name, env_args + args, loc=None)
|
||||
func.is_internal = True
|
||||
func.is_generated = True
|
||||
self.functions.append(func)
|
||||
old_func, self.current_function = self.current_function, func
|
||||
|
||||
entry = self.add_block("entry")
|
||||
old_block, self.current_block = self.current_block, entry
|
||||
|
||||
old_final_branch, self.final_branch = self.final_branch, None
|
||||
old_unwind, self.unwind_target = self.unwind_target, None
|
||||
|
||||
shape = self.append(ir.GetAttr(arg, "shape"))
|
||||
|
||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||
arg_buffer = self.append(ir.GetAttr(arg, "buffer"))
|
||||
num_total_elts = self._get_total_array_len(shape)
|
||||
|
||||
def body_gen(index):
|
||||
a = self.append(ir.GetElem(arg_buffer, index))
|
||||
self.append(
|
||||
ir.SetElem(result_buffer, index, self.append(make_op(a))))
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
|
||||
|
||||
self._make_loop(
|
||||
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
||||
|
||||
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||
return func
|
||||
finally:
|
||||
self.current_function = old_func
|
||||
self.current_block = old_block
|
||||
self.final_branch = old_final_branch
|
||||
self.unwind_target = old_unwind
|
||||
|
||||
def _get_array_unaryop(self, name, make_op, result_type, arg_type):
|
||||
name = "_array_{}_{}".format(
|
||||
name, self._mangle_arrayop_types([result_type, arg_type]))
|
||||
if name not in self.array_unaryop_funcs:
|
||||
self.array_binop_funcs[name] = self._make_array_unaryop(
|
||||
name, make_op, result_type, arg_type)
|
||||
return self.array_binop_funcs[name]
|
||||
|
||||
def visit_UnaryOpT(self, node):
|
||||
if isinstance(node.op, ast.Not):
|
||||
cond = self.coerce_to_bool(self.visit(node.operand))
|
||||
@ -1327,9 +1392,18 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
return self.append(ir.Arith(ast.BitXor(loc=None),
|
||||
ir.Constant(-1, operand.type), operand))
|
||||
elif isinstance(node.op, ast.USub):
|
||||
def make_sub(val):
|
||||
return ir.Arith(ast.Sub(loc=None),
|
||||
ir.Constant(0, val.type), val)
|
||||
operand = self.visit(node.operand)
|
||||
return self.append(ir.Arith(ast.Sub(loc=None),
|
||||
ir.Constant(0, operand.type), operand))
|
||||
if builtins.is_array(operand.type):
|
||||
shape = self.append(ir.GetAttr(operand, "shape"))
|
||||
result = self._alloate_new_array(node.type.find()["elt"], shape)
|
||||
func = self._get_array_unaryop("USub", make_sub, node.type, operand.type)
|
||||
self._invoke_arrayop(func, [result, operand])
|
||||
return result
|
||||
else:
|
||||
return self.append(make_sub(operand))
|
||||
elif isinstance(node.op, ast.UAdd):
|
||||
# No-op.
|
||||
return self.visit(node.operand)
|
||||
@ -1423,10 +1497,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.final_branch = old_final_branch
|
||||
self.unwind_target = old_unwind
|
||||
|
||||
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
|
||||
# Currently, we always have any type coercions resolved explicitly in the AST.
|
||||
# In the future, this might no longer be true and the three types might all
|
||||
# differ.
|
||||
def _mangle_arrayop_types(self, types):
|
||||
def name_error(typ):
|
||||
assert False, "Internal compiler error: No RPC tag for {}".format(typ)
|
||||
|
||||
@ -1438,13 +1509,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
return (ir.rpc_tag(typ["elt"], name_error).decode() +
|
||||
str(typ["num_dims"].find().value))
|
||||
|
||||
name = "_array_{}_{}_{}_{}".format(
|
||||
type(op).__name__, *(map(mangle_name, (result_type, lhs_type, rhs_type))))
|
||||
return "_".join(mangle_name(t) for t in types)
|
||||
|
||||
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
|
||||
# Currently, we always have any type coercions resolved explicitly in the AST.
|
||||
# In the future, this might no longer be true and the three types might all
|
||||
# differ.
|
||||
name = "_array_{}_{}".format(
|
||||
type(op).__name__,
|
||||
self._mangle_arrayop_types([result_type, lhs_type, rhs_type]))
|
||||
if name not in self.array_binop_funcs:
|
||||
self.array_binop_funcs[name] = self._make_array_binop(
|
||||
name, op, result_type, lhs_type, rhs_type)
|
||||
return self.array_binop_funcs[name]
|
||||
|
||||
def _invoke_arrayop(self, func, params):
|
||||
closure = self.append(
|
||||
ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
|
||||
if self.unwind_target is None:
|
||||
self.append(ir.Call(closure, params, {}))
|
||||
else:
|
||||
after_invoke = self.add_block("arrayop.invoke")
|
||||
self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
|
||||
self.current_block = after_invoke
|
||||
|
||||
def visit_BinOpT(self, node):
|
||||
if builtins.is_array(node.type):
|
||||
lhs = self.visit(node.left)
|
||||
@ -1457,14 +1545,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
result = self._alloate_new_array(node.type.find()["elt"], shape)
|
||||
|
||||
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
|
||||
closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
|
||||
params = [result, lhs, rhs]
|
||||
if self.unwind_target is None:
|
||||
insn = self.append(ir.Call(closure, params, {}))
|
||||
else:
|
||||
after_invoke = self.add_block("arrayop.invoke")
|
||||
insn = self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
|
||||
self.current_block = after_invoke
|
||||
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||
|
||||
return result
|
||||
elif builtins.is_numeric(node.type):
|
||||
lhs = self.visit(node.left)
|
||||
|
@ -298,16 +298,27 @@ class Inferencer(algorithm.Visitor):
|
||||
node.operand.loc)
|
||||
self.engine.process(diag)
|
||||
else: # UAdd, USub
|
||||
if types.is_var(operand_type):
|
||||
return
|
||||
|
||||
if builtins.is_numeric(operand_type):
|
||||
self._unify(node.type, operand_type,
|
||||
node.loc, None)
|
||||
elif not types.is_var(operand_type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected unary '{op}' operand to be of numeric type, not {type}",
|
||||
{"op": node.op.loc.source(),
|
||||
"type": types.TypePrinter().name(operand_type)},
|
||||
node.operand.loc)
|
||||
self.engine.process(diag)
|
||||
self._unify(node.type, operand_type, node.loc, None)
|
||||
return
|
||||
|
||||
if builtins.is_array(operand_type):
|
||||
elt = operand_type.find()["elt"]
|
||||
if builtins.is_numeric(elt):
|
||||
self._unify(node.type, operand_type, node.loc, None)
|
||||
return
|
||||
if types.is_var(elt):
|
||||
return
|
||||
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected unary '{op}' operand to be of numeric type, not {type}",
|
||||
{"op": node.op.loc.source(),
|
||||
"type": types.TypePrinter().name(operand_type)},
|
||||
node.operand.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def visit_CoerceT(self, node):
|
||||
self.generic_visit(node)
|
||||
@ -436,7 +447,8 @@ class Inferencer(algorithm.Visitor):
|
||||
return typ
|
||||
|
||||
def map_return(typ):
|
||||
a = builtins.TArray(elt=typ, num_dims=left_dims)
|
||||
elt = builtins.TFloat() if isinstance(op, ast.Div) else typ
|
||||
a = builtins.TArray(elt=elt, num_dims=left_dims)
|
||||
return (a, a, a)
|
||||
|
||||
return self._coerce_numeric((left, right),
|
||||
|
@ -156,7 +156,7 @@ class RegionOf(algorithm.Visitor):
|
||||
visit_NameConstantT = visit_immutable
|
||||
visit_NumT = visit_immutable
|
||||
visit_EllipsisT = visit_immutable
|
||||
visit_UnaryOpT = visit_immutable
|
||||
visit_UnaryOpT = visit_sometimes_allocating # possibly array op
|
||||
visit_CompareT = visit_immutable
|
||||
|
||||
# Value lives forever
|
||||
|
11
artiq/test/lit/integration/array_unaryops.py
Normal file
11
artiq/test/lit/integration/array_unaryops.py
Normal file
@ -0,0 +1,11 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||
|
||||
a = array([1, 2])
|
||||
|
||||
b = +a
|
||||
assert b[0] == 1
|
||||
assert b[1] == 2
|
||||
|
||||
b = -a
|
||||
assert b[0] == -1
|
||||
assert b[1] == -2
|
Loading…
Reference in New Issue
Block a user