diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index ac06f8717..5c5f00139 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1411,7 +1411,7 @@ class ARTIQIRGenerator(algorithm.Visitor): operand = self.visit(node.operand) if builtins.is_array(operand.type): shape = self.append(ir.GetAttr(operand, "shape")) - result = self._allocate_new_array(node.type.find()["elt"], shape) + result, _ = self._allocate_new_array(node.type.find()["elt"], shape) func = self._get_array_unaryop("USub", make_sub, node.type, operand.type) self._invoke_arrayop(func, [result, operand]) return result @@ -1431,7 +1431,7 @@ class ARTIQIRGenerator(algorithm.Visitor): if builtins.is_array(node.type): result_elt = node.type.find()["elt"] shape = self.append(ir.GetAttr(value, "shape")) - result = self._allocate_new_array(result_elt, shape) + result, _ = self._allocate_new_array(result_elt, shape) func = self._get_array_unaryop( "Coerce", lambda v: self.append(ir.Coerce(v, result_elt)), node.type, value.type) @@ -1455,7 +1455,7 @@ class ARTIQIRGenerator(algorithm.Visitor): total_length = self._get_total_array_len(shape) buffer = self.append(ir.Alloc([total_length], types._TPointer(elt=elt))) result_type = builtins.TArray(elt, types.TValue(len(shape.type.elts))) - return self.append(ir.Alloc([buffer, shape], result_type)) + return self.append(ir.Alloc([buffer, shape], result_type)), total_length def _make_array_binop(self, name, result_type, lhs_type, rhs_type, body_gen): try: @@ -1704,7 +1704,7 @@ class ARTIQIRGenerator(algorithm.Visitor): elt = final_type["elt"] result_dims = final_type["num_dims"].value - result = self._allocate_new_array(elt, result_shape) + result, _ = self._allocate_new_array(elt, result_shape) func = self._get_matmult(result.type, left.type, right.type) self._invoke_arrayop(func, [result, lhs, rhs]) @@ -1745,7 +1745,7 @@ class ARTIQIRGenerator(algorithm.Visitor): ir.Constant("operands could not be broadcast together", builtins.TStr()))) - result = self._allocate_new_array(node.type.find()["elt"], shape) + result, _ = self._allocate_new_array(node.type.find()["elt"], shape) func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type) self._invoke_arrayop(func, [result, lhs, rhs]) return result @@ -2223,14 +2223,26 @@ class ARTIQIRGenerator(algorithm.Visitor): if len(node.args) == 2 and len(node.keywords) == 0: arg0, arg1 = map(self.visit, node.args) - result = self.append(ir.Alloc([arg0], node.type)) + num_dims = node.type.find()["num_dims"].value + if types.is_tuple(arg0.type): + lens = [self.append(ir.GetAttr(arg0, i)) for i in range(num_dims)] + else: + assert num_dims == 1 + lens = [arg0] + + shape = self._make_array_shape(lens) + result, total_len = self._allocate_new_array(node.type.find()["elt"], + shape) + def body_gen(index): self.append(ir.SetElem(result, index, arg1)) - return self.append(ir.Arith(ast.Add(loc=None), index, - ir.Constant(1, arg0.type))) - self._make_loop(ir.Constant(0, self._size_type), - lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, arg0)), - body_gen) + return self.append( + ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, self._size_type))) + + self._make_loop( + ir.Constant(0, self._size_type), lambda index: self.append( + ir.Compare(ast.Lt(loc=None), index, total_len)), body_gen) return result else: assert False @@ -2247,7 +2259,7 @@ class ARTIQIRGenerator(algorithm.Visitor): dim0 = self.append(ir.GetAttr(arg_shape, 0)) dim1 = self.append(ir.GetAttr(arg_shape, 1)) shape = self._make_array_shape([dim1, dim0]) - result = self._allocate_new_array(node.type.find()["elt"], shape) + result, _ = self._allocate_new_array(node.type.find()["elt"], shape) arg_buffer = self.append(ir.GetAttr(arg, "buffer")) result_buffer = self.append(ir.GetAttr(result, "buffer")) @@ -2413,7 +2425,7 @@ class ARTIQIRGenerator(algorithm.Visitor): node.arg_exprs) shape = self.append(ir.GetAttr(args[0], "shape")) - result = self._allocate_new_array(node.type.find()["elt"], shape) + result, _ = self._allocate_new_array(node.type.find()["elt"], shape) # TODO: Generate more generically if non-externals are allowed. name = node.func.type.find().name diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 81a4c21ec..e5cbc561b 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -1103,17 +1103,26 @@ class Inferencer(algorithm.Visitor): diagnose(valid_forms()) elif types.is_builtin(typ, "make_array"): valid_forms = lambda: [ - valid_form("numpy.full(count:int32, value:'a) -> numpy.array(elt='a)") + valid_form("numpy.full(count:int32, value:'a) -> array(elt='a, num_dims=1)"), + valid_form("numpy.full(shape:(int32,)*'b, value:'a) -> array(elt='a, num_dims='b)"), ] - self._unify(node.type, builtins.TArray(), - node.loc, None) - if len(node.args) == 2 and len(node.keywords) == 0: arg0, arg1 = node.args - self._unify(arg0.type, builtins.TInt32(), - arg0.loc, None) + if types.is_var(arg0.type): + return # undetermined yet + elif types.is_tuple(arg0.type): + num_dims = len(arg0.type.find().elts) + self._unify(arg0.type, types.TTuple([builtins.TInt32()] * num_dims), + arg0.loc, None) + else: + num_dims = 1 + self._unify(arg0.type, builtins.TInt32(), + arg0.loc, None) + + self._unify(node.type, builtins.TArray(num_dims=num_dims), + node.loc, None) self._unify(arg1.type, node.type.find()["elt"], arg1.loc, None) else: diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index fbd47a1a4..a2c6cf823 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -244,6 +244,10 @@ class _RPCCalls(EnvExperiment): def numpy_full(self): return numpy.full(10, 20) + @kernel + def numpy_full_matrix(self): + return numpy.full((3, 2), 13) + @kernel def numpy_nan(self): return numpy.full(10, numpy.nan) @@ -277,6 +281,7 @@ class RPCCallsTest(ExperimentCase): self.assertEqual(exp.numpy_things(), (numpy.int32(10), numpy.int64(20), numpy.array([42,]))) self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all()) + self.assertTrue((exp.numpy_full_matrix() == numpy.full((3, 2), 13)).all()) self.assertTrue(numpy.isnan(exp.numpy_nan()).all()) exp.builtin() exp.async_in_try()