forked from M-Labs/artiq
compiler: Fix numpy.full, implement for >1D
This commit is contained in:
parent
53d64d08a8
commit
778f2cf905
@ -1411,7 +1411,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
operand = self.visit(node.operand)
|
||||
if builtins.is_array(operand.type):
|
||||
shape = self.append(ir.GetAttr(operand, "shape"))
|
||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
func = self._get_array_unaryop("USub", make_sub, node.type, operand.type)
|
||||
self._invoke_arrayop(func, [result, operand])
|
||||
return result
|
||||
@ -1431,7 +1431,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
if builtins.is_array(node.type):
|
||||
result_elt = node.type.find()["elt"]
|
||||
shape = self.append(ir.GetAttr(value, "shape"))
|
||||
result = self._allocate_new_array(result_elt, shape)
|
||||
result, _ = self._allocate_new_array(result_elt, shape)
|
||||
func = self._get_array_unaryop(
|
||||
"Coerce", lambda v: self.append(ir.Coerce(v, result_elt)),
|
||||
node.type, value.type)
|
||||
@ -1455,7 +1455,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
total_length = self._get_total_array_len(shape)
|
||||
buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt)))
|
||||
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
|
||||
return self.append(ir.Alloc([buffer, shape], result_type))
|
||||
return self.append(ir.Alloc([buffer, shape], result_type)), total_length
|
||||
|
||||
def _make_array_binop(self, name, result_type, lhs_type, rhs_type, body_gen):
|
||||
try:
|
||||
@ -1704,7 +1704,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
elt = final_type["elt"]
|
||||
result_dims = final_type["num_dims"].value
|
||||
|
||||
result = self._allocate_new_array(elt, result_shape)
|
||||
result, _ = self._allocate_new_array(elt, result_shape)
|
||||
func = self._get_matmult(result.type, left.type, right.type)
|
||||
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||
|
||||
@ -1745,7 +1745,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
ir.Constant("operands could not be broadcast together",
|
||||
builtins.TStr())))
|
||||
|
||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type)
|
||||
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||
return result
|
||||
@ -2223,14 +2223,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||
arg0, arg1 = map(self.visit, node.args)
|
||||
|
||||
result = self.append(ir.Alloc([arg0], node.type))
|
||||
num_dims = node.type.find()["num_dims"].value
|
||||
if types.is_tuple(arg0.type):
|
||||
lens = [self.append(ir.GetAttr(arg0, i)) for i in range(num_dims)]
|
||||
else:
|
||||
assert num_dims == 1
|
||||
lens = [arg0]
|
||||
|
||||
shape = self._make_array_shape(lens)
|
||||
result, total_len = self._allocate_new_array(node.type.find()["elt"],
|
||||
shape)
|
||||
|
||||
def body_gen(index):
|
||||
self.append(ir.SetElem(result, index, arg1))
|
||||
return self.append(ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, arg0.type)))
|
||||
self._make_loop(ir.Constant(0, self._size_type),
|
||||
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, arg0)),
|
||||
body_gen)
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, self._size_type)))
|
||||
|
||||
self._make_loop(
|
||||
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||
ir.Compare(ast.Lt(loc=None), index, total_len)), body_gen)
|
||||
return result
|
||||
else:
|
||||
assert False
|
||||
@ -2247,7 +2259,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
dim0 = self.append(ir.GetAttr(arg_shape, 0))
|
||||
dim1 = self.append(ir.GetAttr(arg_shape, 1))
|
||||
shape = self._make_array_shape([dim1, dim0])
|
||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
arg_buffer = self.append(ir.GetAttr(arg, "buffer"))
|
||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||
|
||||
@ -2413,7 +2425,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
node.arg_exprs)
|
||||
|
||||
shape = self.append(ir.GetAttr(args[0], "shape"))
|
||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
|
||||
# TODO: Generate more generically if non-externals are allowed.
|
||||
name = node.func.type.find().name
|
||||
|
@ -1103,17 +1103,26 @@ class Inferencer(algorithm.Visitor):
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "make_array"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("numpy.full(count:int32, value:'a) -> numpy.array(elt='a)")
|
||||
valid_form("numpy.full(count:int32, value:'a) -> array(elt='a, num_dims=1)"),
|
||||
valid_form("numpy.full(shape:(int32,)*'b, value:'a) -> array(elt='a, num_dims='b)"),
|
||||
]
|
||||
|
||||
self._unify(node.type, builtins.TArray(),
|
||||
node.loc, None)
|
||||
|
||||
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||
arg0, arg1 = node.args
|
||||
|
||||
if types.is_var(arg0.type):
|
||||
return # undetermined yet
|
||||
elif types.is_tuple(arg0.type):
|
||||
num_dims = len(arg0.type.find().elts)
|
||||
self._unify(arg0.type, types.TTuple([builtins.TInt32()] * num_dims),
|
||||
arg0.loc, None)
|
||||
else:
|
||||
num_dims = 1
|
||||
self._unify(arg0.type, builtins.TInt32(),
|
||||
arg0.loc, None)
|
||||
|
||||
self._unify(node.type, builtins.TArray(num_dims=num_dims),
|
||||
node.loc, None)
|
||||
self._unify(arg1.type, node.type.find()["elt"],
|
||||
arg1.loc, None)
|
||||
else:
|
||||
|
@ -244,6 +244,10 @@ class _RPCCalls(EnvExperiment):
|
||||
def numpy_full(self):
|
||||
return numpy.full(10, 20)
|
||||
|
||||
@kernel
|
||||
def numpy_full_matrix(self):
|
||||
return numpy.full((3, 2), 13)
|
||||
|
||||
@kernel
|
||||
def numpy_nan(self):
|
||||
return numpy.full(10, numpy.nan)
|
||||
@ -277,6 +281,7 @@ class RPCCallsTest(ExperimentCase):
|
||||
self.assertEqual(exp.numpy_things(),
|
||||
(numpy.int32(10), numpy.int64(20), numpy.array([42,])))
|
||||
self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all())
|
||||
self.assertTrue((exp.numpy_full_matrix() == numpy.full((3, 2), 13)).all())
|
||||
self.assertTrue(numpy.isnan(exp.numpy_nan()).all())
|
||||
exp.builtin()
|
||||
exp.async_in_try()
|
||||
|
Loading…
Reference in New Issue
Block a user