mirror of
https://github.com/m-labs/artiq.git
synced 2025-02-05 23:30:20 +08:00
Allow using Python types in type annotations
This maps basic Python types (float, str, bool, np.int32, np.int64) as well as some generics (list, tuple) to ARTIQ's own type instances. Signed-off-by: Jonathan Coates <jonathan.coates@oxionics.com>
This commit is contained in:
parent
586d97c6cb
commit
6eb81494c5
@ -25,6 +25,8 @@ Highlights:
|
||||
support legacy installations, but may be removed in a future release.
|
||||
* Added channel names to RTIO errors.
|
||||
* Full Python 3.10 support.
|
||||
* Python's built-in types (such as `float`, or `List[...]`) can now be used in type annotations on
|
||||
kernel functions.
|
||||
* Distributed DMA is now supported, allowing DMA to be run directly on satellites for corresponding
|
||||
RTIO events, increasing bandwidth in scenarios with heavy satellite usage.
|
||||
* API extensions have been implemented, enabling applets to directly modify datasets.
|
||||
|
@ -5,6 +5,7 @@ the references to the host objects and translates the functions
|
||||
annotated as ``@kernel`` when they are referenced.
|
||||
"""
|
||||
|
||||
import typing
|
||||
import os, re, linecache, inspect, textwrap, types as pytypes, numpy
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
@ -1071,9 +1072,6 @@ class Stitcher:
|
||||
return function_node
|
||||
|
||||
def _extract_annot(self, function, annot, kind, call_loc, fn_kind):
|
||||
if annot is None:
|
||||
annot = builtins.TNone()
|
||||
|
||||
if isinstance(function, SpecializedFunction):
|
||||
host_function = function.host_function
|
||||
else:
|
||||
@ -1087,9 +1085,20 @@ class Stitcher:
|
||||
if isinstance(embedded_function, str):
|
||||
embedded_function = host_function
|
||||
|
||||
return self._to_artiq_type(
|
||||
annot,
|
||||
function=function,
|
||||
kind=kind,
|
||||
eval_in_scope=lambda x: eval(x, embedded_function.__globals__),
|
||||
call_loc=call_loc,
|
||||
fn_kind=fn_kind)
|
||||
|
||||
def _to_artiq_type(
|
||||
self, annot, *, function, kind: str, eval_in_scope, call_loc: str, fn_kind: str
|
||||
) -> types.Type:
|
||||
if isinstance(annot, str):
|
||||
try:
|
||||
annot = eval(annot, embedded_function.__globals__)
|
||||
annot = eval_in_scope(annot)
|
||||
except Exception:
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
@ -1099,18 +1108,68 @@ class Stitcher:
|
||||
notes=self._call_site_note(call_loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
|
||||
if not isinstance(annot, types.Type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type annotation for {kind}, '{annot}', is not an ARTIQ type",
|
||||
{"kind": kind, "annot": repr(annot)},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(call_loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
|
||||
return types.TVar()
|
||||
else:
|
||||
if isinstance(annot, types.Type):
|
||||
return annot
|
||||
|
||||
# Convert built-in Python types to ARTIQ ones.
|
||||
if annot is None:
|
||||
return builtins.TNone()
|
||||
elif annot is numpy.int64:
|
||||
return builtins.TInt64()
|
||||
elif annot is numpy.int32:
|
||||
return builtins.TInt32()
|
||||
elif annot is float:
|
||||
return builtins.TFloat()
|
||||
elif annot is bool:
|
||||
return builtins.TBool()
|
||||
elif annot is str:
|
||||
return builtins.TStr()
|
||||
elif annot is bytes:
|
||||
return builtins.TBytes()
|
||||
elif annot is bytearray:
|
||||
return builtins.TByteArray()
|
||||
|
||||
# Convert generic Python types to ARTIQ ones.
|
||||
generic_ty = typing.get_origin(annot)
|
||||
if generic_ty is not None:
|
||||
type_args = typing.get_args(annot)
|
||||
artiq_args = [
|
||||
self._to_artiq_type(
|
||||
x,
|
||||
function=function,
|
||||
kind=kind,
|
||||
eval_in_scope=eval_in_scope,
|
||||
call_loc=call_loc,
|
||||
fn_kind=fn_kind)
|
||||
for x in type_args
|
||||
]
|
||||
|
||||
if generic_ty is list and len(artiq_args) == 1:
|
||||
return builtins.TList(artiq_args[0])
|
||||
elif generic_ty is tuple:
|
||||
return types.TTuple(artiq_args)
|
||||
|
||||
# Otherwise report an unknown type and just use a fresh tyvar.
|
||||
|
||||
if annot is int:
|
||||
message = (
|
||||
"type annotation for {kind}, 'int' cannot be used as an ARTIQ type. "
|
||||
"Use numpy's int32 or int64 instead."
|
||||
)
|
||||
ty = builtins.TInt()
|
||||
else:
|
||||
message = "type annotation for {kind}, '{annot}', is not an ARTIQ type"
|
||||
ty = types.TVar()
|
||||
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
message,
|
||||
{"kind": kind, "annot": repr(annot)},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(call_loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
|
||||
return ty
|
||||
|
||||
def _quote_syscall(self, function, loc):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
|
34
artiq/test/lit/embedding/annotation_py.py
Normal file
34
artiq/test/lit/embedding/annotation_py.py
Normal file
@ -0,0 +1,34 @@
|
||||
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
|
||||
# RUN: OutputCheck %s --file-to-check=%t.ll
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
|
||||
# CHECK-L: i64 @_Z13testbench.foozz(i64 %ARG.x, { i1, i32 } %ARG.y)
|
||||
|
||||
@kernel
|
||||
def foo(x: np.int64, y: np.int32 = 1) -> np.int64:
|
||||
print(x + y)
|
||||
return x + y
|
||||
|
||||
# CHECK-L: void @_Z13testbench.barzz()
|
||||
@kernel
|
||||
def bar(x: np.int32) -> None:
|
||||
print(x)
|
||||
|
||||
# CHECK-L: @_Z21testbench.unpack_listzz({ i1, i64 }* nocapture writeonly sret({ i1, i64 }) %.1, { i64*, i32 }* %ARG.xs)
|
||||
@kernel
|
||||
def unpack_list(xs: List[np.int64]) -> Tuple[bool, np.int64]:
|
||||
print(xs)
|
||||
return (len(xs) == 1, xs[0])
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
print(foo(0, 2))
|
||||
print(foo(1, 3))
|
||||
bar(3)
|
||||
print(unpack_list([1, 2, 3]))
|
@ -4,14 +4,14 @@
|
||||
from artiq.experiment import *
|
||||
|
||||
class c():
|
||||
# CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '<class 'float'>', is not an ARTIQ type
|
||||
# CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '<class 'list'>', is not an ARTIQ type
|
||||
@kernel
|
||||
def hello(self, x: float):
|
||||
def hello(self, x: list):
|
||||
pass
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.hello(2)
|
||||
self.hello([])
|
||||
|
||||
i = c()
|
||||
@kernel
|
||||
|
Loading…
Reference in New Issue
Block a user