diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 943550072..f55d9af7f 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -415,6 +415,53 @@ class StitchingASTTypedRewriter(ASTTypedRewriter): self.host_environment = host_environment self.quote = quote + def match_annotation(self, annot): + if isinstance(annot, ast.Name): + if annot.id == "TNone": + return builtins.TNone() + if annot.id == "TBool": + return builtins.TBool() + if annot.id == "TInt32": + return builtins.TInt(types.TValue(32)) + if annot.id == "TInt64": + return builtins.TInt(types.TValue(64)) + if annot.id == "TFloat": + return builtins.TFloat() + if annot.id == "TStr": + return builtins.TStr() + if annot.id == "TRange32": + return builtins.TRange(builtins.TInt(types.TValue(32))) + if annot.id == "TRange64": + return builtins.TRange(builtins.TInt(types.TValue(64))) + if annot.id == "TVar": + return types.TVar() + elif (isinstance(annot, ast.Call) and + annot.keywords is None and + annot.starargs is None and + annot.kwargs is None and + isinstance(annot.func, ast.Name)): + if annot.func.id == "TList" and len(annot.args) == 1: + elttyp = self.match_annotation(annot.args[0]) + if elttyp is not None: + return builtins.TList() + else: + return None + + if annot is not None: + diag = diagnostic.Diagnostic("error", + "unrecognized type annotation", {}, + annot.loc) + self.engine.process(diag) + + def visit_arg(self, node): + typ = self._find_name(node.arg, node.loc) + annottyp = self.match_annotation(node.annotation) + if annottyp is not None: + typ.unify(annottyp) + return asttyped.argT(type=typ, + arg=node.arg, annotation=None, + arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) + def visit_quoted_function(self, node, function): extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor.visit(node) diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index b0714bb6f..199ea2ef9 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -307,8 +307,14 @@ class ASTTypedRewriter(algorithm.Transformer): self.in_class = old_in_class def visit_arg(self, node): + if node.annotation is not None: + diag = diagnostic.Diagnostic("fatal", + "type annotations are not supported here", {}, + node.annotation.loc) + self.engine.process(diag) + return asttyped.argT(type=self._find_name(node.arg, node.loc), - arg=node.arg, annotation=self.visit(node.annotation), + arg=node.arg, annotation=None, arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) def visit_Num(self, node): diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index 1edb73fa4..988d23e73 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -137,6 +137,21 @@ class RPCTest(ExperimentCase): exp.builtin() +class _Annotation(EnvExperiment): + def build(self): + self.setattr_device("core") + + @kernel + def overflow(self, x: TInt64) -> TBool: + return (x << 32) != 0 + + +class AnnotationTest(ExperimentCase): + def test_annotation(self): + exp = self.create(_Annotation) + self.assertEqual(exp.overflow(1), True) + + class _Payload1MB(EnvExperiment): def build(self): self.setattr_device("core")