compiler: Implement basic element-wise array operations

This commit is contained in:
David Nadlinger 2020-07-29 01:28:55 +01:00
parent 9af6e5747d
commit 48fb80017f
2 changed files with 145 additions and 5 deletions

View File

@ -83,6 +83,13 @@ class ARTIQIRGenerator(algorithm.Visitor):
:ivar method_map: (map of :class:`ast.AttributeT` to :class:`ir.GetAttribute`) :ivar method_map: (map of :class:`ast.AttributeT` to :class:`ir.GetAttribute`)
the map from method resolution nodes to instructions retrieving the map from method resolution nodes to instructions retrieving
the called function inside a translated :class:`ast.CallT` node the called function inside a translated :class:`ast.CallT` node
Finally, functions that implement array operations are instantiated on the fly as
necessary. They are kept track of in global dictionaries, with a mangled name
containing types and operations as key:
:ivar array_binop_funcs: the map from mangled name to implementation of binary
operations between arrays
""" """
_size_type = builtins.TInt32() _size_type = builtins.TInt32()
@ -111,6 +118,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.function_map = dict() self.function_map = dict()
self.variable_map = dict() self.variable_map = dict()
self.method_map = defaultdict(lambda: []) self.method_map = defaultdict(lambda: [])
self.array_binop_funcs = dict()
def annotate_calls(self, devirtualization): def annotate_calls(self, devirtualization):
for var_node in devirtualization.variable_map: for var_node in devirtualization.variable_map:
@ -1337,8 +1345,124 @@ class ARTIQIRGenerator(algorithm.Visitor):
name="{}.{}".format(_readable_name(value), name="{}.{}".format(_readable_name(value),
node.type.name))) node.type.name)))
def _get_total_array_len(self, shape):
lengths = [
self.append(ir.GetAttr(shape, i)) for i in range(len(shape.type.elts))
]
return reduce(lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
lengths[1:], lengths[0])
def _alloate_new_array(self, elt, shape):
total_length = self._get_total_array_len(shape)
buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt)))
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
return self.append(ir.Alloc([buffer, shape], result_type))
def _make_array_binop(self, name, op, result_type, lhs_type, rhs_type):
try:
result = ir.Argument(result_type, "result")
lhs = ir.Argument(lhs_type, "lhs")
rhs = ir.Argument(rhs_type, "rhs")
# TODO: We'd like to use a "C function" here to be able to supply
# specialised implementations in a library in the future (and e.g. avoid
# passing around the context argument), but the code generator currently
# doesn't allow emitting them.
args = [result, lhs, rhs]
typ = types.TFunction(args=OrderedDict([(arg.name, arg.type)
for arg in args]),
optargs=OrderedDict(),
ret=builtins.TNone())
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
# TODO: What to use for loc?
func = ir.Function(typ, name, env_args + args, loc=None)
func.is_internal = True
func.is_generated = True
self.functions.append(func)
old_func, self.current_function = self.current_function, func
entry = self.add_block("entry")
old_block, self.current_block = self.current_block, entry
old_final_branch, self.final_branch = self.final_branch, None
old_unwind, self.unwind_target = self.unwind_target, None
shape = self.append(ir.GetAttr(lhs, "shape"))
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
self._make_check(
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
lambda: self.alloc_exn(
builtins.TException("ValueError"),
ir.Constant("operands could not be broadcast together",
builtins.TStr())))
# We assume result has correct shape; could just pass buffer pointer as well.
result_buffer = self.append(ir.GetAttr(result, "buffer"))
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
num_total_elts = self._get_total_array_len(shape)
def body_gen(index):
l = self.append(ir.GetElem(lhs_buffer, index))
r = self.append(ir.GetElem(rhs_buffer, index))
self.append(
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, r))))
return self.append(
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
return func
finally:
self.current_function = old_func
self.current_block = old_block
self.final_branch = old_final_branch
self.unwind_target = old_unwind
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
# Currently, we always have any type coercions resolved explicitly in the AST.
# In the future, this might no longer be true and the three types might all
# differ.
def name_error(typ):
assert False, "Internal compiler error: No RPC tag for {}".format(typ)
def mangle_name(typ):
typ = typ.find()
return ir.rpc_tag(typ["elt"], name_error).decode() +\
str(typ["num_dims"].find().value)
name = "_array_{}_{}_{}_{}".format(
type(op).__name__,
*(map(mangle_name, (result_type, lhs_type, rhs_type))))
if name not in self.array_binop_funcs:
self.array_binop_funcs[name] = self._make_array_binop(
name, op, result_type, lhs_type, rhs_type)
return self.array_binop_funcs[name]
def visit_BinOpT(self, node): def visit_BinOpT(self, node):
if builtins.is_numeric(node.type): if builtins.is_array(node.type):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
# Array op implementation will check for matching shape.
# TODO: Broadcasts; select the widest shape.
# TODO: Detect and special-case matrix multiplication.
shape = self.append(ir.GetAttr(lhs, "shape"))
result = self._alloate_new_array(node.type.find()["elt"], shape)
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
params = [result, lhs, rhs]
if self.unwind_target is None:
insn = self.append(ir.Call(closure, params, {}))
else:
after_invoke = self.add_block("arrayop.invoke")
insn = self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
self.current_block = after_invoke
return result
elif builtins.is_numeric(node.type):
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.right) rhs = self.visit(node.right)
if isinstance(node.op, (ast.LShift, ast.RShift)): if isinstance(node.op, (ast.LShift, ast.RShift)):
@ -1703,11 +1827,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
ir.Constant(0, self._size_type)) ir.Constant(0, self._size_type))
lengths.append(self.iterable_len(first_elt)) lengths.append(self.iterable_len(first_elt))
num_total_elts = reduce(
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
lengths[1:], lengths[0])
shape = self.append(ir.Alloc(lengths, result_type.attributes["shape"])) shape = self.append(ir.Alloc(lengths, result_type.attributes["shape"]))
num_total_elts = self._get_total_array_len(shape)
# Assign buffer from nested iterables. # Assign buffer from nested iterables.
buffer = self.append( buffer = self.append(

View File

@ -0,0 +1,19 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
a = array([1, 2, 3])
b = array([4, 5, 6])
c = a + b
assert c[0] == 5
assert c[1] == 7
assert c[2] == 9
c = a * b
assert c[0] == 4
assert c[1] == 10
assert c[2] == 18
c = b // a
assert c[0] == 4
assert c[1] == 2
assert c[2] == 2