diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index 98f111c2d..ea94f5651 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -25,6 +25,8 @@ Highlights: support legacy installations, but may be removed in a future release. * Added channel names to RTIO errors. * Full Python 3.10 support. +* Python's built-in types (such as `float`, or `List[...]`) can now be used in type annotations on + kernel functions. * Distributed DMA is now supported, allowing DMA to be run directly on satellites for corresponding RTIO events, increasing bandwidth in scenarios with heavy satellite usage. * API extensions have been implemented, enabling applets to directly modify datasets. diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 502b364b9..040fc80ee 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -5,6 +5,7 @@ the references to the host objects and translates the functions annotated as ``@kernel`` when they are referenced. """ +import typing import os, re, linecache, inspect, textwrap, types as pytypes, numpy from collections import OrderedDict, defaultdict @@ -1071,9 +1072,6 @@ class Stitcher: return function_node def _extract_annot(self, function, annot, kind, call_loc, fn_kind): - if annot is None: - annot = builtins.TNone() - if isinstance(function, SpecializedFunction): host_function = function.host_function else: @@ -1087,9 +1085,20 @@ class Stitcher: if isinstance(embedded_function, str): embedded_function = host_function + return self._to_artiq_type( + annot, + function=function, + kind=kind, + eval_in_scope=lambda x: eval(x, embedded_function.__globals__), + call_loc=call_loc, + fn_kind=fn_kind) + + def _to_artiq_type( + self, annot, *, function, kind: str, eval_in_scope, call_loc: str, fn_kind: str + ) -> types.Type: if isinstance(annot, str): try: - annot = eval(annot, embedded_function.__globals__) + annot = eval_in_scope(annot) except Exception: diag = diagnostic.Diagnostic( "error", @@ -1099,18 +1108,68 @@ class Stitcher: notes=self._call_site_note(call_loc, fn_kind)) self.engine.process(diag) - if not isinstance(annot, types.Type): - diag = diagnostic.Diagnostic("error", - "type annotation for {kind}, '{annot}', is not an ARTIQ type", - {"kind": kind, "annot": repr(annot)}, - self._function_loc(function), - notes=self._call_site_note(call_loc, fn_kind)) - self.engine.process(diag) - - return types.TVar() - else: + if isinstance(annot, types.Type): return annot + # Convert built-in Python types to ARTIQ ones. + if annot is None: + return builtins.TNone() + elif annot is numpy.int64: + return builtins.TInt64() + elif annot is numpy.int32: + return builtins.TInt32() + elif annot is float: + return builtins.TFloat() + elif annot is bool: + return builtins.TBool() + elif annot is str: + return builtins.TStr() + elif annot is bytes: + return builtins.TBytes() + elif annot is bytearray: + return builtins.TByteArray() + + # Convert generic Python types to ARTIQ ones. + generic_ty = typing.get_origin(annot) + if generic_ty is not None: + type_args = typing.get_args(annot) + artiq_args = [ + self._to_artiq_type( + x, + function=function, + kind=kind, + eval_in_scope=eval_in_scope, + call_loc=call_loc, + fn_kind=fn_kind) + for x in type_args + ] + + if generic_ty is list and len(artiq_args) == 1: + return builtins.TList(artiq_args[0]) + elif generic_ty is tuple: + return types.TTuple(artiq_args) + + # Otherwise report an unknown type and just use a fresh tyvar. + + if annot is int: + message = ( + "type annotation for {kind}, 'int' cannot be used as an ARTIQ type. " + "Use numpy's int32 or int64 instead." + ) + ty = builtins.TInt() + else: + message = "type annotation for {kind}, '{annot}', is not an ARTIQ type" + ty = types.TVar() + + diag = diagnostic.Diagnostic("error", + message, + {"kind": kind, "annot": repr(annot)}, + self._function_loc(function), + notes=self._call_site_note(call_loc, fn_kind)) + self.engine.process(diag) + + return ty + def _quote_syscall(self, function, loc): signature = inspect.signature(function) diff --git a/artiq/test/lit/embedding/annotation_py.py b/artiq/test/lit/embedding/annotation_py.py new file mode 100644 index 000000000..c790b6914 --- /dev/null +++ b/artiq/test/lit/embedding/annotation_py.py @@ -0,0 +1,34 @@ +# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s +# RUN: OutputCheck %s --file-to-check=%t.ll + +from typing import List, Tuple + +import numpy as np + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: i64 @_Z13testbench.foozz(i64 %ARG.x, { i1, i32 } %ARG.y) + +@kernel +def foo(x: np.int64, y: np.int32 = 1) -> np.int64: + print(x + y) + return x + y + +# CHECK-L: void @_Z13testbench.barzz() +@kernel +def bar(x: np.int32) -> None: + print(x) + +# CHECK-L: @_Z21testbench.unpack_listzz({ i1, i64 }* nocapture writeonly sret({ i1, i64 }) %.1, { i64*, i32 }* %ARG.xs) +@kernel +def unpack_list(xs: List[np.int64]) -> Tuple[bool, np.int64]: + print(xs) + return (len(xs) == 1, xs[0]) + +@kernel +def entrypoint(): + print(foo(0, 2)) + print(foo(1, 3)) + bar(3) + print(unpack_list([1, 2, 3])) diff --git a/artiq/test/lit/embedding/error_specialized_annot.py b/artiq/test/lit/embedding/error_specialized_annot.py index 2f5955043..5d901d0d4 100644 --- a/artiq/test/lit/embedding/error_specialized_annot.py +++ b/artiq/test/lit/embedding/error_specialized_annot.py @@ -4,14 +4,14 @@ from artiq.experiment import * class c(): -# CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '', is not an ARTIQ type +# CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '', is not an ARTIQ type @kernel - def hello(self, x: float): + def hello(self, x: list): pass @kernel def run(self): - self.hello(2) + self.hello([]) i = c() @kernel