forked from M-Labs/artiq
compiler: refactor type annotations recognizing in kernels.
The new implementation is much more generic, more robust, and shares code with the same for syscalls as well as RPCs. Fixes #713.
This commit is contained in:
parent
99196986c0
commit
3a1fc729cf
|
@ -423,49 +423,9 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
|||
self.host_environment = host_environment
|
||||
self.quote = quote
|
||||
|
||||
def match_annotation(self, annot):
|
||||
if isinstance(annot, ast.Name):
|
||||
if annot.id == "TNone":
|
||||
return builtins.TNone()
|
||||
if annot.id == "TBool":
|
||||
return builtins.TBool()
|
||||
if annot.id == "TInt32":
|
||||
return builtins.TInt(types.TValue(32))
|
||||
if annot.id == "TInt64":
|
||||
return builtins.TInt(types.TValue(64))
|
||||
if annot.id == "TFloat":
|
||||
return builtins.TFloat()
|
||||
if annot.id == "TStr":
|
||||
return builtins.TStr()
|
||||
if annot.id == "TRange32":
|
||||
return builtins.TRange(builtins.TInt(types.TValue(32)))
|
||||
if annot.id == "TRange64":
|
||||
return builtins.TRange(builtins.TInt(types.TValue(64)))
|
||||
if annot.id == "TVar":
|
||||
return types.TVar()
|
||||
elif (isinstance(annot, ast.Call) and
|
||||
annot.keywords == [] and
|
||||
annot.starargs is None and
|
||||
annot.kwargs is None and
|
||||
isinstance(annot.func, ast.Name)):
|
||||
if annot.func.id == "TList" and len(annot.args) == 1:
|
||||
elttyp = self.match_annotation(annot.args[0])
|
||||
if elttyp is not None:
|
||||
return builtins.TList(elttyp)
|
||||
else:
|
||||
return None
|
||||
|
||||
if annot is not None:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"unrecognized type annotation", {},
|
||||
annot.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def visit_arg(self, node):
|
||||
typ = self._find_name(node.arg, node.loc)
|
||||
annottyp = self.match_annotation(node.annotation)
|
||||
if annottyp is not None:
|
||||
typ.unify(annottyp)
|
||||
# ignore annotations; these are handled in _quote_function
|
||||
return asttyped.argT(type=typ,
|
||||
arg=node.arg, annotation=None,
|
||||
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||
|
@ -795,6 +755,79 @@ class Stitcher:
|
|||
value_map=self.value_map,
|
||||
quote_function=self._quote_function)
|
||||
|
||||
def _function_loc(self, function):
|
||||
filename = function.__code__.co_filename
|
||||
line = function.__code__.co_firstlineno
|
||||
name = function.__code__.co_name
|
||||
|
||||
source_line = linecache.getline(filename, line).lstrip()
|
||||
while source_line.startswith("@") or source_line == "":
|
||||
line += 1
|
||||
source_line = linecache.getline(filename, line).lstrip()
|
||||
|
||||
if "<lambda>" in function.__qualname__:
|
||||
column = 0 # can't get column of lambda
|
||||
else:
|
||||
column = re.search("def", source_line).start(0)
|
||||
source_buffer = source.Buffer(source_line, filename, line)
|
||||
return source.Range(source_buffer, column, column)
|
||||
|
||||
def _call_site_note(self, call_loc, fn_kind):
|
||||
if call_loc:
|
||||
if fn_kind == 'syscall':
|
||||
return [diagnostic.Diagnostic("note",
|
||||
"in system call here", {},
|
||||
call_loc)]
|
||||
elif fn_kind == 'rpc':
|
||||
return [diagnostic.Diagnostic("note",
|
||||
"in function called remotely here", {},
|
||||
call_loc)]
|
||||
elif fn_kind == 'kernel':
|
||||
return [diagnostic.Diagnostic("note",
|
||||
"in kernel function here", {},
|
||||
call_loc)]
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
return []
|
||||
|
||||
def _type_of_param(self, function, loc, param, fn_kind):
|
||||
if param.annotation is not inspect.Parameter.empty:
|
||||
# Type specified explicitly.
|
||||
return self._extract_annot(function, param.annotation,
|
||||
"argument '{}'".format(param.name), loc,
|
||||
fn_kind)
|
||||
elif fn_kind == 'syscall':
|
||||
# Syscalls must be entirely annotated.
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"system call argument '{argument}' must have a type annotation",
|
||||
{"argument": param.name},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
elif fn_kind == 'rpc' and param.default is not inspect.Parameter.empty:
|
||||
notes = []
|
||||
notes.append(diagnostic.Diagnostic("note",
|
||||
"expanded from here while trying to infer a type for an"
|
||||
" unannotated optional argument '{argument}' from its default value",
|
||||
{"argument": param.name},
|
||||
self._function_loc(function)))
|
||||
if loc is not None:
|
||||
notes.append(self._call_site_note(loc, fn_kind))
|
||||
|
||||
with self.engine.context(*notes):
|
||||
# Try and infer the type from the default value.
|
||||
# This is tricky, because the default value might not have
|
||||
# a well-defined type in APython.
|
||||
# In this case, we bail out, but mention why we do it.
|
||||
ast = self._quote(param.default, None)
|
||||
Inferencer(engine=self.engine).visit(ast)
|
||||
IntMonomorphizer(engine=self.engine).visit(ast)
|
||||
return ast.type
|
||||
else:
|
||||
# Let the rest of the program decide.
|
||||
return types.TVar()
|
||||
|
||||
def _quote_embedded_function(self, function, flags):
|
||||
if isinstance(function, SpecializedFunction):
|
||||
host_function = function.host_function
|
||||
|
@ -811,6 +844,34 @@ class Stitcher:
|
|||
module_name = embedded_function.__globals__['__name__']
|
||||
first_line = embedded_function.__code__.co_firstlineno
|
||||
|
||||
# Extract function annotation.
|
||||
signature = inspect.signature(embedded_function)
|
||||
loc = self._function_loc(embedded_function)
|
||||
|
||||
arg_types = OrderedDict()
|
||||
optarg_types = OrderedDict()
|
||||
for param in signature.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_POSITIONAL or \
|
||||
param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"variadic arguments are not supported; '{argument}' is variadic",
|
||||
{"argument": param.name},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(loc, fn_kind='kernel'))
|
||||
self.engine.process(diag)
|
||||
|
||||
arg_type = self._type_of_param(function, loc, param, fn_kind='kernel')
|
||||
if param.default is inspect.Parameter.empty:
|
||||
arg_types[param.name] = arg_type
|
||||
else:
|
||||
optarg_types[param.name] = arg_type
|
||||
|
||||
if signature.return_annotation is not inspect.Signature.empty:
|
||||
ret_type = self._extract_annot(function, signature.return_annotation,
|
||||
"return type", loc, fn_kind='kernel')
|
||||
else:
|
||||
ret_type = types.TVar()
|
||||
|
||||
# Extract function environment.
|
||||
host_environment = dict()
|
||||
host_environment.update(embedded_function.__globals__)
|
||||
|
@ -846,9 +907,9 @@ class Stitcher:
|
|||
# can handle quoting it.
|
||||
self.embedding_map.store_function(function, function_node.name)
|
||||
|
||||
# Memoize the function type before typing it to handle recursive
|
||||
# Fill in the function type before typing it to handle recursive
|
||||
# invocations.
|
||||
self.functions[function] = types.TVar()
|
||||
self.functions[function] = types.TFunction(arg_types, optarg_types, ret_type)
|
||||
|
||||
# Rewrite into typed form.
|
||||
asttyped_rewriter = StitchingASTTypedRewriter(
|
||||
|
@ -866,86 +927,19 @@ class Stitcher:
|
|||
|
||||
return function_node
|
||||
|
||||
def _function_loc(self, function):
|
||||
filename = function.__code__.co_filename
|
||||
line = function.__code__.co_firstlineno
|
||||
name = function.__code__.co_name
|
||||
|
||||
source_line = linecache.getline(filename, line).lstrip()
|
||||
while source_line.startswith("@") or source_line == "":
|
||||
line += 1
|
||||
source_line = linecache.getline(filename, line).lstrip()
|
||||
|
||||
if "<lambda>" in function.__qualname__:
|
||||
column = 0 # can't get column of lambda
|
||||
else:
|
||||
column = re.search("def", source_line).start(0)
|
||||
source_buffer = source.Buffer(source_line, filename, line)
|
||||
return source.Range(source_buffer, column, column)
|
||||
|
||||
def _call_site_note(self, call_loc, is_syscall):
|
||||
if call_loc:
|
||||
if is_syscall:
|
||||
return [diagnostic.Diagnostic("note",
|
||||
"in system call here", {},
|
||||
call_loc)]
|
||||
else:
|
||||
return [diagnostic.Diagnostic("note",
|
||||
"in function called remotely here", {},
|
||||
call_loc)]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_annot(self, function, annot, kind, call_loc, is_syscall):
|
||||
def _extract_annot(self, function, annot, kind, call_loc, fn_kind):
|
||||
if not isinstance(annot, types.Type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type annotation for {kind}, '{annot}', is not an ARTIQ type",
|
||||
{"kind": kind, "annot": repr(annot)},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(call_loc, is_syscall))
|
||||
notes=self._call_site_note(call_loc, fn_kind))
|
||||
self.engine.process(diag)
|
||||
|
||||
return types.TVar()
|
||||
else:
|
||||
return annot
|
||||
|
||||
def _type_of_param(self, function, loc, param, is_syscall):
|
||||
if param.annotation is not inspect.Parameter.empty:
|
||||
# Type specified explicitly.
|
||||
return self._extract_annot(function, param.annotation,
|
||||
"argument '{}'".format(param.name), loc,
|
||||
is_syscall)
|
||||
elif is_syscall:
|
||||
# Syscalls must be entirely annotated.
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"system call argument '{argument}' must have a type annotation",
|
||||
{"argument": param.name},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(loc, is_syscall))
|
||||
self.engine.process(diag)
|
||||
elif param.default is not inspect.Parameter.empty:
|
||||
notes = []
|
||||
notes.append(diagnostic.Diagnostic("note",
|
||||
"expanded from here while trying to infer a type for an"
|
||||
" unannotated optional argument '{argument}' from its default value",
|
||||
{"argument": param.name},
|
||||
self._function_loc(function)))
|
||||
if loc is not None:
|
||||
notes.append(self._call_site_note(loc, is_syscall))
|
||||
|
||||
with self.engine.context(*notes):
|
||||
# Try and infer the type from the default value.
|
||||
# This is tricky, because the default value might not have
|
||||
# a well-defined type in APython.
|
||||
# In this case, we bail out, but mention why we do it.
|
||||
ast = self._quote(param.default, None)
|
||||
Inferencer(engine=self.engine).visit(ast)
|
||||
IntMonomorphizer(engine=self.engine).visit(ast)
|
||||
return ast.type
|
||||
else:
|
||||
# Let the rest of the program decide.
|
||||
return types.TVar()
|
||||
|
||||
def _quote_syscall(self, function, loc):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
|
@ -957,27 +951,28 @@ class Stitcher:
|
|||
"system calls must only use positional arguments; '{argument}' isn't",
|
||||
{"argument": param.name},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(loc, is_syscall=True))
|
||||
notes=self._call_site_note(loc, fn_kind='syscall'))
|
||||
self.engine.process(diag)
|
||||
|
||||
if param.default is inspect.Parameter.empty:
|
||||
arg_types[param.name] = self._type_of_param(function, loc, param, is_syscall=True)
|
||||
arg_types[param.name] = self._type_of_param(function, loc, param,
|
||||
fn_kind='syscall')
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"system call argument '{argument}' must not have a default value",
|
||||
{"argument": param.name},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(loc, is_syscall=True))
|
||||
notes=self._call_site_note(loc, fn_kind='syscall'))
|
||||
self.engine.process(diag)
|
||||
|
||||
if signature.return_annotation is not inspect.Signature.empty:
|
||||
ret_type = self._extract_annot(function, signature.return_annotation,
|
||||
"return type", loc, is_syscall=True)
|
||||
"return type", loc, fn_kind='syscall')
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"system call must have a return type annotation", {},
|
||||
self._function_loc(function),
|
||||
notes=self._call_site_note(loc, is_syscall=True))
|
||||
notes=self._call_site_note(loc, fn_kind='syscall'))
|
||||
self.engine.process(diag)
|
||||
ret_type = types.TVar()
|
||||
|
||||
|
@ -1005,7 +1000,7 @@ class Stitcher:
|
|||
signature = inspect.signature(host_function.__func__)
|
||||
if signature.return_annotation is not inspect.Signature.empty:
|
||||
ret_type = self._extract_annot(host_function, signature.return_annotation,
|
||||
"return type", loc, is_syscall=False)
|
||||
"return type", loc, fn_kind='rpc')
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
@ -1077,7 +1072,7 @@ class Stitcher:
|
|||
diag = diagnostic.Diagnostic("fatal",
|
||||
"this function cannot be called as an RPC", {},
|
||||
self._function_loc(host_function),
|
||||
notes=self._call_site_note(loc, is_syscall=False))
|
||||
notes=self._call_site_note(loc, fn_kind='rpc'))
|
||||
self.engine.process(diag)
|
||||
else:
|
||||
assert False
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
|
||||
# RUN: OutputCheck %s --file-to-check=%t.ll
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
|
||||
# CHECK: i64 @_Z13testbench.foozz\(i64 %ARG.x, \{ i1, i64 \} %ARG.y\)
|
||||
|
||||
@kernel
|
||||
def foo(x: TInt64, y: TInt64 = 1) -> TInt64:
|
||||
print(x+y)
|
||||
return x+y
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
print(foo(0, 2))
|
||||
print(foo(1, 3))
|
Loading…
Reference in New Issue