compiler: Support explicit array(…, dtype=…) syntax

pull/1508/head
David Nadlinger 2020-08-09 02:44:54 +01:00
parent ad34df3de1
commit b00ba5ece1
2 changed files with 34 additions and 4 deletions

View File

@ -2082,7 +2082,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else:
assert False
elif types.is_builtin(typ, "array"):
if len(node.args) == 1 and len(node.keywords) == 0:
if len(node.args) == 1 and len(node.keywords) in (0, 1):
result_type = node.type.find()
arg = self.visit(node.args[0])
@ -2111,7 +2111,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
def assign_elems(outer_indices, indexed_arg):
if len(outer_indices) == num_dims:
dest_idx = self._get_array_offset(lengths, outer_indices)
self.append(ir.SetElem(buffer, dest_idx, indexed_arg))
coerced = self.append(ir.Coerce(indexed_arg, result_elt))
self.append(ir.SetElem(buffer, dest_idx, coerced))
else:
this_level_len = self.iterable_len(indexed_arg)
dim_idx = len(outer_indices)

View File

@ -832,10 +832,19 @@ class Inferencer(algorithm.Visitor):
self.engine.process(diag)
elif types.is_builtin(typ, "array"):
valid_forms = lambda: [
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable"),
valid_form("array(x:'a, dtype:'b) -> array(elt='b) where 'a is iterable")
]
if len(node.args) == 1 and len(node.keywords) == 0:
explicit_dtype = None
keywords_acceptable = False
if len(node.keywords) == 0:
keywords_acceptable = True
elif len(node.keywords) == 1:
if node.keywords[0].arg == "dtype":
keywords_acceptable = True
explicit_dtype = node.keywords[0].value
if len(node.args) == 1 and keywords_acceptable:
arg, = node.args
# In the absence of any other information (there currently isn't a way
@ -858,6 +867,26 @@ class Inferencer(algorithm.Visitor):
num_dims += 1
elt = builtins.get_iterable_elt(elt)
if explicit_dtype is not None:
# TODO: Factor out type detection; support quoted type constructors
# (TList(TInt32), …)?
typ = explicit_dtype.type
if types.is_builtin(typ, "int32"):
elt = builtins.TInt32()
elif types.is_builtin(typ, "int64"):
elt = builtins.TInt64()
elif types.is_constructor(typ):
elt = typ.find().instance
else:
diag = diagnostic.Diagnostic(
"error",
"dtype argument of {builtin}() must be a valid constructor",
{"builtin": typ.find().name},
node.func.loc,
notes=[note])
self.engine.process(diag)
return
if num_dims == 0:
note = diagnostic.Diagnostic(
"note", "this expression has type {type}",