mirror of https://github.com/m-labs/artiq.git
compiler: Implement 1D-/2D- array transpose
Left generic transpose (shape order inversion) for now, as that would be less ugly if we implement forwarding to Python function bodies for array function implementations. Needs a runtime test case.
This commit is contained in:
parent
faea886c44
commit
be7d78253f
|
@ -26,6 +26,9 @@ unary_fp_runtime_calls = [
|
|||
("arctan", "atan"),
|
||||
]
|
||||
|
||||
#: Array handling builtins (special treatment due to allocations).
|
||||
numpy_builtins = ["transpose"]
|
||||
|
||||
|
||||
def unary_fp_type(name):
|
||||
return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]),
|
||||
|
@ -36,6 +39,8 @@ numpy_map = {
|
|||
getattr(numpy, symbol): unary_fp_type(mangle)
|
||||
for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls)
|
||||
}
|
||||
for name in numpy_builtins:
|
||||
numpy_map[getattr(numpy, name)] = types.TBuiltinFunction("numpy." + name)
|
||||
|
||||
|
||||
def match(obj):
|
||||
|
|
|
@ -2217,6 +2217,51 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
return result
|
||||
else:
|
||||
assert False
|
||||
elif types.is_builtin(typ, "numpy.transpose"):
|
||||
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||
arg, = map(self.visit, node.args)
|
||||
|
||||
num_dims = arg.type.find()["num_dims"].value
|
||||
if num_dims == 1:
|
||||
# No-op as per NumPy semantics.
|
||||
return arg
|
||||
assert num_dims == 2
|
||||
arg_shape = self.append(ir.GetAttr(arg, "shape"))
|
||||
dim0 = self.append(ir.GetAttr(arg_shape, 0))
|
||||
dim1 = self.append(ir.GetAttr(arg_shape, 1))
|
||||
shape = self._make_array_shape([dim1, dim0])
|
||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
arg_buffer = self.append(ir.GetAttr(arg, "buffer"))
|
||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||
|
||||
def outer_gen(idx1):
|
||||
arg_base = self.append(ir.Offset(arg_buffer, idx1))
|
||||
result_offset = self.append(ir.Arith(ast.Mult(loc=None), idx1,
|
||||
dim0))
|
||||
result_base = self.append(ir.Offset(result_buffer, result_offset))
|
||||
|
||||
def inner_gen(idx0):
|
||||
arg_offset = self.append(
|
||||
ir.Arith(ast.Mult(loc=None), idx0, dim1))
|
||||
val = self.append(ir.GetElem(arg_base, arg_offset))
|
||||
self.append(ir.SetElem(result_base, idx0, val))
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), idx0, ir.Constant(1,
|
||||
idx0.type)))
|
||||
|
||||
self._make_loop(
|
||||
ir.Constant(0, self._size_type), lambda idx0: self.append(
|
||||
ir.Compare(ast.Lt(loc=None), idx0, dim0)), inner_gen)
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), idx1, ir.Constant(1, idx1.type)))
|
||||
|
||||
self._make_loop(
|
||||
ir.Constant(0, self._size_type),
|
||||
lambda idx1: self.append(ir.Compare(ast.Lt(loc=None), idx1, dim1)),
|
||||
outer_gen)
|
||||
return result
|
||||
else:
|
||||
assert False
|
||||
elif types.is_builtin(typ, "print"):
|
||||
self.polymorphic_print([self.visit(arg) for arg in node.args],
|
||||
separator=" ", suffix="\n")
|
||||
|
|
|
@ -1074,6 +1074,45 @@ class Inferencer(algorithm.Visitor):
|
|||
arg1.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "numpy.transpose"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("transpose(x: array(elt='a, num_dims=1)) -> array(elt='a, num_dims=1)"),
|
||||
valid_form("transpose(x: array(elt='a, num_dims=2)) -> array(elt='a, num_dims=2)")
|
||||
]
|
||||
|
||||
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||
arg, = node.args
|
||||
|
||||
if types.is_var(arg.type):
|
||||
pass # undetermined yet
|
||||
elif not builtins.is_array(arg.type):
|
||||
note = diagnostic.Diagnostic(
|
||||
"note", "this expression has type {type}",
|
||||
{"type": types.TypePrinter().name(arg.type)}, arg.loc)
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"the argument of {builtin}() must be an array",
|
||||
{"builtin": typ.find().name},
|
||||
node.func.loc,
|
||||
notes=[note])
|
||||
self.engine.process(diag)
|
||||
else:
|
||||
num_dims = arg.type.find()["num_dims"].value
|
||||
if num_dims not in (1, 2):
|
||||
note = diagnostic.Diagnostic(
|
||||
"note", "argument is {num_dims}-dimensional",
|
||||
{"num_dims": num_dims}, arg.loc)
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"{builtin}() is currently only supported for up to "
|
||||
"two-dimensional arrays", {"builtin": typ.find().name},
|
||||
node.func.loc,
|
||||
notes=[note])
|
||||
self.engine.process(diag)
|
||||
else:
|
||||
self._unify(node.type, arg.type, node.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "rtio_log"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("rtio_log(channel:str, args...) -> None"),
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.embedding %s
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
import numpy as np
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
# FIXME: This needs to be a runtime test (but numpy.* integration is
|
||||
# currently embedding-only).
|
||||
a = np.array([1, 2, 3])
|
||||
b = np.transpose(a)
|
||||
assert a.shape == b.shape
|
||||
for i in range(len(a)):
|
||||
assert a[i] == b[i]
|
||||
|
||||
c = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
d = np.transpose(c)
|
||||
assert c.shape == d.shape
|
||||
for i in range(2):
|
||||
for j in range(3):
|
||||
assert c[i][j] == d[j][i]
|
Loading…
Reference in New Issue