Allow type annotations on remotely called functions.

This commit is contained in:
whitequark 2015-08-10 17:48:35 +03:00
parent b28a874274
commit 435559fe50
3 changed files with 57 additions and 8 deletions

View File

@ -232,7 +232,7 @@ class Stitcher:
quote_function=self._quote_function) quote_function=self._quote_function)
return asttyped_rewriter.visit(function_node) return asttyped_rewriter.visit(function_node)
def _function_def_note(self, function): def _function_loc(self, function):
filename = function.__code__.co_filename filename = function.__code__.co_filename
line = function.__code__.co_firstlineno line = function.__code__.co_firstlineno
name = function.__code__.co_name name = function.__code__.co_name
@ -240,14 +240,36 @@ class Stitcher:
source_line = linecache.getline(filename, line) source_line = linecache.getline(filename, line)
column = re.search("def", source_line).start(0) column = re.search("def", source_line).start(0)
source_buffer = source.Buffer(source_line, filename, line) source_buffer = source.Buffer(source_line, filename, line)
loc = source.Range(source_buffer, column, column) return source.Range(source_buffer, column, column)
def _function_def_note(self, function):
return diagnostic.Diagnostic("note", return diagnostic.Diagnostic("note",
"definition of function '{function}'", "definition of function '{function}'",
{"function": name}, {"function": function.__name__},
loc) self._function_loc(function))
def _extract_annot(self, function, annot, kind, call_loc):
if not isinstance(annot, types.Type):
note = diagnostic.Diagnostic("note",
"in function called remotely here", {},
call_loc)
diag = diagnostic.Diagnostic("error",
"type annotation for {kind}, '{annot}', is not an ARTIQ type",
{"kind": kind, "annot": repr(annot)},
self._function_loc(function),
notes=[note])
self.engine.process(diag)
return types.TVar()
else:
return annot
def _type_of_param(self, function, loc, param): def _type_of_param(self, function, loc, param):
if param.default is not inspect.Parameter.empty: if param.annotation is not inspect.Parameter.empty:
# Type specified explicitly.
return self._extract_annot(function, param.annotation,
"argument {}".format(param.name), loc)
elif param.default is not inspect.Parameter.empty:
# Try and infer the type from the default value. # Try and infer the type from the default value.
# This is tricky, because the default value might not have # This is tricky, because the default value might not have
# a well-defined type in APython. # a well-defined type in APython.
@ -300,8 +322,14 @@ class Stitcher:
else: else:
optarg_types[param.name] = self._type_of_param(function, loc, param) optarg_types[param.name] = self._type_of_param(function, loc, param)
# Fixed for now. if signature.return_annotation is not inspect.Signature.empty:
ret_type = builtins.TInt(types.TValue(32)) ret_type = self._extract_annot(function, signature.return_annotation,
"return type", loc)
else:
diag = diagnostic.Diagnostic("fatal",
"function must have a return type specified to be called remotely", {},
self._function_loc(function))
self.engine.process(diag)
rpc_type = types.TRPCFunction(arg_types, optarg_types, ret_type, rpc_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
service=self._map(function)) service=self._map(function))

View File

@ -1,5 +1,6 @@
from artiq.language import core, environment, units, scan from artiq.language import core, types, environment, units, scan
from artiq.language.core import * from artiq.language.core import *
from artiq.language.types import *
from artiq.language.environment import * from artiq.language.environment import *
from artiq.language.units import * from artiq.language.units import *
from artiq.language.scan import * from artiq.language.scan import *
@ -7,6 +8,7 @@ from artiq.language.scan import *
__all__ = [] __all__ = []
__all__.extend(core.__all__) __all__.extend(core.__all__)
__all__.extend(types.__all__)
__all__.extend(environment.__all__) __all__.extend(environment.__all__)
__all__.extend(units.__all__) __all__.extend(units.__all__)
__all__.extend(scan.__all__) __all__.extend(scan.__all__)

19
artiq/language/types.py Normal file
View File

@ -0,0 +1,19 @@
"""
Values representing ARTIQ types, to be used in function type
annotations.
"""
from artiq.compiler import types, builtins
__all__ = ["TNone", "TBool", "TInt32", "TInt64", "TFloat",
"TStr", "TList", "TRange32", "TRange64"]
TNone = builtins.TNone()
TBool = builtins.TBool()
TInt32 = builtins.TInt(types.TValue(32))
TInt64 = builtins.TInt(types.TValue(64))
TFloat = builtins.TFloat()
TStr = builtins.TStr()
TList = builtins.TList
TRange32 = builtins.TRange(builtins.TInt(types.TValue(32)))
TRange64 = builtins.TRange(builtins.TInt(types.TValue(64)))