forked from M-Labs/artiq
compiler.embedding: implement type annotations for function arguments.
Fixes #318.
This commit is contained in:
parent
8a243d322f
commit
5a2306ae5a
|
@ -415,6 +415,53 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||||
self.host_environment = host_environment
|
self.host_environment = host_environment
|
||||||
self.quote = quote
|
self.quote = quote
|
||||||
|
|
||||||
|
def match_annotation(self, annot):
|
||||||
|
if isinstance(annot, ast.Name):
|
||||||
|
if annot.id == "TNone":
|
||||||
|
return builtins.TNone()
|
||||||
|
if annot.id == "TBool":
|
||||||
|
return builtins.TBool()
|
||||||
|
if annot.id == "TInt32":
|
||||||
|
return builtins.TInt(types.TValue(32))
|
||||||
|
if annot.id == "TInt64":
|
||||||
|
return builtins.TInt(types.TValue(64))
|
||||||
|
if annot.id == "TFloat":
|
||||||
|
return builtins.TFloat()
|
||||||
|
if annot.id == "TStr":
|
||||||
|
return builtins.TStr()
|
||||||
|
if annot.id == "TRange32":
|
||||||
|
return builtins.TRange(builtins.TInt(types.TValue(32)))
|
||||||
|
if annot.id == "TRange64":
|
||||||
|
return builtins.TRange(builtins.TInt(types.TValue(64)))
|
||||||
|
if annot.id == "TVar":
|
||||||
|
return types.TVar()
|
||||||
|
elif (isinstance(annot, ast.Call) and
|
||||||
|
annot.keywords is None and
|
||||||
|
annot.starargs is None and
|
||||||
|
annot.kwargs is None and
|
||||||
|
isinstance(annot.func, ast.Name)):
|
||||||
|
if annot.func.id == "TList" and len(annot.args) == 1:
|
||||||
|
elttyp = self.match_annotation(annot.args[0])
|
||||||
|
if elttyp is not None:
|
||||||
|
return builtins.TList()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if annot is not None:
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"unrecognized type annotation", {},
|
||||||
|
annot.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
def visit_arg(self, node):
|
||||||
|
typ = self._find_name(node.arg, node.loc)
|
||||||
|
annottyp = self.match_annotation(node.annotation)
|
||||||
|
if annottyp is not None:
|
||||||
|
typ.unify(annottyp)
|
||||||
|
return asttyped.argT(type=typ,
|
||||||
|
arg=node.arg, annotation=None,
|
||||||
|
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||||
|
|
||||||
def visit_quoted_function(self, node, function):
|
def visit_quoted_function(self, node, function):
|
||||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||||
extractor.visit(node)
|
extractor.visit(node)
|
||||||
|
|
|
@ -307,8 +307,14 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||||
self.in_class = old_in_class
|
self.in_class = old_in_class
|
||||||
|
|
||||||
def visit_arg(self, node):
|
def visit_arg(self, node):
|
||||||
|
if node.annotation is not None:
|
||||||
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
|
"type annotations are not supported here", {},
|
||||||
|
node.annotation.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
return asttyped.argT(type=self._find_name(node.arg, node.loc),
|
return asttyped.argT(type=self._find_name(node.arg, node.loc),
|
||||||
arg=node.arg, annotation=self.visit(node.annotation),
|
arg=node.arg, annotation=None,
|
||||||
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||||
|
|
||||||
def visit_Num(self, node):
|
def visit_Num(self, node):
|
||||||
|
|
|
@ -137,6 +137,21 @@ class RPCTest(ExperimentCase):
|
||||||
exp.builtin()
|
exp.builtin()
|
||||||
|
|
||||||
|
|
||||||
|
class _Annotation(EnvExperiment):
|
||||||
|
def build(self):
|
||||||
|
self.setattr_device("core")
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def overflow(self, x: TInt64) -> TBool:
|
||||||
|
return (x << 32) != 0
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationTest(ExperimentCase):
|
||||||
|
def test_annotation(self):
|
||||||
|
exp = self.create(_Annotation)
|
||||||
|
self.assertEqual(exp.overflow(1), True)
|
||||||
|
|
||||||
|
|
||||||
class _Payload1MB(EnvExperiment):
|
class _Payload1MB(EnvExperiment):
|
||||||
def build(self):
|
def build(self):
|
||||||
self.setattr_device("core")
|
self.setattr_device("core")
|
||||||
|
|
Loading…
Reference in New Issue