forked from M-Labs/artiq
1
0
Fork 0

Allow using Python types in type annotations

This maps basic Python types (float, str, bool, np.int32, np.int64) as well as
some generics (list, tuple) to ARTIQ's own type instances.

Signed-off-by: Jonathan Coates <jonathan.coates@oxionics.com>
This commit is contained in:
Jonathan Coates 2023-09-22 11:01:40 +01:00 committed by David Nadlinger
parent 586d97c6cb
commit 6eb81494c5
4 changed files with 112 additions and 17 deletions

View File

@ -25,6 +25,8 @@ Highlights:
support legacy installations, but may be removed in a future release. support legacy installations, but may be removed in a future release.
* Added channel names to RTIO errors. * Added channel names to RTIO errors.
* Full Python 3.10 support. * Full Python 3.10 support.
* Python's built-in types (such as `float`, or `List[...]`) can now be used in type annotations on
kernel functions.
* Distributed DMA is now supported, allowing DMA to be run directly on satellites for corresponding * Distributed DMA is now supported, allowing DMA to be run directly on satellites for corresponding
RTIO events, increasing bandwidth in scenarios with heavy satellite usage. RTIO events, increasing bandwidth in scenarios with heavy satellite usage.
* API extensions have been implemented, enabling applets to directly modify datasets. * API extensions have been implemented, enabling applets to directly modify datasets.

View File

@ -5,6 +5,7 @@ the references to the host objects and translates the functions
annotated as ``@kernel`` when they are referenced. annotated as ``@kernel`` when they are referenced.
""" """
import typing
import os, re, linecache, inspect, textwrap, types as pytypes, numpy import os, re, linecache, inspect, textwrap, types as pytypes, numpy
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
@ -1071,9 +1072,6 @@ class Stitcher:
return function_node return function_node
def _extract_annot(self, function, annot, kind, call_loc, fn_kind): def _extract_annot(self, function, annot, kind, call_loc, fn_kind):
if annot is None:
annot = builtins.TNone()
if isinstance(function, SpecializedFunction): if isinstance(function, SpecializedFunction):
host_function = function.host_function host_function = function.host_function
else: else:
@ -1087,9 +1085,20 @@ class Stitcher:
if isinstance(embedded_function, str): if isinstance(embedded_function, str):
embedded_function = host_function embedded_function = host_function
return self._to_artiq_type(
annot,
function=function,
kind=kind,
eval_in_scope=lambda x: eval(x, embedded_function.__globals__),
call_loc=call_loc,
fn_kind=fn_kind)
def _to_artiq_type(
self, annot, *, function, kind: str, eval_in_scope, call_loc: str, fn_kind: str
) -> types.Type:
if isinstance(annot, str): if isinstance(annot, str):
try: try:
annot = eval(annot, embedded_function.__globals__) annot = eval_in_scope(annot)
except Exception: except Exception:
diag = diagnostic.Diagnostic( diag = diagnostic.Diagnostic(
"error", "error",
@ -1099,17 +1108,67 @@ class Stitcher:
notes=self._call_site_note(call_loc, fn_kind)) notes=self._call_site_note(call_loc, fn_kind))
self.engine.process(diag) self.engine.process(diag)
if not isinstance(annot, types.Type): if isinstance(annot, types.Type):
return annot
# Convert built-in Python types to ARTIQ ones.
if annot is None:
return builtins.TNone()
elif annot is numpy.int64:
return builtins.TInt64()
elif annot is numpy.int32:
return builtins.TInt32()
elif annot is float:
return builtins.TFloat()
elif annot is bool:
return builtins.TBool()
elif annot is str:
return builtins.TStr()
elif annot is bytes:
return builtins.TBytes()
elif annot is bytearray:
return builtins.TByteArray()
# Convert generic Python types to ARTIQ ones.
generic_ty = typing.get_origin(annot)
if generic_ty is not None:
type_args = typing.get_args(annot)
artiq_args = [
self._to_artiq_type(
x,
function=function,
kind=kind,
eval_in_scope=eval_in_scope,
call_loc=call_loc,
fn_kind=fn_kind)
for x in type_args
]
if generic_ty is list and len(artiq_args) == 1:
return builtins.TList(artiq_args[0])
elif generic_ty is tuple:
return types.TTuple(artiq_args)
# Otherwise report an unknown type and just use a fresh tyvar.
if annot is int:
message = (
"type annotation for {kind}, 'int' cannot be used as an ARTIQ type. "
"Use numpy's int32 or int64 instead."
)
ty = builtins.TInt()
else:
message = "type annotation for {kind}, '{annot}', is not an ARTIQ type"
ty = types.TVar()
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"type annotation for {kind}, '{annot}', is not an ARTIQ type", message,
{"kind": kind, "annot": repr(annot)}, {"kind": kind, "annot": repr(annot)},
self._function_loc(function), self._function_loc(function),
notes=self._call_site_note(call_loc, fn_kind)) notes=self._call_site_note(call_loc, fn_kind))
self.engine.process(diag) self.engine.process(diag)
return types.TVar() return ty
else:
return annot
def _quote_syscall(self, function, loc): def _quote_syscall(self, function, loc):
signature = inspect.signature(function) signature = inspect.signature(function)

View File

@ -0,0 +1,34 @@
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
# RUN: OutputCheck %s --file-to-check=%t.ll
from typing import List, Tuple
import numpy as np
from artiq.language.core import *
from artiq.language.types import *
# CHECK-L: i64 @_Z13testbench.foozz(i64 %ARG.x, { i1, i32 } %ARG.y)
@kernel
def foo(x: np.int64, y: np.int32 = 1) -> np.int64:
print(x + y)
return x + y
# CHECK-L: void @_Z13testbench.barzz()
@kernel
def bar(x: np.int32) -> None:
print(x)
# CHECK-L: @_Z21testbench.unpack_listzz({ i1, i64 }* nocapture writeonly sret({ i1, i64 }) %.1, { i64*, i32 }* %ARG.xs)
@kernel
def unpack_list(xs: List[np.int64]) -> Tuple[bool, np.int64]:
print(xs)
return (len(xs) == 1, xs[0])
@kernel
def entrypoint():
print(foo(0, 2))
print(foo(1, 3))
bar(3)
print(unpack_list([1, 2, 3]))

View File

@ -4,14 +4,14 @@
from artiq.experiment import * from artiq.experiment import *
class c(): class c():
# CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '<class 'float'>', is not an ARTIQ type # CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '<class 'list'>', is not an ARTIQ type
@kernel @kernel
def hello(self, x: float): def hello(self, x: list):
pass pass
@kernel @kernel
def run(self): def run(self):
self.hello(2) self.hello([])
i = c() i = c()
@kernel @kernel