diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index a64129dcc..77e102d3b 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -71,6 +71,10 @@ class TBytes(types.TMono): def __init__(self): super().__init__("bytes") +class TByteArray(types.TMono): + def __init__(self): + super().__init__("bytearray") + class TList(types.TMono): def __init__(self, elt=None): if elt is None: @@ -144,6 +148,9 @@ def fn_str(): def fn_bytes(): return types.TConstructor(TBytes()) +def fn_bytearray(): + return types.TConstructor(TByteArray()) + def fn_list(): return types.TConstructor(TList()) @@ -246,6 +253,9 @@ def is_str(typ): def is_bytes(typ): return types.is_mono(typ, "bytes") +def is_bytearray(typ): + return types.is_mono(typ, "bytearray") + def is_numeric(typ): typ = typ.find() return isinstance(typ, types.TMono) and \ @@ -267,7 +277,7 @@ def is_listish(typ, elt=None): if is_list(typ, elt) or is_array(typ, elt): return True elif elt is None: - return is_str(typ) or is_bytes(typ) + return is_str(typ) or is_bytes(typ) or is_bytearray(typ) else: return False @@ -288,7 +298,7 @@ def is_iterable(typ): return is_listish(typ) or is_range(typ) def get_iterable_elt(typ): - if is_str(typ) or is_bytes(typ): + if is_str(typ) or is_bytes(typ) or is_bytearray(typ): return TInt(types.TValue(8)) elif is_iterable(typ): return typ.find()["elt"].find() diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 052a5d737..6cee76bf9 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -203,6 +203,13 @@ class ASTSynthesizer: elif isinstance(value, bytes): return asttyped.StrT(s=value, ctx=None, type=builtins.TBytes(), loc=self._add(repr(value))) + elif isinstance(value, bytearray): + quote_loc = self._add('`') + repr_loc = self._add(repr(value)) + unquote_loc = self._add('`') + loc = quote_loc.join(unquote_loc) + + return asttyped.QuoteT(value=value, type=builtins.TByteArray(), loc=loc) elif isinstance(value, list): begin_loc = self._add("[") elts = [] diff --git a/artiq/compiler/prelude.py b/artiq/compiler/prelude.py index f451b8a62..24a7bd1fa 100644 --- a/artiq/compiler/prelude.py +++ b/artiq/compiler/prelude.py @@ -13,6 +13,7 @@ def globals(): "float": builtins.fn_float(), "str": builtins.fn_str(), "bytes": builtins.fn_bytes(), + "bytearray": builtins.fn_bytearray(), "list": builtins.fn_list(), "array": builtins.fn_array(), "range": builtins.fn_range(), diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 0d1d9c76d..1d2f6cbba 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1620,7 +1620,8 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Coerce(arg, node.type)) else: assert False - elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"): + elif (types.is_builtin(typ, "list") or types.is_builtin(typ, "array") or + types.is_builtin(typ, "bytearray")): if len(node.args) == 0 and len(node.keywords) == 0: length = ir.Constant(0, builtins.TInt32()) return self.append(ir.Alloc([length], node.type)) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index d6a23a36e..0003f214e 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -180,8 +180,8 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) def _unify_iterable(self, element, collection): - if builtins.is_bytes(collection.type): - self._unify(element.type, builtins.TInt(), + if builtins.is_bytes(collection.type) or builtins.is_bytearray(collection.type): + self._unify(element.type, builtins.get_iterable_elt(collection.type), element.loc, None) elif builtins.is_iterable(collection.type): rhs_type = collection.type.find() diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 84eddfc88..cbcd45e85 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -1436,8 +1436,8 @@ class LLVMIRGenerator: elif builtins.is_float(typ): assert isinstance(value, float), fail_msg return ll.Constant(llty, value) - elif builtins.is_str(typ) or builtins.is_bytes(typ): - assert isinstance(value, (str, bytes)), fail_msg + elif builtins.is_str(typ) or builtins.is_bytes(typ) or builtins.is_bytearray(typ): + assert isinstance(value, (str, bytes, bytearray)), fail_msg if isinstance(value, str): as_bytes = value.encode("utf-8") else: diff --git a/artiq/test/lit/inferencer/unify.py b/artiq/test/lit/inferencer/unify.py index 5ccbc6b70..e4fea57fe 100644 --- a/artiq/test/lit/inferencer/unify.py +++ b/artiq/test/lit/inferencer/unify.py @@ -60,6 +60,9 @@ k = "x" ka = b"x" # CHECK-L: ka:bytes +kb = bytearray(b"x") +# CHECK-L: kb:bytearray + l = array([1]) # CHECK-L: l:numpy.array(elt=numpy.int?) diff --git a/artiq/test/lit/integration/subscript.py b/artiq/test/lit/integration/subscript.py index f0398be4c..db50809e6 100644 --- a/artiq/test/lit/integration/subscript.py +++ b/artiq/test/lit/integration/subscript.py @@ -21,3 +21,7 @@ assert lst == [1, 0, 2, 0, 3] byt = b"abc" assert byt[0] == 97 assert byt[1] == 98 + +barr = bytearray(b"abc") +assert barr[0] == 97 +assert barr[1] == 98