compiler.embedding: implement type annotations for function arguments.

Fixes #318.
This commit is contained in:
whitequark 2016-08-08 03:28:25 +00:00
parent 8a243d322f
commit 5a2306ae5a
3 changed files with 69 additions and 1 deletions

View File

@ -415,6 +415,53 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
self.host_environment = host_environment self.host_environment = host_environment
self.quote = quote self.quote = quote
def match_annotation(self, annot):
if isinstance(annot, ast.Name):
if annot.id == "TNone":
return builtins.TNone()
if annot.id == "TBool":
return builtins.TBool()
if annot.id == "TInt32":
return builtins.TInt(types.TValue(32))
if annot.id == "TInt64":
return builtins.TInt(types.TValue(64))
if annot.id == "TFloat":
return builtins.TFloat()
if annot.id == "TStr":
return builtins.TStr()
if annot.id == "TRange32":
return builtins.TRange(builtins.TInt(types.TValue(32)))
if annot.id == "TRange64":
return builtins.TRange(builtins.TInt(types.TValue(64)))
if annot.id == "TVar":
return types.TVar()
elif (isinstance(annot, ast.Call) and
annot.keywords is None and
annot.starargs is None and
annot.kwargs is None and
isinstance(annot.func, ast.Name)):
if annot.func.id == "TList" and len(annot.args) == 1:
elttyp = self.match_annotation(annot.args[0])
if elttyp is not None:
return builtins.TList()
else:
return None
if annot is not None:
diag = diagnostic.Diagnostic("error",
"unrecognized type annotation", {},
annot.loc)
self.engine.process(diag)
def visit_arg(self, node):
typ = self._find_name(node.arg, node.loc)
annottyp = self.match_annotation(node.annotation)
if annottyp is not None:
typ.unify(annottyp)
return asttyped.argT(type=typ,
arg=node.arg, annotation=None,
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
def visit_quoted_function(self, node, function): def visit_quoted_function(self, node, function):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node) extractor.visit(node)

View File

@ -307,8 +307,14 @@ class ASTTypedRewriter(algorithm.Transformer):
self.in_class = old_in_class self.in_class = old_in_class
def visit_arg(self, node): def visit_arg(self, node):
if node.annotation is not None:
diag = diagnostic.Diagnostic("fatal",
"type annotations are not supported here", {},
node.annotation.loc)
self.engine.process(diag)
return asttyped.argT(type=self._find_name(node.arg, node.loc), return asttyped.argT(type=self._find_name(node.arg, node.loc),
arg=node.arg, annotation=self.visit(node.annotation), arg=node.arg, annotation=None,
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
def visit_Num(self, node): def visit_Num(self, node):

View File

@ -137,6 +137,21 @@ class RPCTest(ExperimentCase):
exp.builtin() exp.builtin()
class _Annotation(EnvExperiment):
def build(self):
self.setattr_device("core")
@kernel
def overflow(self, x: TInt64) -> TBool:
return (x << 32) != 0
class AnnotationTest(ExperimentCase):
def test_annotation(self):
exp = self.create(_Annotation)
self.assertEqual(exp.overflow(1), True)
class _Payload1MB(EnvExperiment): class _Payload1MB(EnvExperiment):
def build(self): def build(self):
self.setattr_device("core") self.setattr_device("core")