forked from M-Labs/artiq
compiler: Implement unary plus/minus for arrays
Implementation is needlessly generic to anticipate coercion/transcendental functions.
This commit is contained in:
parent
0d8fbd4f19
commit
4426e4144f
|
@ -88,6 +88,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
necessary. They are kept track of in global dictionaries, with a mangled name
|
necessary. They are kept track of in global dictionaries, with a mangled name
|
||||||
containing types and operations as key:
|
containing types and operations as key:
|
||||||
|
|
||||||
|
:ivar array_unaryop_funcs: the map from mangled name to implementation of unary
|
||||||
|
operations for arrays
|
||||||
:ivar array_binop_funcs: the map from mangled name to implementation of binary
|
:ivar array_binop_funcs: the map from mangled name to implementation of binary
|
||||||
operations between arrays
|
operations between arrays
|
||||||
"""
|
"""
|
||||||
|
@ -118,6 +120,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
self.function_map = dict()
|
self.function_map = dict()
|
||||||
self.variable_map = dict()
|
self.variable_map = dict()
|
||||||
self.method_map = defaultdict(lambda: [])
|
self.method_map = defaultdict(lambda: [])
|
||||||
|
self.array_unaryop_funcs = dict()
|
||||||
self.array_binop_funcs = dict()
|
self.array_binop_funcs = dict()
|
||||||
|
|
||||||
def annotate_calls(self, devirtualization):
|
def annotate_calls(self, devirtualization):
|
||||||
|
@ -1316,6 +1319,68 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
value_tail.append(ir.Branch(tail))
|
value_tail.append(ir.Branch(tail))
|
||||||
return phi
|
return phi
|
||||||
|
|
||||||
|
def _make_array_unaryop(self, name, make_op, result_type, arg_type):
|
||||||
|
try:
|
||||||
|
result = ir.Argument(result_type, "result")
|
||||||
|
arg = ir.Argument(arg_type, "arg")
|
||||||
|
|
||||||
|
# TODO: We'd like to use a "C function" here to be able to supply
|
||||||
|
# specialised implementations in a library in the future (and e.g. avoid
|
||||||
|
# passing around the context argument), but the code generator currently
|
||||||
|
# doesn't allow emitting them.
|
||||||
|
args = [result, arg]
|
||||||
|
typ = types.TFunction(args=OrderedDict([(arg.name, arg.type)
|
||||||
|
for arg in args]),
|
||||||
|
optargs=OrderedDict(),
|
||||||
|
ret=builtins.TNone())
|
||||||
|
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
||||||
|
|
||||||
|
# TODO: What to use for loc?
|
||||||
|
func = ir.Function(typ, name, env_args + args, loc=None)
|
||||||
|
func.is_internal = True
|
||||||
|
func.is_generated = True
|
||||||
|
self.functions.append(func)
|
||||||
|
old_func, self.current_function = self.current_function, func
|
||||||
|
|
||||||
|
entry = self.add_block("entry")
|
||||||
|
old_block, self.current_block = self.current_block, entry
|
||||||
|
|
||||||
|
old_final_branch, self.final_branch = self.final_branch, None
|
||||||
|
old_unwind, self.unwind_target = self.unwind_target, None
|
||||||
|
|
||||||
|
shape = self.append(ir.GetAttr(arg, "shape"))
|
||||||
|
|
||||||
|
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||||
|
arg_buffer = self.append(ir.GetAttr(arg, "buffer"))
|
||||||
|
num_total_elts = self._get_total_array_len(shape)
|
||||||
|
|
||||||
|
def body_gen(index):
|
||||||
|
a = self.append(ir.GetElem(arg_buffer, index))
|
||||||
|
self.append(
|
||||||
|
ir.SetElem(result_buffer, index, self.append(make_op(a))))
|
||||||
|
return self.append(
|
||||||
|
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
|
||||||
|
|
||||||
|
self._make_loop(
|
||||||
|
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||||
|
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
||||||
|
|
||||||
|
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||||
|
return func
|
||||||
|
finally:
|
||||||
|
self.current_function = old_func
|
||||||
|
self.current_block = old_block
|
||||||
|
self.final_branch = old_final_branch
|
||||||
|
self.unwind_target = old_unwind
|
||||||
|
|
||||||
|
def _get_array_unaryop(self, name, make_op, result_type, arg_type):
|
||||||
|
name = "_array_{}_{}".format(
|
||||||
|
name, self._mangle_arrayop_types([result_type, arg_type]))
|
||||||
|
if name not in self.array_unaryop_funcs:
|
||||||
|
self.array_binop_funcs[name] = self._make_array_unaryop(
|
||||||
|
name, make_op, result_type, arg_type)
|
||||||
|
return self.array_binop_funcs[name]
|
||||||
|
|
||||||
def visit_UnaryOpT(self, node):
|
def visit_UnaryOpT(self, node):
|
||||||
if isinstance(node.op, ast.Not):
|
if isinstance(node.op, ast.Not):
|
||||||
cond = self.coerce_to_bool(self.visit(node.operand))
|
cond = self.coerce_to_bool(self.visit(node.operand))
|
||||||
|
@ -1327,9 +1392,18 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
return self.append(ir.Arith(ast.BitXor(loc=None),
|
return self.append(ir.Arith(ast.BitXor(loc=None),
|
||||||
ir.Constant(-1, operand.type), operand))
|
ir.Constant(-1, operand.type), operand))
|
||||||
elif isinstance(node.op, ast.USub):
|
elif isinstance(node.op, ast.USub):
|
||||||
|
def make_sub(val):
|
||||||
|
return ir.Arith(ast.Sub(loc=None),
|
||||||
|
ir.Constant(0, val.type), val)
|
||||||
operand = self.visit(node.operand)
|
operand = self.visit(node.operand)
|
||||||
return self.append(ir.Arith(ast.Sub(loc=None),
|
if builtins.is_array(operand.type):
|
||||||
ir.Constant(0, operand.type), operand))
|
shape = self.append(ir.GetAttr(operand, "shape"))
|
||||||
|
result = self._alloate_new_array(node.type.find()["elt"], shape)
|
||||||
|
func = self._get_array_unaryop("USub", make_sub, node.type, operand.type)
|
||||||
|
self._invoke_arrayop(func, [result, operand])
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return self.append(make_sub(operand))
|
||||||
elif isinstance(node.op, ast.UAdd):
|
elif isinstance(node.op, ast.UAdd):
|
||||||
# No-op.
|
# No-op.
|
||||||
return self.visit(node.operand)
|
return self.visit(node.operand)
|
||||||
|
@ -1423,10 +1497,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
self.final_branch = old_final_branch
|
self.final_branch = old_final_branch
|
||||||
self.unwind_target = old_unwind
|
self.unwind_target = old_unwind
|
||||||
|
|
||||||
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
|
def _mangle_arrayop_types(self, types):
|
||||||
# Currently, we always have any type coercions resolved explicitly in the AST.
|
|
||||||
# In the future, this might no longer be true and the three types might all
|
|
||||||
# differ.
|
|
||||||
def name_error(typ):
|
def name_error(typ):
|
||||||
assert False, "Internal compiler error: No RPC tag for {}".format(typ)
|
assert False, "Internal compiler error: No RPC tag for {}".format(typ)
|
||||||
|
|
||||||
|
@ -1438,13 +1509,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
return (ir.rpc_tag(typ["elt"], name_error).decode() +
|
return (ir.rpc_tag(typ["elt"], name_error).decode() +
|
||||||
str(typ["num_dims"].find().value))
|
str(typ["num_dims"].find().value))
|
||||||
|
|
||||||
name = "_array_{}_{}_{}_{}".format(
|
return "_".join(mangle_name(t) for t in types)
|
||||||
type(op).__name__, *(map(mangle_name, (result_type, lhs_type, rhs_type))))
|
|
||||||
|
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
|
||||||
|
# Currently, we always have any type coercions resolved explicitly in the AST.
|
||||||
|
# In the future, this might no longer be true and the three types might all
|
||||||
|
# differ.
|
||||||
|
name = "_array_{}_{}".format(
|
||||||
|
type(op).__name__,
|
||||||
|
self._mangle_arrayop_types([result_type, lhs_type, rhs_type]))
|
||||||
if name not in self.array_binop_funcs:
|
if name not in self.array_binop_funcs:
|
||||||
self.array_binop_funcs[name] = self._make_array_binop(
|
self.array_binop_funcs[name] = self._make_array_binop(
|
||||||
name, op, result_type, lhs_type, rhs_type)
|
name, op, result_type, lhs_type, rhs_type)
|
||||||
return self.array_binop_funcs[name]
|
return self.array_binop_funcs[name]
|
||||||
|
|
||||||
|
def _invoke_arrayop(self, func, params):
|
||||||
|
closure = self.append(
|
||||||
|
ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
|
||||||
|
if self.unwind_target is None:
|
||||||
|
self.append(ir.Call(closure, params, {}))
|
||||||
|
else:
|
||||||
|
after_invoke = self.add_block("arrayop.invoke")
|
||||||
|
self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
|
||||||
|
self.current_block = after_invoke
|
||||||
|
|
||||||
def visit_BinOpT(self, node):
|
def visit_BinOpT(self, node):
|
||||||
if builtins.is_array(node.type):
|
if builtins.is_array(node.type):
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
|
@ -1457,14 +1545,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
result = self._alloate_new_array(node.type.find()["elt"], shape)
|
result = self._alloate_new_array(node.type.find()["elt"], shape)
|
||||||
|
|
||||||
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
|
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
|
||||||
closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
|
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||||
params = [result, lhs, rhs]
|
|
||||||
if self.unwind_target is None:
|
|
||||||
insn = self.append(ir.Call(closure, params, {}))
|
|
||||||
else:
|
|
||||||
after_invoke = self.add_block("arrayop.invoke")
|
|
||||||
insn = self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
|
|
||||||
self.current_block = after_invoke
|
|
||||||
return result
|
return result
|
||||||
elif builtins.is_numeric(node.type):
|
elif builtins.is_numeric(node.type):
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
|
|
|
@ -298,16 +298,27 @@ class Inferencer(algorithm.Visitor):
|
||||||
node.operand.loc)
|
node.operand.loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
else: # UAdd, USub
|
else: # UAdd, USub
|
||||||
|
if types.is_var(operand_type):
|
||||||
|
return
|
||||||
|
|
||||||
if builtins.is_numeric(operand_type):
|
if builtins.is_numeric(operand_type):
|
||||||
self._unify(node.type, operand_type,
|
self._unify(node.type, operand_type, node.loc, None)
|
||||||
node.loc, None)
|
return
|
||||||
elif not types.is_var(operand_type):
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
if builtins.is_array(operand_type):
|
||||||
"expected unary '{op}' operand to be of numeric type, not {type}",
|
elt = operand_type.find()["elt"]
|
||||||
{"op": node.op.loc.source(),
|
if builtins.is_numeric(elt):
|
||||||
"type": types.TypePrinter().name(operand_type)},
|
self._unify(node.type, operand_type, node.loc, None)
|
||||||
node.operand.loc)
|
return
|
||||||
self.engine.process(diag)
|
if types.is_var(elt):
|
||||||
|
return
|
||||||
|
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"expected unary '{op}' operand to be of numeric type, not {type}",
|
||||||
|
{"op": node.op.loc.source(),
|
||||||
|
"type": types.TypePrinter().name(operand_type)},
|
||||||
|
node.operand.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
def visit_CoerceT(self, node):
|
def visit_CoerceT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
@ -436,7 +447,8 @@ class Inferencer(algorithm.Visitor):
|
||||||
return typ
|
return typ
|
||||||
|
|
||||||
def map_return(typ):
|
def map_return(typ):
|
||||||
a = builtins.TArray(elt=typ, num_dims=left_dims)
|
elt = builtins.TFloat() if isinstance(op, ast.Div) else typ
|
||||||
|
a = builtins.TArray(elt=elt, num_dims=left_dims)
|
||||||
return (a, a, a)
|
return (a, a, a)
|
||||||
|
|
||||||
return self._coerce_numeric((left, right),
|
return self._coerce_numeric((left, right),
|
||||||
|
|
|
@ -156,7 +156,7 @@ class RegionOf(algorithm.Visitor):
|
||||||
visit_NameConstantT = visit_immutable
|
visit_NameConstantT = visit_immutable
|
||||||
visit_NumT = visit_immutable
|
visit_NumT = visit_immutable
|
||||||
visit_EllipsisT = visit_immutable
|
visit_EllipsisT = visit_immutable
|
||||||
visit_UnaryOpT = visit_immutable
|
visit_UnaryOpT = visit_sometimes_allocating # possibly array op
|
||||||
visit_CompareT = visit_immutable
|
visit_CompareT = visit_immutable
|
||||||
|
|
||||||
# Value lives forever
|
# Value lives forever
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||||
|
|
||||||
|
a = array([1, 2])
|
||||||
|
|
||||||
|
b = +a
|
||||||
|
assert b[0] == 1
|
||||||
|
assert b[1] == 2
|
||||||
|
|
||||||
|
b = -a
|
||||||
|
assert b[0] == -1
|
||||||
|
assert b[1] == -2
|
Loading…
Reference in New Issue