mirror of
https://github.com/m-labs/artiq.git
synced 2025-01-24 09:28:13 +08:00
compiler.embedding: implement type annotations for function arguments.
Fixes #318.
This commit is contained in:
parent
8a243d322f
commit
5a2306ae5a
@ -415,6 +415,53 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||
self.host_environment = host_environment
|
||||
self.quote = quote
|
||||
|
||||
def match_annotation(self, annot):
|
||||
if isinstance(annot, ast.Name):
|
||||
if annot.id == "TNone":
|
||||
return builtins.TNone()
|
||||
if annot.id == "TBool":
|
||||
return builtins.TBool()
|
||||
if annot.id == "TInt32":
|
||||
return builtins.TInt(types.TValue(32))
|
||||
if annot.id == "TInt64":
|
||||
return builtins.TInt(types.TValue(64))
|
||||
if annot.id == "TFloat":
|
||||
return builtins.TFloat()
|
||||
if annot.id == "TStr":
|
||||
return builtins.TStr()
|
||||
if annot.id == "TRange32":
|
||||
return builtins.TRange(builtins.TInt(types.TValue(32)))
|
||||
if annot.id == "TRange64":
|
||||
return builtins.TRange(builtins.TInt(types.TValue(64)))
|
||||
if annot.id == "TVar":
|
||||
return types.TVar()
|
||||
elif (isinstance(annot, ast.Call) and
|
||||
annot.keywords is None and
|
||||
annot.starargs is None and
|
||||
annot.kwargs is None and
|
||||
isinstance(annot.func, ast.Name)):
|
||||
if annot.func.id == "TList" and len(annot.args) == 1:
|
||||
elttyp = self.match_annotation(annot.args[0])
|
||||
if elttyp is not None:
|
||||
return builtins.TList()
|
||||
else:
|
||||
return None
|
||||
|
||||
if annot is not None:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"unrecognized type annotation", {},
|
||||
annot.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def visit_arg(self, node):
|
||||
typ = self._find_name(node.arg, node.loc)
|
||||
annottyp = self.match_annotation(node.annotation)
|
||||
if annottyp is not None:
|
||||
typ.unify(annottyp)
|
||||
return asttyped.argT(type=typ,
|
||||
arg=node.arg, annotation=None,
|
||||
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||
|
||||
def visit_quoted_function(self, node, function):
|
||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||
extractor.visit(node)
|
||||
|
@ -307,8 +307,14 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
self.in_class = old_in_class
|
||||
|
||||
def visit_arg(self, node):
|
||||
if node.annotation is not None:
|
||||
diag = diagnostic.Diagnostic("fatal",
|
||||
"type annotations are not supported here", {},
|
||||
node.annotation.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
return asttyped.argT(type=self._find_name(node.arg, node.loc),
|
||||
arg=node.arg, annotation=self.visit(node.annotation),
|
||||
arg=node.arg, annotation=None,
|
||||
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||
|
||||
def visit_Num(self, node):
|
||||
|
@ -137,6 +137,21 @@ class RPCTest(ExperimentCase):
|
||||
exp.builtin()
|
||||
|
||||
|
||||
class _Annotation(EnvExperiment):
|
||||
def build(self):
|
||||
self.setattr_device("core")
|
||||
|
||||
@kernel
|
||||
def overflow(self, x: TInt64) -> TBool:
|
||||
return (x << 32) != 0
|
||||
|
||||
|
||||
class AnnotationTest(ExperimentCase):
|
||||
def test_annotation(self):
|
||||
exp = self.create(_Annotation)
|
||||
self.assertEqual(exp.overflow(1), True)
|
||||
|
||||
|
||||
class _Payload1MB(EnvExperiment):
|
||||
def build(self):
|
||||
self.setattr_device("core")
|
||||
|
Loading…
Reference in New Issue
Block a user