forked from M-Labs/artiq
compiler: Fix numpy.full, implement for >1D
This commit is contained in:
parent
53d64d08a8
commit
778f2cf905
@ -1411,7 +1411,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
operand = self.visit(node.operand)
|
operand = self.visit(node.operand)
|
||||||
if builtins.is_array(operand.type):
|
if builtins.is_array(operand.type):
|
||||||
shape = self.append(ir.GetAttr(operand, "shape"))
|
shape = self.append(ir.GetAttr(operand, "shape"))
|
||||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||||
func = self._get_array_unaryop("USub", make_sub, node.type, operand.type)
|
func = self._get_array_unaryop("USub", make_sub, node.type, operand.type)
|
||||||
self._invoke_arrayop(func, [result, operand])
|
self._invoke_arrayop(func, [result, operand])
|
||||||
return result
|
return result
|
||||||
@ -1431,7 +1431,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
if builtins.is_array(node.type):
|
if builtins.is_array(node.type):
|
||||||
result_elt = node.type.find()["elt"]
|
result_elt = node.type.find()["elt"]
|
||||||
shape = self.append(ir.GetAttr(value, "shape"))
|
shape = self.append(ir.GetAttr(value, "shape"))
|
||||||
result = self._allocate_new_array(result_elt, shape)
|
result, _ = self._allocate_new_array(result_elt, shape)
|
||||||
func = self._get_array_unaryop(
|
func = self._get_array_unaryop(
|
||||||
"Coerce", lambda v: self.append(ir.Coerce(v, result_elt)),
|
"Coerce", lambda v: self.append(ir.Coerce(v, result_elt)),
|
||||||
node.type, value.type)
|
node.type, value.type)
|
||||||
@ -1455,7 +1455,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
total_length = self._get_total_array_len(shape)
|
total_length = self._get_total_array_len(shape)
|
||||||
buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt)))
|
buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt)))
|
||||||
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
|
result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts)))
|
||||||
return self.append(ir.Alloc([buffer, shape], result_type))
|
return self.append(ir.Alloc([buffer, shape], result_type)), total_length
|
||||||
|
|
||||||
def _make_array_binop(self, name, result_type, lhs_type, rhs_type, body_gen):
|
def _make_array_binop(self, name, result_type, lhs_type, rhs_type, body_gen):
|
||||||
try:
|
try:
|
||||||
@ -1704,7 +1704,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
elt = final_type["elt"]
|
elt = final_type["elt"]
|
||||||
result_dims = final_type["num_dims"].value
|
result_dims = final_type["num_dims"].value
|
||||||
|
|
||||||
result = self._allocate_new_array(elt, result_shape)
|
result, _ = self._allocate_new_array(elt, result_shape)
|
||||||
func = self._get_matmult(result.type, left.type, right.type)
|
func = self._get_matmult(result.type, left.type, right.type)
|
||||||
self._invoke_arrayop(func, [result, lhs, rhs])
|
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||||
|
|
||||||
@ -1745,7 +1745,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
ir.Constant("operands could not be broadcast together",
|
ir.Constant("operands could not be broadcast together",
|
||||||
builtins.TStr())))
|
builtins.TStr())))
|
||||||
|
|
||||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||||
func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type)
|
func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type)
|
||||||
self._invoke_arrayop(func, [result, lhs, rhs])
|
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||||
return result
|
return result
|
||||||
@ -2223,14 +2223,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
if len(node.args) == 2 and len(node.keywords) == 0:
|
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||||
arg0, arg1 = map(self.visit, node.args)
|
arg0, arg1 = map(self.visit, node.args)
|
||||||
|
|
||||||
result = self.append(ir.Alloc([arg0], node.type))
|
num_dims = node.type.find()["num_dims"].value
|
||||||
|
if types.is_tuple(arg0.type):
|
||||||
|
lens = [self.append(ir.GetAttr(arg0, i)) for i in range(num_dims)]
|
||||||
|
else:
|
||||||
|
assert num_dims == 1
|
||||||
|
lens = [arg0]
|
||||||
|
|
||||||
|
shape = self._make_array_shape(lens)
|
||||||
|
result, total_len = self._allocate_new_array(node.type.find()["elt"],
|
||||||
|
shape)
|
||||||
|
|
||||||
def body_gen(index):
|
def body_gen(index):
|
||||||
self.append(ir.SetElem(result, index, arg1))
|
self.append(ir.SetElem(result, index, arg1))
|
||||||
return self.append(ir.Arith(ast.Add(loc=None), index,
|
return self.append(
|
||||||
ir.Constant(1, arg0.type)))
|
ir.Arith(ast.Add(loc=None), index,
|
||||||
self._make_loop(ir.Constant(0, self._size_type),
|
ir.Constant(1, self._size_type)))
|
||||||
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, arg0)),
|
|
||||||
body_gen)
|
self._make_loop(
|
||||||
|
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||||
|
ir.Compare(ast.Lt(loc=None), index, total_len)), body_gen)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
@ -2247,7 +2259,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
dim0 = self.append(ir.GetAttr(arg_shape, 0))
|
dim0 = self.append(ir.GetAttr(arg_shape, 0))
|
||||||
dim1 = self.append(ir.GetAttr(arg_shape, 1))
|
dim1 = self.append(ir.GetAttr(arg_shape, 1))
|
||||||
shape = self._make_array_shape([dim1, dim0])
|
shape = self._make_array_shape([dim1, dim0])
|
||||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||||
arg_buffer = self.append(ir.GetAttr(arg, "buffer"))
|
arg_buffer = self.append(ir.GetAttr(arg, "buffer"))
|
||||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||||
|
|
||||||
@ -2413,7 +2425,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
node.arg_exprs)
|
node.arg_exprs)
|
||||||
|
|
||||||
shape = self.append(ir.GetAttr(args[0], "shape"))
|
shape = self.append(ir.GetAttr(args[0], "shape"))
|
||||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
result, _ = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||||
|
|
||||||
# TODO: Generate more generically if non-externals are allowed.
|
# TODO: Generate more generically if non-externals are allowed.
|
||||||
name = node.func.type.find().name
|
name = node.func.type.find().name
|
||||||
|
@ -1103,17 +1103,26 @@ class Inferencer(algorithm.Visitor):
|
|||||||
diagnose(valid_forms())
|
diagnose(valid_forms())
|
||||||
elif types.is_builtin(typ, "make_array"):
|
elif types.is_builtin(typ, "make_array"):
|
||||||
valid_forms = lambda: [
|
valid_forms = lambda: [
|
||||||
valid_form("numpy.full(count:int32, value:'a) -> numpy.array(elt='a)")
|
valid_form("numpy.full(count:int32, value:'a) -> array(elt='a, num_dims=1)"),
|
||||||
|
valid_form("numpy.full(shape:(int32,)*'b, value:'a) -> array(elt='a, num_dims='b)"),
|
||||||
]
|
]
|
||||||
|
|
||||||
self._unify(node.type, builtins.TArray(),
|
|
||||||
node.loc, None)
|
|
||||||
|
|
||||||
if len(node.args) == 2 and len(node.keywords) == 0:
|
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||||
arg0, arg1 = node.args
|
arg0, arg1 = node.args
|
||||||
|
|
||||||
|
if types.is_var(arg0.type):
|
||||||
|
return # undetermined yet
|
||||||
|
elif types.is_tuple(arg0.type):
|
||||||
|
num_dims = len(arg0.type.find().elts)
|
||||||
|
self._unify(arg0.type, types.TTuple([builtins.TInt32()] * num_dims),
|
||||||
|
arg0.loc, None)
|
||||||
|
else:
|
||||||
|
num_dims = 1
|
||||||
self._unify(arg0.type, builtins.TInt32(),
|
self._unify(arg0.type, builtins.TInt32(),
|
||||||
arg0.loc, None)
|
arg0.loc, None)
|
||||||
|
|
||||||
|
self._unify(node.type, builtins.TArray(num_dims=num_dims),
|
||||||
|
node.loc, None)
|
||||||
self._unify(arg1.type, node.type.find()["elt"],
|
self._unify(arg1.type, node.type.find()["elt"],
|
||||||
arg1.loc, None)
|
arg1.loc, None)
|
||||||
else:
|
else:
|
||||||
|
@ -244,6 +244,10 @@ class _RPCCalls(EnvExperiment):
|
|||||||
def numpy_full(self):
|
def numpy_full(self):
|
||||||
return numpy.full(10, 20)
|
return numpy.full(10, 20)
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def numpy_full_matrix(self):
|
||||||
|
return numpy.full((3, 2), 13)
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def numpy_nan(self):
|
def numpy_nan(self):
|
||||||
return numpy.full(10, numpy.nan)
|
return numpy.full(10, numpy.nan)
|
||||||
@ -277,6 +281,7 @@ class RPCCallsTest(ExperimentCase):
|
|||||||
self.assertEqual(exp.numpy_things(),
|
self.assertEqual(exp.numpy_things(),
|
||||||
(numpy.int32(10), numpy.int64(20), numpy.array([42,])))
|
(numpy.int32(10), numpy.int64(20), numpy.array([42,])))
|
||||||
self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all())
|
self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all())
|
||||||
|
self.assertTrue((exp.numpy_full_matrix() == numpy.full((3, 2), 13)).all())
|
||||||
self.assertTrue(numpy.isnan(exp.numpy_nan()).all())
|
self.assertTrue(numpy.isnan(exp.numpy_nan()).all())
|
||||||
exp.builtin()
|
exp.builtin()
|
||||||
exp.async_in_try()
|
exp.async_in_try()
|
||||||
|
Loading…
Reference in New Issue
Block a user