forked from M-Labs/artiq
compiler: Implement basic element-wise array operations
This commit is contained in:
parent
9af6e5747d
commit
48fb80017f
|
@ -83,6 +83,13 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
:ivar method_map: (map of :class:`ast.AttributeT` to :class:`ir.GetAttribute`)
|
:ivar method_map: (map of :class:`ast.AttributeT` to :class:`ir.GetAttribute`)
|
||||||
the map from method resolution nodes to instructions retrieving
|
the map from method resolution nodes to instructions retrieving
|
||||||
the called function inside a translated :class:`ast.CallT` node
|
the called function inside a translated :class:`ast.CallT` node
|
||||||
|
|
||||||
|
Finally, functions that implement array operations are instantiated on the fly as
|
||||||
|
necessary. They are kept track of in global dictionaries, with a mangled name
|
||||||
|
containing types and operations as key:
|
||||||
|
|
||||||
|
:ivar array_binop_funcs: the map from mangled name to implementation of binary
|
||||||
|
operations between arrays
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_size_type = builtins.TInt32()
|
_size_type = builtins.TInt32()
|
||||||
|
@ -111,6 +118,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
self.function_map = dict()
|
self.function_map = dict()
|
||||||
self.variable_map = dict()
|
self.variable_map = dict()
|
||||||
self.method_map = defaultdict(lambda: [])
|
self.method_map = defaultdict(lambda: [])
|
||||||
|
self.array_binop_funcs = dict()
|
||||||
|
|
||||||
def annotate_calls(self, devirtualization):
|
def annotate_calls(self, devirtualization):
|
||||||
for var_node in devirtualization.variable_map:
|
for var_node in devirtualization.variable_map:
|
||||||
|
@ -1337,8 +1345,124 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
name="{}.{}".format(_readable_name(value),
|
name="{}.{}".format(_readable_name(value),
|
||||||
node.type.name)))
|
node.type.name)))
|
||||||
|
|
||||||
|
def _get_total_array_len(self, shape):
|
||||||
|
lengths = [
|
||||||
|
self.append(ir.GetAttr(shape, i)) for i in range(len(shape.type.elts))
|
||||||
|
]
|
||||||
|
return reduce(lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
||||||
|
lengths[1:], lengths[0])
|
||||||
|
|
||||||
|
def _alloate_new_array(self, elt, shape):
|
||||||
|
total_length = self._get_total_array_len(shape)
|
||||||
|
buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt)))
|
||||||
|
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
|
||||||
|
return self.append(ir.Alloc([buffer, shape], result_type))
|
||||||
|
|
||||||
|
def _make_array_binop(self, name, op, result_type, lhs_type, rhs_type):
|
||||||
|
try:
|
||||||
|
result = ir.Argument(result_type, "result")
|
||||||
|
lhs = ir.Argument(lhs_type, "lhs")
|
||||||
|
rhs = ir.Argument(rhs_type, "rhs")
|
||||||
|
|
||||||
|
# TODO: We'd like to use a "C function" here to be able to supply
|
||||||
|
# specialised implementations in a library in the future (and e.g. avoid
|
||||||
|
# passing around the context argument), but the code generator currently
|
||||||
|
# doesn't allow emitting them.
|
||||||
|
args = [result, lhs, rhs]
|
||||||
|
typ = types.TFunction(args=OrderedDict([(arg.name, arg.type)
|
||||||
|
for arg in args]),
|
||||||
|
optargs=OrderedDict(),
|
||||||
|
ret=builtins.TNone())
|
||||||
|
env_args = [ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")]
|
||||||
|
|
||||||
|
# TODO: What to use for loc?
|
||||||
|
func = ir.Function(typ, name, env_args + args, loc=None)
|
||||||
|
func.is_internal = True
|
||||||
|
func.is_generated = True
|
||||||
|
self.functions.append(func)
|
||||||
|
old_func, self.current_function = self.current_function, func
|
||||||
|
|
||||||
|
entry = self.add_block("entry")
|
||||||
|
old_block, self.current_block = self.current_block, entry
|
||||||
|
|
||||||
|
old_final_branch, self.final_branch = self.final_branch, None
|
||||||
|
old_unwind, self.unwind_target = self.unwind_target, None
|
||||||
|
|
||||||
|
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||||
|
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||||
|
self._make_check(
|
||||||
|
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
||||||
|
lambda: self.alloc_exn(
|
||||||
|
builtins.TException("ValueError"),
|
||||||
|
ir.Constant("operands could not be broadcast together",
|
||||||
|
builtins.TStr())))
|
||||||
|
# We assume result has correct shape; could just pass buffer pointer as well.
|
||||||
|
|
||||||
|
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||||
|
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
||||||
|
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
||||||
|
num_total_elts = self._get_total_array_len(shape)
|
||||||
|
|
||||||
|
def body_gen(index):
|
||||||
|
l = self.append(ir.GetElem(lhs_buffer, index))
|
||||||
|
r = self.append(ir.GetElem(rhs_buffer, index))
|
||||||
|
self.append(
|
||||||
|
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, r))))
|
||||||
|
return self.append(
|
||||||
|
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type)))
|
||||||
|
|
||||||
|
self._make_loop(
|
||||||
|
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||||
|
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
||||||
|
|
||||||
|
self.append(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||||
|
return func
|
||||||
|
finally:
|
||||||
|
self.current_function = old_func
|
||||||
|
self.current_block = old_block
|
||||||
|
self.final_branch = old_final_branch
|
||||||
|
self.unwind_target = old_unwind
|
||||||
|
|
||||||
|
def _get_array_binop(self, op, result_type, lhs_type, rhs_type):
|
||||||
|
# Currently, we always have any type coercions resolved explicitly in the AST.
|
||||||
|
# In the future, this might no longer be true and the three types might all
|
||||||
|
# differ.
|
||||||
|
def name_error(typ):
|
||||||
|
assert False, "Internal compiler error: No RPC tag for {}".format(typ)
|
||||||
|
def mangle_name(typ):
|
||||||
|
typ = typ.find()
|
||||||
|
return ir.rpc_tag(typ["elt"], name_error).decode() +\
|
||||||
|
str(typ["num_dims"].find().value)
|
||||||
|
name = "_array_{}_{}_{}_{}".format(
|
||||||
|
type(op).__name__,
|
||||||
|
*(map(mangle_name, (result_type, lhs_type, rhs_type))))
|
||||||
|
if name not in self.array_binop_funcs:
|
||||||
|
self.array_binop_funcs[name] = self._make_array_binop(
|
||||||
|
name, op, result_type, lhs_type, rhs_type)
|
||||||
|
return self.array_binop_funcs[name]
|
||||||
|
|
||||||
def visit_BinOpT(self, node):
|
def visit_BinOpT(self, node):
|
||||||
if builtins.is_numeric(node.type):
|
if builtins.is_array(node.type):
|
||||||
|
lhs = self.visit(node.left)
|
||||||
|
rhs = self.visit(node.right)
|
||||||
|
|
||||||
|
# Array op implementation will check for matching shape.
|
||||||
|
# TODO: Broadcasts; select the widest shape.
|
||||||
|
# TODO: Detect and special-case matrix multiplication.
|
||||||
|
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||||
|
result = self._alloate_new_array(node.type.find()["elt"], shape)
|
||||||
|
|
||||||
|
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
|
||||||
|
closure = self.append(ir.Closure(func, ir.Constant(None, ir.TEnvironment("arrayop", {}))))
|
||||||
|
params = [result, lhs, rhs]
|
||||||
|
if self.unwind_target is None:
|
||||||
|
insn = self.append(ir.Call(closure, params, {}))
|
||||||
|
else:
|
||||||
|
after_invoke = self.add_block("arrayop.invoke")
|
||||||
|
insn = self.append(ir.Invoke(func, params, {}, after_invoke, self.unwind_target))
|
||||||
|
self.current_block = after_invoke
|
||||||
|
return result
|
||||||
|
elif builtins.is_numeric(node.type):
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
rhs = self.visit(node.right)
|
rhs = self.visit(node.right)
|
||||||
if isinstance(node.op, (ast.LShift, ast.RShift)):
|
if isinstance(node.op, (ast.LShift, ast.RShift)):
|
||||||
|
@ -1703,11 +1827,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
ir.Constant(0, self._size_type))
|
ir.Constant(0, self._size_type))
|
||||||
lengths.append(self.iterable_len(first_elt))
|
lengths.append(self.iterable_len(first_elt))
|
||||||
|
|
||||||
num_total_elts = reduce(
|
|
||||||
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
|
||||||
lengths[1:], lengths[0])
|
|
||||||
|
|
||||||
shape = self.append(ir.Alloc(lengths, result_type.attributes["shape"]))
|
shape = self.append(ir.Alloc(lengths, result_type.attributes["shape"]))
|
||||||
|
num_total_elts = self._get_total_array_len(shape)
|
||||||
|
|
||||||
# Assign buffer from nested iterables.
|
# Assign buffer from nested iterables.
|
||||||
buffer = self.append(
|
buffer = self.append(
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||||
|
|
||||||
|
a = array([1, 2, 3])
|
||||||
|
b = array([4, 5, 6])
|
||||||
|
|
||||||
|
c = a + b
|
||||||
|
assert c[0] == 5
|
||||||
|
assert c[1] == 7
|
||||||
|
assert c[2] == 9
|
||||||
|
|
||||||
|
c = a * b
|
||||||
|
assert c[0] == 4
|
||||||
|
assert c[1] == 10
|
||||||
|
assert c[2] == 18
|
||||||
|
|
||||||
|
c = b // a
|
||||||
|
assert c[0] == 4
|
||||||
|
assert c[1] == 2
|
||||||
|
assert c[2] == 2
|
Loading…
Reference in New Issue