From b00ba5ece11cf1c04306621d24523ece794a0c41 Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Sun, 9 Aug 2020 02:44:54 +0100 Subject: [PATCH] =?UTF-8?q?compiler:=20Support=20explicit=20array(?= =?UTF-8?q?=E2=80=A6,=20dtype=3D=E2=80=A6)=20syntax?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../compiler/transforms/artiq_ir_generator.py | 5 +-- artiq/compiler/transforms/inferencer.py | 33 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 240dd4957..9d5201800 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -2082,7 +2082,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False elif types.is_builtin(typ, "array"): - if len(node.args) == 1 and len(node.keywords) == 0: + if len(node.args) == 1 and len(node.keywords) in (0, 1): result_type = node.type.find() arg = self.visit(node.args[0]) @@ -2111,7 +2111,8 @@ class ARTIQIRGenerator(algorithm.Visitor): def assign_elems(outer_indices, indexed_arg): if len(outer_indices) == num_dims: dest_idx = self._get_array_offset(lengths, outer_indices) - self.append(ir.SetElem(buffer, dest_idx, indexed_arg)) + coerced = self.append(ir.Coerce(indexed_arg, result_elt)) + self.append(ir.SetElem(buffer, dest_idx, coerced)) else: this_level_len = self.iterable_len(indexed_arg) dim_idx = len(outer_indices) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index ef982c3ce..8ac5fab15 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -832,10 +832,19 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) elif types.is_builtin(typ, "array"): valid_forms = lambda: [ - valid_form("array(x:'a) -> array(elt='b) where 'a is iterable") + valid_form("array(x:'a) -> array(elt='b) where 'a is iterable"), + valid_form("array(x:'a, dtype:'b) -> array(elt='b) where 'a is iterable") ] - if len(node.args) == 1 and len(node.keywords) == 0: + explicit_dtype = None + keywords_acceptable = False + if len(node.keywords) == 0: + keywords_acceptable = True + elif len(node.keywords) == 1: + if node.keywords[0].arg == "dtype": + keywords_acceptable = True + explicit_dtype = node.keywords[0].value + if len(node.args) == 1 and keywords_acceptable: arg, = node.args # In the absence of any other information (there currently isn't a way @@ -858,6 +867,26 @@ class Inferencer(algorithm.Visitor): num_dims += 1 elt = builtins.get_iterable_elt(elt) + if explicit_dtype is not None: + # TODO: Factor out type detection; support quoted type constructors + # (TList(TInt32), …)? + typ = explicit_dtype.type + if types.is_builtin(typ, "int32"): + elt = builtins.TInt32() + elif types.is_builtin(typ, "int64"): + elt = builtins.TInt64() + elif types.is_constructor(typ): + elt = typ.find().instance + else: + diag = diagnostic.Diagnostic( + "error", + "dtype argument of {builtin}() must be a valid constructor", + {"builtin": typ.find().name}, + node.func.loc, + notes=[note]) + self.engine.process(diag) + return + if num_dims == 0: note = diagnostic.Diagnostic( "note", "this expression has type {type}",