mirror of https://github.com/m-labs/artiq.git
Allow type annotations on remotely called functions.
This commit is contained in:
parent
b28a874274
commit
435559fe50
|
@ -232,7 +232,7 @@ class Stitcher:
|
|||
quote_function=self._quote_function)
|
||||
return asttyped_rewriter.visit(function_node)
|
||||
|
||||
def _function_def_note(self, function):
|
||||
def _function_loc(self, function):
|
||||
filename = function.__code__.co_filename
|
||||
line = function.__code__.co_firstlineno
|
||||
name = function.__code__.co_name
|
||||
|
@ -240,14 +240,36 @@ class Stitcher:
|
|||
source_line = linecache.getline(filename, line)
|
||||
column = re.search("def", source_line).start(0)
|
||||
source_buffer = source.Buffer(source_line, filename, line)
|
||||
loc = source.Range(source_buffer, column, column)
|
||||
return source.Range(source_buffer, column, column)
|
||||
|
||||
def _function_def_note(self, function):
|
||||
return diagnostic.Diagnostic("note",
|
||||
"definition of function '{function}'",
|
||||
{"function": name},
|
||||
loc)
|
||||
{"function": function.__name__},
|
||||
self._function_loc(function))
|
||||
|
||||
def _extract_annot(self, function, annot, kind, call_loc):
|
||||
if not isinstance(annot, types.Type):
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"in function called remotely here", {},
|
||||
call_loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type annotation for {kind}, '{annot}', is not an ARTIQ type",
|
||||
{"kind": kind, "annot": repr(annot)},
|
||||
self._function_loc(function),
|
||||
notes=[note])
|
||||
self.engine.process(diag)
|
||||
|
||||
return types.TVar()
|
||||
else:
|
||||
return annot
|
||||
|
||||
def _type_of_param(self, function, loc, param):
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if param.annotation is not inspect.Parameter.empty:
|
||||
# Type specified explicitly.
|
||||
return self._extract_annot(function, param.annotation,
|
||||
"argument {}".format(param.name), loc)
|
||||
elif param.default is not inspect.Parameter.empty:
|
||||
# Try and infer the type from the default value.
|
||||
# This is tricky, because the default value might not have
|
||||
# a well-defined type in APython.
|
||||
|
@ -300,8 +322,14 @@ class Stitcher:
|
|||
else:
|
||||
optarg_types[param.name] = self._type_of_param(function, loc, param)
|
||||
|
||||
# Fixed for now.
|
||||
ret_type = builtins.TInt(types.TValue(32))
|
||||
if signature.return_annotation is not inspect.Signature.empty:
|
||||
ret_type = self._extract_annot(function, signature.return_annotation,
|
||||
"return type", loc)
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("fatal",
|
||||
"function must have a return type specified to be called remotely", {},
|
||||
self._function_loc(function))
|
||||
self.engine.process(diag)
|
||||
|
||||
rpc_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
|
||||
service=self._map(function))
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from artiq.language import core, environment, units, scan
|
||||
from artiq.language import core, types, environment, units, scan
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
from artiq.language.environment import *
|
||||
from artiq.language.units import *
|
||||
from artiq.language.scan import *
|
||||
|
@ -7,6 +8,7 @@ from artiq.language.scan import *
|
|||
|
||||
__all__ = []
|
||||
__all__.extend(core.__all__)
|
||||
__all__.extend(types.__all__)
|
||||
__all__.extend(environment.__all__)
|
||||
__all__.extend(units.__all__)
|
||||
__all__.extend(scan.__all__)
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
Values representing ARTIQ types, to be used in function type
|
||||
annotations.
|
||||
"""
|
||||
|
||||
from artiq.compiler import types, builtins
|
||||
|
||||
__all__ = ["TNone", "TBool", "TInt32", "TInt64", "TFloat",
|
||||
"TStr", "TList", "TRange32", "TRange64"]
|
||||
|
||||
TNone = builtins.TNone()
|
||||
TBool = builtins.TBool()
|
||||
TInt32 = builtins.TInt(types.TValue(32))
|
||||
TInt64 = builtins.TInt(types.TValue(64))
|
||||
TFloat = builtins.TFloat()
|
||||
TStr = builtins.TStr()
|
||||
TList = builtins.TList
|
||||
TRange32 = builtins.TRange(builtins.TInt(types.TValue(32)))
|
||||
TRange64 = builtins.TRange(builtins.TInt(types.TValue(64)))
|
Loading…
Reference in New Issue