mirror of https://github.com/m-labs/artiq.git
Allow using Python types in type annotations
This maps basic Python types (float, str, bool, np.int32, np.int64) as well as some generics (list, tuple) to ARTIQ's own type instances. Signed-off-by: Jonathan Coates <jonathan.coates@oxionics.com>
This commit is contained in:
parent
586d97c6cb
commit
6eb81494c5
|
@ -25,6 +25,8 @@ Highlights:
|
|||
support legacy installations, but may be removed in a future release.
|
||||
* Added channel names to RTIO errors.
|
||||
* Full Python 3.10 support.
|
||||
* Python's built-in types (such as `float`, or `List[...]`) can now be used in type annotations on
|
||||
kernel functions.
|
||||
* Distributed DMA is now supported, allowing DMA to be run directly on satellites for corresponding
|
||||
RTIO events, increasing bandwidth in scenarios with heavy satellite usage.
|
||||
* API extensions have been implemented, enabling applets to directly modify datasets.
|
||||
|
|
|
@ -5,6 +5,7 @@ the references to the host objects and translates the functions
|
|||
annotated as ``@kernel`` when they are referenced.
|
||||
"""
|
||||
|
||||
import typing
|
||||
import os, re, linecache, inspect, textwrap, types as pytypes, numpy
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
@ -1071,9 +1072,6 @@ class Stitcher:
|
|||
return function_node
|
||||
|
||||
def _extract_annot(self, function, annot, kind, call_loc, fn_kind):
|
||||
if annot is None:
|
||||
annot = builtins.TNone()
|
||||
|
||||
if isinstance(function, SpecializedFunction):
|
||||
host_function = function.host_function
|
||||
else:
|
||||
|
@ -1087,9 +1085,20 @@ class Stitcher:
|
|||
if isinstance(embedded_function, str):
|
||||
embedded_function = host_function
|
||||
|
||||
return self._to_artiq_type(
|
||||
annot,
|
||||
function=function,
|
||||
kind=kind,
|
||||
eval_in_scope=lambda x: eval(x, embedded_function.__globals__),
|
||||
call_loc=call_loc,
|
||||
fn_kind=fn_kind)
|
||||
|
||||
def _to_artiq_type(
|
||||
self, annot, *, function, kind: str, eval_in_scope, call_loc: str, fn_kind: str
|
||||
) -> types.Type:
|
||||
if isinstance(annot, str):
|
||||
try:
|
||||
annot = eval(annot, embedded_function.__globals__)
|
||||
annot = eval_in_scope(annot)
|
||||
except Exception:
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
|
@ -1099,18 +1108,68 @@ class Stitcher:
|
|||
notes=self._call_site_note(call_loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
|
||||
if not isinstance(annot, types.Type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type annotation for {kind}, '{annot}', is not an ARTIQ type",
|
||||
{"kind": kind, "annot": repr(annot)},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(call_loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
|
||||
return types.TVar()
|
||||
else:
|
||||
if isinstance(annot, types.Type):
|
||||
return annot
|
||||
|
||||
# Convert built-in Python types to ARTIQ ones.
|
||||
if annot is None:
|
||||
return builtins.TNone()
|
||||
elif annot is numpy.int64:
|
||||
return builtins.TInt64()
|
||||
elif annot is numpy.int32:
|
||||
return builtins.TInt32()
|
||||
elif annot is float:
|
||||
return builtins.TFloat()
|
||||
elif annot is bool:
|
||||
return builtins.TBool()
|
||||
elif annot is str:
|
||||
return builtins.TStr()
|
||||
elif annot is bytes:
|
||||
return builtins.TBytes()
|
||||
elif annot is bytearray:
|
||||
return builtins.TByteArray()
|
||||
|
||||
# Convert generic Python types to ARTIQ ones.
|
||||
generic_ty = typing.get_origin(annot)
|
||||
if generic_ty is not None:
|
||||
type_args = typing.get_args(annot)
|
||||
artiq_args = [
|
||||
self._to_artiq_type(
|
||||
x,
|
||||
function=function,
|
||||
kind=kind,
|
||||
eval_in_scope=eval_in_scope,
|
||||
call_loc=call_loc,
|
||||
fn_kind=fn_kind)
|
||||
for x in type_args
|
||||
]
|
||||
|
||||
if generic_ty is list and len(artiq_args) == 1:
|
||||
return builtins.TList(artiq_args[0])
|
||||
elif generic_ty is tuple:
|
||||
return types.TTuple(artiq_args)
|
||||
|
||||
# Otherwise report an unknown type and just use a fresh tyvar.
|
||||
|
||||
if annot is int:
|
||||
message = (
|
||||
"type annotation for {kind}, 'int' cannot be used as an ARTIQ type. "
|
||||
"Use numpy's int32 or int64 instead."
|
||||
)
|
||||
ty = builtins.TInt()
|
||||
else:
|
||||
message = "type annotation for {kind}, '{annot}', is not an ARTIQ type"
|
||||
ty = types.TVar()
|
||||
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
message,
|
||||
{"kind": kind, "annot": repr(annot)},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(call_loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
|
||||
return ty
|
||||
|
||||
def _quote_syscall(self, function, loc):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
|
||||
# RUN: OutputCheck %s --file-to-check=%t.ll
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
|
||||
# CHECK-L: i64 @_Z13testbench.foozz(i64 %ARG.x, { i1, i32 } %ARG.y)
|
||||
|
||||
@kernel
|
||||
def foo(x: np.int64, y: np.int32 = 1) -> np.int64:
|
||||
print(x + y)
|
||||
return x + y
|
||||
|
||||
# CHECK-L: void @_Z13testbench.barzz()
|
||||
@kernel
|
||||
def bar(x: np.int32) -> None:
|
||||
print(x)
|
||||
|
||||
# CHECK-L: @_Z21testbench.unpack_listzz({ i1, i64 }* nocapture writeonly sret({ i1, i64 }) %.1, { i64*, i32 }* %ARG.xs)
|
||||
@kernel
|
||||
def unpack_list(xs: List[np.int64]) -> Tuple[bool, np.int64]:
|
||||
print(xs)
|
||||
return (len(xs) == 1, xs[0])
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
print(foo(0, 2))
|
||||
print(foo(1, 3))
|
||||
bar(3)
|
||||
print(unpack_list([1, 2, 3]))
|
|
@ -4,14 +4,14 @@
|
|||
from artiq.experiment import *
|
||||
|
||||
class c():
|
||||
# CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '<class 'float'>', is not an ARTIQ type
|
||||
# CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '<class 'list'>', is not an ARTIQ type
|
||||
@kernel
|
||||
def hello(self, x: float):
|
||||
def hello(self, x: list):
|
||||
pass
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.hello(2)
|
||||
self.hello([])
|
||||
|
||||
i = c()
|
||||
@kernel
|
||||
|
|
Loading…
Reference in New Issue