forked from M-Labs/artiq
compiler: Implement basic element-wise array operations
This commit is contained in:
parent
9af6e5747d
commit
48fb80017f
@ -83,6 +83,13 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
:ivar method_map: (map of :class:`ast.AttributeT` to :class:`ir.GetAttribute`)
|
||||
the map from method resolution nodes to instructions retrieving
|
||||
the called function inside a translated :class:`ast.CallT` node
|
||||
|
||||
Finally, functions that implement array operations are instantiated on the fly as
|
||||
necessary. They are kept track of in global dictionaries, with a mangled name
|
||||
containing types and operations as key:
|
||||
|
||||
:ivar array_binop_funcs: the map from mangled name to implementation of binary
|
||||
operations between arrays
|
||||
"""
|
||||
|
||||
_size_type = builtins.TInt32()
|
||||
@ -111,6 +118,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.function_map = dict()
|
||||
self.variable_map = dict()
|
||||
self.method_map = defaultdict(lambda: [])
|
||||
self.array_binop_funcs = dict()
|
||||
|
||||
def annotate_calls(self, devirtualization):
|
||||
for var_node in devirtualization.variable_map:
|
||||
@ -1337,8 +1345,124 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
name="{}.{}".format(_readable_name(value),
|
||||
node.type.name)))
|
||||
|
||||
def _get_total_array_len(self, shape):
|
||||
lengths = [
|
||||
self.append(ir.GetAttr(shape, i)) for i in range(len(shape.type.elts))
|
||||
]
|
||||
return reduce(lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
||||
lengths[1:], lengths[0])
|
||||
|
||||
def _alloate_new_array(self, elt, shape):
|
||||
total_length = self._get_total_array_len(shape)
|
||||
buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt)))
|
||||
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
|
||||
return self.append(ir.Alloc([buffer, shape], result_type))
|
||||
|
||||
def _make_array_binop(self, name, op, result_type, lhs_type, rhs_type):
|
||||
try:
|
||||
result = ir.Argument(result_type, "result")
|
||||
lhs = ir.Argument(lhs_type, "lhs")
|
||||
rhs = ir.Argument(rhs_type, "rhs")
|
||||
|
||||
# TODO: We'd like to use a "C function" here to be able to supply
|
||||
# specialised implementations in a library in the future (and e.g. avoid
|
||||
# passing around the context argument), but the code generator currently
|
||||
# doesn't allow emitting them.
|
||||
args = [result, lhs, rhs]
|
||||
typ = types.TFunction(args=OrderedDict([(arg.name, arg.type)
|
||||
for arg in args]),
|
||||
optargs=OrderedDict(),
|
||||
ret=builtins.TNone())
|
||||
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
||||
|
||||
# TODO: What to use for loc?
|
||||
func = ir.Function(typ, name, env_args + args, loc=None)
|
||||
func.is_internal = True
|
||||
func.is_generated = True
|
||||
self.functions.append(func)
|
||||
old_func, self.current_function = self.current_function, func
|
||||
|
||||
entry = self.add_block("entry")
|
||||
old_block, self.current_block = self.current_block, entry
|
||||
|
||||
old_final_branch, self.final_branch = self.final_branch, None
|
||||
old_unwind, self.unwind_target = self.unwind_target, None
|
||||
|
||||
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
||||
lambda: self.alloc_exn(
|
||||
builtins.TException("ValueError"),
|
||||
ir.Constant("operands could not be broadcast together",
|
||||
builtins.TStr())))
|
||||
# We assume result has correct shape; could just pass buffer pointer as well.
|
||||
|
||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
||||
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
||||
num_total_elts = self._get_total_array_len(shape)
|
||||
|
||||
def body_gen(index):
|
||||
l = self.append(ir.GetElem(lhs_buffer, index))
|
||||
r = self.append(ir.GetElem(rhs_buffer, index))
|
||||
self.append(
|
||||
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, r))))
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
|
||||
|
||||
self._make_loop(
|
||||
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
||||
|
||||
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||
return func
|
||||
finally:
|
||||
self.current_function = old_func
|
||||
self.current_block = old_block
|
||||
self.final_branch = old_final_branch
|
||||
self.unwind_target = old_unwind
|
||||
|
||||
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
|
||||
# Currently, we always have any type coercions resolved explicitly in the AST.
|
||||
# In the future, this might no longer be true and the three types might all
|
||||
# differ.
|
||||
def name_error(typ):
|
||||
assert False, "Internal compiler error: No RPC tag for {}".format(typ)
|
||||
def mangle_name(typ):
|
||||
typ = typ.find()
|
||||
return ir.rpc_tag(typ["elt"], name_error).decode() +\
|
||||
str(typ["num_dims"].find().value)
|
||||
name = "_array_{}_{}_{}_{}".format(
|
||||
type(op).__name__,
|
||||
*(map(mangle_name, (result_type, lhs_type, rhs_type))))
|
||||
if name not in self.array_binop_funcs:
|
||||
self.array_binop_funcs[name] = self._make_array_binop(
|
||||
name, op, result_type, lhs_type, rhs_type)
|
||||
return self.array_binop_funcs[name]
|
||||
|
||||
def visit_BinOpT(self, node):
|
||||
if builtins.is_numeric(node.type):
|
||||
if builtins.is_array(node.type):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
|
||||
# Array op implementation will check for matching shape.
|
||||
# TODO: Broadcasts; select the widest shape.
|
||||
# TODO: Detect and special-case matrix multiplication.
|
||||
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||
result = self._alloate_new_array(node.type.find()["elt"], shape)
|
||||
|
||||
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
|
||||
closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
|
||||
params = [result, lhs, rhs]
|
||||
if self.unwind_target is None:
|
||||
insn = self.append(ir.Call(closure, params, {}))
|
||||
else:
|
||||
after_invoke = self.add_block("arrayop.invoke")
|
||||
insn = self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
|
||||
self.current_block = after_invoke
|
||||
return result
|
||||
elif builtins.is_numeric(node.type):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(node.op, (ast.LShift, ast.RShift)):
|
||||
@ -1703,11 +1827,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
ir.Constant(0, self._size_type))
|
||||
lengths.append(self.iterable_len(first_elt))
|
||||
|
||||
num_total_elts = reduce(
|
||||
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
||||
lengths[1:], lengths[0])
|
||||
|
||||
shape = self.append(ir.Alloc(lengths, result_type.attributes["shape"]))
|
||||
num_total_elts = self._get_total_array_len(shape)
|
||||
|
||||
# Assign buffer from nested iterables.
|
||||
buffer = self.append(
|
||||
|
19
artiq/test/lit/integration/array_ops.py
Normal file
19
artiq/test/lit/integration/array_ops.py
Normal file
@ -0,0 +1,19 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||
|
||||
a = array([1, 2, 3])
|
||||
b = array([4, 5, 6])
|
||||
|
||||
c = a + b
|
||||
assert c[0] == 5
|
||||
assert c[1] == 7
|
||||
assert c[2] == 9
|
||||
|
||||
c = a * b
|
||||
assert c[0] == 4
|
||||
assert c[1] == 10
|
||||
assert c[2] == 18
|
||||
|
||||
c = b // a
|
||||
assert c[0] == 4
|
||||
assert c[1] == 2
|
||||
assert c[2] == 2
|
Loading…
Reference in New Issue
Block a user