2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-26 03:38:25 +08:00

compiler: Fix numpy.full, implement for >1D

This commit is contained in:
David Nadlinger 2020-08-09 23:30:25 +01:00
parent 53d64d08a8
commit 778f2cf905
3 changed files with 45 additions and 19 deletions

View File

@ -1411,7 +1411,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
operand = self.visit(node.operand)
if builtins.is_array(operand.type):
shape = self.append(ir.GetAttr(operand, "shape"))
result = self._allocate_new_array(node.type.find()["elt"], shape)
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
func = self._get_array_unaryop("USub", make_sub, node.type, operand.type)
self._invoke_arrayop(func, [result, operand])
return result
@ -1431,7 +1431,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
if builtins.is_array(node.type):
result_elt = node.type.find()["elt"]
shape = self.append(ir.GetAttr(value, "shape"))
result = self._allocate_new_array(result_elt, shape)
result, _ = self._allocate_new_array(result_elt, shape)
func = self._get_array_unaryop(
"Coerce", lambda v: self.append(ir.Coerce(v, result_elt)),
node.type, value.type)
@ -1455,7 +1455,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
total_length = self._get_total_array_len(shape)
buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt)))
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
return self.append(ir.Alloc([buffer, shape], result_type))
return self.append(ir.Alloc([buffer, shape], result_type)), total_length
def _make_array_binop(self, name, result_type, lhs_type, rhs_type, body_gen):
try:
@ -1704,7 +1704,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
elt = final_type["elt"]
result_dims = final_type["num_dims"].value
result = self._allocate_new_array(elt, result_shape)
result, _ = self._allocate_new_array(elt, result_shape)
func = self._get_matmult(result.type, left.type, right.type)
self._invoke_arrayop(func, [result, lhs, rhs])
@ -1745,7 +1745,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
ir.Constant("operands could not be broadcast together",
builtins.TStr())))
result = self._allocate_new_array(node.type.find()["elt"], shape)
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type)
self._invoke_arrayop(func, [result, lhs, rhs])
return result
@ -2223,14 +2223,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
if len(node.args) == 2 and len(node.keywords) == 0:
arg0, arg1 = map(self.visit, node.args)
result = self.append(ir.Alloc([arg0], node.type))
num_dims = node.type.find()["num_dims"].value
if types.is_tuple(arg0.type):
lens = [self.append(ir.GetAttr(arg0, i)) for i in range(num_dims)]
else:
assert num_dims == 1
lens = [arg0]
shape = self._make_array_shape(lens)
result, total_len = self._allocate_new_array(node.type.find()["elt"],
shape)
def body_gen(index):
self.append(ir.SetElem(result, index, arg1))
return self.append(ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, arg0.type)))
self._make_loop(ir.Constant(0, self._size_type),
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, arg0)),
body_gen)
return self.append(
ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, total_len)), body_gen)
return result
else:
assert False
@ -2247,7 +2259,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
dim0 = self.append(ir.GetAttr(arg_shape, 0))
dim1 = self.append(ir.GetAttr(arg_shape, 1))
shape = self._make_array_shape([dim1, dim0])
result = self._allocate_new_array(node.type.find()["elt"], shape)
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
arg_buffer = self.append(ir.GetAttr(arg, "buffer"))
result_buffer = self.append(ir.GetAttr(result, "buffer"))
@ -2413,7 +2425,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
node.arg_exprs)
shape = self.append(ir.GetAttr(args[0], "shape"))
result = self._allocate_new_array(node.type.find()["elt"], shape)
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
# TODO: Generate more generically if non-externals are allowed.
name = node.func.type.find().name

View File

@ -1103,17 +1103,26 @@ class Inferencer(algorithm.Visitor):
diagnose(valid_forms())
elif types.is_builtin(typ, "make_array"):
valid_forms = lambda: [
valid_form("numpy.full(count:int32, value:'a) -> numpy.array(elt='a)")
valid_form("numpy.full(count:int32, value:'a) -> array(elt='a, num_dims=1)"),
valid_form("numpy.full(shape:(int32,)*'b, value:'a) -> array(elt='a, num_dims='b)"),
]
self._unify(node.type, builtins.TArray(),
node.loc, None)
if len(node.args) == 2 and len(node.keywords) == 0:
arg0, arg1 = node.args
self._unify(arg0.type, builtins.TInt32(),
arg0.loc, None)
if types.is_var(arg0.type):
return # undetermined yet
elif types.is_tuple(arg0.type):
num_dims = len(arg0.type.find().elts)
self._unify(arg0.type, types.TTuple([builtins.TInt32()] * num_dims),
arg0.loc, None)
else:
num_dims = 1
self._unify(arg0.type, builtins.TInt32(),
arg0.loc, None)
self._unify(node.type, builtins.TArray(num_dims=num_dims),
node.loc, None)
self._unify(arg1.type, node.type.find()["elt"],
arg1.loc, None)
else:

View File

@ -244,6 +244,10 @@ class _RPCCalls(EnvExperiment):
def numpy_full(self):
return numpy.full(10, 20)
@kernel
def numpy_full_matrix(self):
return numpy.full((3, 2), 13)
@kernel
def numpy_nan(self):
return numpy.full(10, numpy.nan)
@ -277,6 +281,7 @@ class RPCCallsTest(ExperimentCase):
self.assertEqual(exp.numpy_things(),
(numpy.int32(10), numpy.int64(20), numpy.array([42,])))
self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all())
self.assertTrue((exp.numpy_full_matrix() == numpy.full((3, 2), 13)).all())
self.assertTrue(numpy.isnan(exp.numpy_nan()).all())
exp.builtin()
exp.async_in_try()