compiler: refactor type annotations recognizing in kernels.

The new implementation is much more generic, more robust,
and shares code with the same for syscalls as well as RPCs.

Fixes #713.
This commit is contained in:
whitequark 2017-04-13 08:01:46 +00:00
parent 99196986c0
commit 3a1fc729cf
2 changed files with 131 additions and 119 deletions

View File

@ -423,49 +423,9 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
self.host_environment = host_environment self.host_environment = host_environment
self.quote = quote self.quote = quote
def match_annotation(self, annot):
if isinstance(annot, ast.Name):
if annot.id == "TNone":
return builtins.TNone()
if annot.id == "TBool":
return builtins.TBool()
if annot.id == "TInt32":
return builtins.TInt(types.TValue(32))
if annot.id == "TInt64":
return builtins.TInt(types.TValue(64))
if annot.id == "TFloat":
return builtins.TFloat()
if annot.id == "TStr":
return builtins.TStr()
if annot.id == "TRange32":
return builtins.TRange(builtins.TInt(types.TValue(32)))
if annot.id == "TRange64":
return builtins.TRange(builtins.TInt(types.TValue(64)))
if annot.id == "TVar":
return types.TVar()
elif (isinstance(annot, ast.Call) and
annot.keywords == [] and
annot.starargs is None and
annot.kwargs is None and
isinstance(annot.func, ast.Name)):
if annot.func.id == "TList" and len(annot.args) == 1:
elttyp = self.match_annotation(annot.args[0])
if elttyp is not None:
return builtins.TList(elttyp)
else:
return None
if annot is not None:
diag = diagnostic.Diagnostic("error",
"unrecognized type annotation", {},
annot.loc)
self.engine.process(diag)
def visit_arg(self, node): def visit_arg(self, node):
typ = self._find_name(node.arg, node.loc) typ = self._find_name(node.arg, node.loc)
annottyp = self.match_annotation(node.annotation) # ignore annotations; these are handled in _quote_function
if annottyp is not None:
typ.unify(annottyp)
return asttyped.argT(type=typ, return asttyped.argT(type=typ,
arg=node.arg, annotation=None, arg=node.arg, annotation=None,
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
@ -795,6 +755,79 @@ class Stitcher:
value_map=self.value_map, value_map=self.value_map,
quote_function=self._quote_function) quote_function=self._quote_function)
def _function_loc(self, function):
filename = function.__code__.co_filename
line = function.__code__.co_firstlineno
name = function.__code__.co_name
source_line = linecache.getline(filename, line).lstrip()
while source_line.startswith("@") or source_line == "":
line += 1
source_line = linecache.getline(filename, line).lstrip()
if "<lambda>" in function.__qualname__:
column = 0 # can't get column of lambda
else:
column = re.search("def", source_line).start(0)
source_buffer = source.Buffer(source_line, filename, line)
return source.Range(source_buffer, column, column)
def _call_site_note(self, call_loc, fn_kind):
if call_loc:
if fn_kind == 'syscall':
return [diagnostic.Diagnostic("note",
"in system call here", {},
call_loc)]
elif fn_kind == 'rpc':
return [diagnostic.Diagnostic("note",
"in function called remotely here", {},
call_loc)]
elif fn_kind == 'kernel':
return [diagnostic.Diagnostic("note",
"in kernel function here", {},
call_loc)]
else:
assert False
else:
return []
def _type_of_param(self, function, loc, param, fn_kind):
if param.annotation is not inspect.Parameter.empty:
# Type specified explicitly.
return self._extract_annot(function, param.annotation,
"argument '{}'".format(param.name), loc,
fn_kind)
elif fn_kind == 'syscall':
# Syscalls must be entirely annotated.
diag = diagnostic.Diagnostic("error",
"system call argument '{argument}' must have a type annotation",
{"argument": param.name},
self._function_loc(function),
notes=self._call_site_note(loc, fn_kind))
self.engine.process(diag)
elif fn_kind == 'rpc' and param.default is not inspect.Parameter.empty:
notes = []
notes.append(diagnostic.Diagnostic("note",
"expanded from here while trying to infer a type for an"
" unannotated optional argument '{argument}' from its default value",
{"argument": param.name},
self._function_loc(function)))
if loc is not None:
notes.append(self._call_site_note(loc, fn_kind))
with self.engine.context(*notes):
# Try and infer the type from the default value.
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
ast = self._quote(param.default, None)
Inferencer(engine=self.engine).visit(ast)
IntMonomorphizer(engine=self.engine).visit(ast)
return ast.type
else:
# Let the rest of the program decide.
return types.TVar()
def _quote_embedded_function(self, function, flags): def _quote_embedded_function(self, function, flags):
if isinstance(function, SpecializedFunction): if isinstance(function, SpecializedFunction):
host_function = function.host_function host_function = function.host_function
@ -811,6 +844,34 @@ class Stitcher:
module_name = embedded_function.__globals__['__name__'] module_name = embedded_function.__globals__['__name__']
first_line = embedded_function.__code__.co_firstlineno first_line = embedded_function.__code__.co_firstlineno
# Extract function annotation.
signature = inspect.signature(embedded_function)
loc = self._function_loc(embedded_function)
arg_types = OrderedDict()
optarg_types = OrderedDict()
for param in signature.parameters.values():
if param.kind == inspect.Parameter.VAR_POSITIONAL or \
param.kind == inspect.Parameter.VAR_KEYWORD:
diag = diagnostic.Diagnostic("error",
"variadic arguments are not supported; '{argument}' is variadic",
{"argument": param.name},
self._function_loc(function),
notes=self._call_site_note(loc, fn_kind='kernel'))
self.engine.process(diag)
arg_type = self._type_of_param(function, loc, param, fn_kind='kernel')
if param.default is inspect.Parameter.empty:
arg_types[param.name] = arg_type
else:
optarg_types[param.name] = arg_type
if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(function, signature.return_annotation,
"return type", loc, fn_kind='kernel')
else:
ret_type = types.TVar()
# Extract function environment. # Extract function environment.
host_environment = dict() host_environment = dict()
host_environment.update(embedded_function.__globals__) host_environment.update(embedded_function.__globals__)
@ -846,9 +907,9 @@ class Stitcher:
# can handle quoting it. # can handle quoting it.
self.embedding_map.store_function(function, function_node.name) self.embedding_map.store_function(function, function_node.name)
# Memoize the function type before typing it to handle recursive # Fill in the function type before typing it to handle recursive
# invocations. # invocations.
self.functions[function] = types.TVar() self.functions[function] = types.TFunction(arg_types, optarg_types, ret_type)
# Rewrite into typed form. # Rewrite into typed form.
asttyped_rewriter = StitchingASTTypedRewriter( asttyped_rewriter = StitchingASTTypedRewriter(
@ -866,86 +927,19 @@ class Stitcher:
return function_node return function_node
def _function_loc(self, function): def _extract_annot(self, function, annot, kind, call_loc, fn_kind):
filename = function.__code__.co_filename
line = function.__code__.co_firstlineno
name = function.__code__.co_name
source_line = linecache.getline(filename, line).lstrip()
while source_line.startswith("@") or source_line == "":
line += 1
source_line = linecache.getline(filename, line).lstrip()
if "<lambda>" in function.__qualname__:
column = 0 # can't get column of lambda
else:
column = re.search("def", source_line).start(0)
source_buffer = source.Buffer(source_line, filename, line)
return source.Range(source_buffer, column, column)
def _call_site_note(self, call_loc, is_syscall):
if call_loc:
if is_syscall:
return [diagnostic.Diagnostic("note",
"in system call here", {},
call_loc)]
else:
return [diagnostic.Diagnostic("note",
"in function called remotely here", {},
call_loc)]
else:
return []
def _extract_annot(self, function, annot, kind, call_loc, is_syscall):
if not isinstance(annot, types.Type): if not isinstance(annot, types.Type):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"type annotation for {kind}, '{annot}', is not an ARTIQ type", "type annotation for {kind}, '{annot}', is not an ARTIQ type",
{"kind": kind, "annot": repr(annot)}, {"kind": kind, "annot": repr(annot)},
self._function_loc(function), self._function_loc(function),
notes=self._call_site_note(call_loc, is_syscall)) notes=self._call_site_note(call_loc, fn_kind))
self.engine.process(diag) self.engine.process(diag)
return types.TVar() return types.TVar()
else: else:
return annot return annot
def _type_of_param(self, function, loc, param, is_syscall):
if param.annotation is not inspect.Parameter.empty:
# Type specified explicitly.
return self._extract_annot(function, param.annotation,
"argument '{}'".format(param.name), loc,
is_syscall)
elif is_syscall:
# Syscalls must be entirely annotated.
diag = diagnostic.Diagnostic("error",
"system call argument '{argument}' must have a type annotation",
{"argument": param.name},
self._function_loc(function),
notes=self._call_site_note(loc, is_syscall))
self.engine.process(diag)
elif param.default is not inspect.Parameter.empty:
notes = []
notes.append(diagnostic.Diagnostic("note",
"expanded from here while trying to infer a type for an"
" unannotated optional argument '{argument}' from its default value",
{"argument": param.name},
self._function_loc(function)))
if loc is not None:
notes.append(self._call_site_note(loc, is_syscall))
with self.engine.context(*notes):
# Try and infer the type from the default value.
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
ast = self._quote(param.default, None)
Inferencer(engine=self.engine).visit(ast)
IntMonomorphizer(engine=self.engine).visit(ast)
return ast.type
else:
# Let the rest of the program decide.
return types.TVar()
def _quote_syscall(self, function, loc): def _quote_syscall(self, function, loc):
signature = inspect.signature(function) signature = inspect.signature(function)
@ -957,27 +951,28 @@ class Stitcher:
"system calls must only use positional arguments; '{argument}' isn't", "system calls must only use positional arguments; '{argument}' isn't",
{"argument": param.name}, {"argument": param.name},
self._function_loc(function), self._function_loc(function),
notes=self._call_site_note(loc, is_syscall=True)) notes=self._call_site_note(loc, fn_kind='syscall'))
self.engine.process(diag) self.engine.process(diag)
if param.default is inspect.Parameter.empty: if param.default is inspect.Parameter.empty:
arg_types[param.name] = self._type_of_param(function, loc, param, is_syscall=True) arg_types[param.name] = self._type_of_param(function, loc, param,
fn_kind='syscall')
else: else:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"system call argument '{argument}' must not have a default value", "system call argument '{argument}' must not have a default value",
{"argument": param.name}, {"argument": param.name},
self._function_loc(function), self._function_loc(function),
notes=self._call_site_note(loc, is_syscall=True)) notes=self._call_site_note(loc, fn_kind='syscall'))
self.engine.process(diag) self.engine.process(diag)
if signature.return_annotation is not inspect.Signature.empty: if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(function, signature.return_annotation, ret_type = self._extract_annot(function, signature.return_annotation,
"return type", loc, is_syscall=True) "return type", loc, fn_kind='syscall')
else: else:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"system call must have a return type annotation", {}, "system call must have a return type annotation", {},
self._function_loc(function), self._function_loc(function),
notes=self._call_site_note(loc, is_syscall=True)) notes=self._call_site_note(loc, fn_kind='syscall'))
self.engine.process(diag) self.engine.process(diag)
ret_type = types.TVar() ret_type = types.TVar()
@ -1005,7 +1000,7 @@ class Stitcher:
signature = inspect.signature(host_function.__func__) signature = inspect.signature(host_function.__func__)
if signature.return_annotation is not inspect.Signature.empty: if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(host_function, signature.return_annotation, ret_type = self._extract_annot(host_function, signature.return_annotation,
"return type", loc, is_syscall=False) "return type", loc, fn_kind='rpc')
else: else:
assert False assert False
@ -1077,7 +1072,7 @@ class Stitcher:
diag = diagnostic.Diagnostic("fatal", diag = diagnostic.Diagnostic("fatal",
"this function cannot be called as an RPC", {}, "this function cannot be called as an RPC", {},
self._function_loc(host_function), self._function_loc(host_function),
notes=self._call_site_note(loc, is_syscall=False)) notes=self._call_site_note(loc, fn_kind='rpc'))
self.engine.process(diag) self.engine.process(diag)
else: else:
assert False assert False

View File

@ -0,0 +1,17 @@
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
# RUN: OutputCheck %s --file-to-check=%t.ll
from artiq.language.core import *
from artiq.language.types import *
# CHECK: i64 @_Z13testbench.foozz\(i64 %ARG.x, \{ i1, i64 \} %ARG.y\)
@kernel
def foo(x: TInt64, y: TInt64 = 1) -> TInt64:
print(x+y)
return x+y
@kernel
def entrypoint():
print(foo(0, 2))
print(foo(1, 3))