forked from M-Labs/artiq
compiler: refactor type annotations recognizing in kernels.
The new implementation is much more generic, more robust, and shares code with the same for syscalls as well as RPCs. Fixes #713.
This commit is contained in:
parent
99196986c0
commit
3a1fc729cf
|
@ -423,49 +423,9 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||||
self.host_environment = host_environment
|
self.host_environment = host_environment
|
||||||
self.quote = quote
|
self.quote = quote
|
||||||
|
|
||||||
def match_annotation(self, annot):
|
|
||||||
if isinstance(annot, ast.Name):
|
|
||||||
if annot.id == "TNone":
|
|
||||||
return builtins.TNone()
|
|
||||||
if annot.id == "TBool":
|
|
||||||
return builtins.TBool()
|
|
||||||
if annot.id == "TInt32":
|
|
||||||
return builtins.TInt(types.TValue(32))
|
|
||||||
if annot.id == "TInt64":
|
|
||||||
return builtins.TInt(types.TValue(64))
|
|
||||||
if annot.id == "TFloat":
|
|
||||||
return builtins.TFloat()
|
|
||||||
if annot.id == "TStr":
|
|
||||||
return builtins.TStr()
|
|
||||||
if annot.id == "TRange32":
|
|
||||||
return builtins.TRange(builtins.TInt(types.TValue(32)))
|
|
||||||
if annot.id == "TRange64":
|
|
||||||
return builtins.TRange(builtins.TInt(types.TValue(64)))
|
|
||||||
if annot.id == "TVar":
|
|
||||||
return types.TVar()
|
|
||||||
elif (isinstance(annot, ast.Call) and
|
|
||||||
annot.keywords == [] and
|
|
||||||
annot.starargs is None and
|
|
||||||
annot.kwargs is None and
|
|
||||||
isinstance(annot.func, ast.Name)):
|
|
||||||
if annot.func.id == "TList" and len(annot.args) == 1:
|
|
||||||
elttyp = self.match_annotation(annot.args[0])
|
|
||||||
if elttyp is not None:
|
|
||||||
return builtins.TList(elttyp)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if annot is not None:
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
|
||||||
"unrecognized type annotation", {},
|
|
||||||
annot.loc)
|
|
||||||
self.engine.process(diag)
|
|
||||||
|
|
||||||
def visit_arg(self, node):
|
def visit_arg(self, node):
|
||||||
typ = self._find_name(node.arg, node.loc)
|
typ = self._find_name(node.arg, node.loc)
|
||||||
annottyp = self.match_annotation(node.annotation)
|
# ignore annotations; these are handled in _quote_function
|
||||||
if annottyp is not None:
|
|
||||||
typ.unify(annottyp)
|
|
||||||
return asttyped.argT(type=typ,
|
return asttyped.argT(type=typ,
|
||||||
arg=node.arg, annotation=None,
|
arg=node.arg, annotation=None,
|
||||||
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||||
|
@ -795,6 +755,79 @@ class Stitcher:
|
||||||
value_map=self.value_map,
|
value_map=self.value_map,
|
||||||
quote_function=self._quote_function)
|
quote_function=self._quote_function)
|
||||||
|
|
||||||
|
def _function_loc(self, function):
|
||||||
|
filename = function.__code__.co_filename
|
||||||
|
line = function.__code__.co_firstlineno
|
||||||
|
name = function.__code__.co_name
|
||||||
|
|
||||||
|
source_line = linecache.getline(filename, line).lstrip()
|
||||||
|
while source_line.startswith("@") or source_line == "":
|
||||||
|
line += 1
|
||||||
|
source_line = linecache.getline(filename, line).lstrip()
|
||||||
|
|
||||||
|
if "<lambda>" in function.__qualname__:
|
||||||
|
column = 0 # can't get column of lambda
|
||||||
|
else:
|
||||||
|
column = re.search("def", source_line).start(0)
|
||||||
|
source_buffer = source.Buffer(source_line, filename, line)
|
||||||
|
return source.Range(source_buffer, column, column)
|
||||||
|
|
||||||
|
def _call_site_note(self, call_loc, fn_kind):
|
||||||
|
if call_loc:
|
||||||
|
if fn_kind == 'syscall':
|
||||||
|
return [diagnostic.Diagnostic("note",
|
||||||
|
"in system call here", {},
|
||||||
|
call_loc)]
|
||||||
|
elif fn_kind == 'rpc':
|
||||||
|
return [diagnostic.Diagnostic("note",
|
||||||
|
"in function called remotely here", {},
|
||||||
|
call_loc)]
|
||||||
|
elif fn_kind == 'kernel':
|
||||||
|
return [diagnostic.Diagnostic("note",
|
||||||
|
"in kernel function here", {},
|
||||||
|
call_loc)]
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _type_of_param(self, function, loc, param, fn_kind):
|
||||||
|
if param.annotation is not inspect.Parameter.empty:
|
||||||
|
# Type specified explicitly.
|
||||||
|
return self._extract_annot(function, param.annotation,
|
||||||
|
"argument '{}'".format(param.name), loc,
|
||||||
|
fn_kind)
|
||||||
|
elif fn_kind == 'syscall':
|
||||||
|
# Syscalls must be entirely annotated.
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"system call argument '{argument}' must have a type annotation",
|
||||||
|
{"argument": param.name},
|
||||||
|
self._function_loc(function),
|
||||||
|
notes=self._call_site_note(loc, fn_kind))
|
||||||
|
self.engine.process(diag)
|
||||||
|
elif fn_kind == 'rpc' and param.default is not inspect.Parameter.empty:
|
||||||
|
notes = []
|
||||||
|
notes.append(diagnostic.Diagnostic("note",
|
||||||
|
"expanded from here while trying to infer a type for an"
|
||||||
|
" unannotated optional argument '{argument}' from its default value",
|
||||||
|
{"argument": param.name},
|
||||||
|
self._function_loc(function)))
|
||||||
|
if loc is not None:
|
||||||
|
notes.append(self._call_site_note(loc, fn_kind))
|
||||||
|
|
||||||
|
with self.engine.context(*notes):
|
||||||
|
# Try and infer the type from the default value.
|
||||||
|
# This is tricky, because the default value might not have
|
||||||
|
# a well-defined type in APython.
|
||||||
|
# In this case, we bail out, but mention why we do it.
|
||||||
|
ast = self._quote(param.default, None)
|
||||||
|
Inferencer(engine=self.engine).visit(ast)
|
||||||
|
IntMonomorphizer(engine=self.engine).visit(ast)
|
||||||
|
return ast.type
|
||||||
|
else:
|
||||||
|
# Let the rest of the program decide.
|
||||||
|
return types.TVar()
|
||||||
|
|
||||||
def _quote_embedded_function(self, function, flags):
|
def _quote_embedded_function(self, function, flags):
|
||||||
if isinstance(function, SpecializedFunction):
|
if isinstance(function, SpecializedFunction):
|
||||||
host_function = function.host_function
|
host_function = function.host_function
|
||||||
|
@ -811,6 +844,34 @@ class Stitcher:
|
||||||
module_name = embedded_function.__globals__['__name__']
|
module_name = embedded_function.__globals__['__name__']
|
||||||
first_line = embedded_function.__code__.co_firstlineno
|
first_line = embedded_function.__code__.co_firstlineno
|
||||||
|
|
||||||
|
# Extract function annotation.
|
||||||
|
signature = inspect.signature(embedded_function)
|
||||||
|
loc = self._function_loc(embedded_function)
|
||||||
|
|
||||||
|
arg_types = OrderedDict()
|
||||||
|
optarg_types = OrderedDict()
|
||||||
|
for param in signature.parameters.values():
|
||||||
|
if param.kind == inspect.Parameter.VAR_POSITIONAL or \
|
||||||
|
param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"variadic arguments are not supported; '{argument}' is variadic",
|
||||||
|
{"argument": param.name},
|
||||||
|
self._function_loc(function),
|
||||||
|
notes=self._call_site_note(loc, fn_kind='kernel'))
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
arg_type = self._type_of_param(function, loc, param, fn_kind='kernel')
|
||||||
|
if param.default is inspect.Parameter.empty:
|
||||||
|
arg_types[param.name] = arg_type
|
||||||
|
else:
|
||||||
|
optarg_types[param.name] = arg_type
|
||||||
|
|
||||||
|
if signature.return_annotation is not inspect.Signature.empty:
|
||||||
|
ret_type = self._extract_annot(function, signature.return_annotation,
|
||||||
|
"return type", loc, fn_kind='kernel')
|
||||||
|
else:
|
||||||
|
ret_type = types.TVar()
|
||||||
|
|
||||||
# Extract function environment.
|
# Extract function environment.
|
||||||
host_environment = dict()
|
host_environment = dict()
|
||||||
host_environment.update(embedded_function.__globals__)
|
host_environment.update(embedded_function.__globals__)
|
||||||
|
@ -846,9 +907,9 @@ class Stitcher:
|
||||||
# can handle quoting it.
|
# can handle quoting it.
|
||||||
self.embedding_map.store_function(function, function_node.name)
|
self.embedding_map.store_function(function, function_node.name)
|
||||||
|
|
||||||
# Memoize the function type before typing it to handle recursive
|
# Fill in the function type before typing it to handle recursive
|
||||||
# invocations.
|
# invocations.
|
||||||
self.functions[function] = types.TVar()
|
self.functions[function] = types.TFunction(arg_types, optarg_types, ret_type)
|
||||||
|
|
||||||
# Rewrite into typed form.
|
# Rewrite into typed form.
|
||||||
asttyped_rewriter = StitchingASTTypedRewriter(
|
asttyped_rewriter = StitchingASTTypedRewriter(
|
||||||
|
@ -866,86 +927,19 @@ class Stitcher:
|
||||||
|
|
||||||
return function_node
|
return function_node
|
||||||
|
|
||||||
def _function_loc(self, function):
|
def _extract_annot(self, function, annot, kind, call_loc, fn_kind):
|
||||||
filename = function.__code__.co_filename
|
|
||||||
line = function.__code__.co_firstlineno
|
|
||||||
name = function.__code__.co_name
|
|
||||||
|
|
||||||
source_line = linecache.getline(filename, line).lstrip()
|
|
||||||
while source_line.startswith("@") or source_line == "":
|
|
||||||
line += 1
|
|
||||||
source_line = linecache.getline(filename, line).lstrip()
|
|
||||||
|
|
||||||
if "<lambda>" in function.__qualname__:
|
|
||||||
column = 0 # can't get column of lambda
|
|
||||||
else:
|
|
||||||
column = re.search("def", source_line).start(0)
|
|
||||||
source_buffer = source.Buffer(source_line, filename, line)
|
|
||||||
return source.Range(source_buffer, column, column)
|
|
||||||
|
|
||||||
def _call_site_note(self, call_loc, is_syscall):
|
|
||||||
if call_loc:
|
|
||||||
if is_syscall:
|
|
||||||
return [diagnostic.Diagnostic("note",
|
|
||||||
"in system call here", {},
|
|
||||||
call_loc)]
|
|
||||||
else:
|
|
||||||
return [diagnostic.Diagnostic("note",
|
|
||||||
"in function called remotely here", {},
|
|
||||||
call_loc)]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _extract_annot(self, function, annot, kind, call_loc, is_syscall):
|
|
||||||
if not isinstance(annot, types.Type):
|
if not isinstance(annot, types.Type):
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"type annotation for {kind}, '{annot}', is not an ARTIQ type",
|
"type annotation for {kind}, '{annot}', is not an ARTIQ type",
|
||||||
{"kind": kind, "annot": repr(annot)},
|
{"kind": kind, "annot": repr(annot)},
|
||||||
self._function_loc(function),
|
self._function_loc(function),
|
||||||
notes=self._call_site_note(call_loc, is_syscall))
|
notes=self._call_site_note(call_loc, fn_kind))
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
return types.TVar()
|
return types.TVar()
|
||||||
else:
|
else:
|
||||||
return annot
|
return annot
|
||||||
|
|
||||||
def _type_of_param(self, function, loc, param, is_syscall):
|
|
||||||
if param.annotation is not inspect.Parameter.empty:
|
|
||||||
# Type specified explicitly.
|
|
||||||
return self._extract_annot(function, param.annotation,
|
|
||||||
"argument '{}'".format(param.name), loc,
|
|
||||||
is_syscall)
|
|
||||||
elif is_syscall:
|
|
||||||
# Syscalls must be entirely annotated.
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
|
||||||
"system call argument '{argument}' must have a type annotation",
|
|
||||||
{"argument": param.name},
|
|
||||||
self._function_loc(function),
|
|
||||||
notes=self._call_site_note(loc, is_syscall))
|
|
||||||
self.engine.process(diag)
|
|
||||||
elif param.default is not inspect.Parameter.empty:
|
|
||||||
notes = []
|
|
||||||
notes.append(diagnostic.Diagnostic("note",
|
|
||||||
"expanded from here while trying to infer a type for an"
|
|
||||||
" unannotated optional argument '{argument}' from its default value",
|
|
||||||
{"argument": param.name},
|
|
||||||
self._function_loc(function)))
|
|
||||||
if loc is not None:
|
|
||||||
notes.append(self._call_site_note(loc, is_syscall))
|
|
||||||
|
|
||||||
with self.engine.context(*notes):
|
|
||||||
# Try and infer the type from the default value.
|
|
||||||
# This is tricky, because the default value might not have
|
|
||||||
# a well-defined type in APython.
|
|
||||||
# In this case, we bail out, but mention why we do it.
|
|
||||||
ast = self._quote(param.default, None)
|
|
||||||
Inferencer(engine=self.engine).visit(ast)
|
|
||||||
IntMonomorphizer(engine=self.engine).visit(ast)
|
|
||||||
return ast.type
|
|
||||||
else:
|
|
||||||
# Let the rest of the program decide.
|
|
||||||
return types.TVar()
|
|
||||||
|
|
||||||
def _quote_syscall(self, function, loc):
|
def _quote_syscall(self, function, loc):
|
||||||
signature = inspect.signature(function)
|
signature = inspect.signature(function)
|
||||||
|
|
||||||
|
@ -957,27 +951,28 @@ class Stitcher:
|
||||||
"system calls must only use positional arguments; '{argument}' isn't",
|
"system calls must only use positional arguments; '{argument}' isn't",
|
||||||
{"argument": param.name},
|
{"argument": param.name},
|
||||||
self._function_loc(function),
|
self._function_loc(function),
|
||||||
notes=self._call_site_note(loc, is_syscall=True))
|
notes=self._call_site_note(loc, fn_kind='syscall'))
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
if param.default is inspect.Parameter.empty:
|
if param.default is inspect.Parameter.empty:
|
||||||
arg_types[param.name] = self._type_of_param(function, loc, param, is_syscall=True)
|
arg_types[param.name] = self._type_of_param(function, loc, param,
|
||||||
|
fn_kind='syscall')
|
||||||
else:
|
else:
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"system call argument '{argument}' must not have a default value",
|
"system call argument '{argument}' must not have a default value",
|
||||||
{"argument": param.name},
|
{"argument": param.name},
|
||||||
self._function_loc(function),
|
self._function_loc(function),
|
||||||
notes=self._call_site_note(loc, is_syscall=True))
|
notes=self._call_site_note(loc, fn_kind='syscall'))
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
if signature.return_annotation is not inspect.Signature.empty:
|
if signature.return_annotation is not inspect.Signature.empty:
|
||||||
ret_type = self._extract_annot(function, signature.return_annotation,
|
ret_type = self._extract_annot(function, signature.return_annotation,
|
||||||
"return type", loc, is_syscall=True)
|
"return type", loc, fn_kind='syscall')
|
||||||
else:
|
else:
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"system call must have a return type annotation", {},
|
"system call must have a return type annotation", {},
|
||||||
self._function_loc(function),
|
self._function_loc(function),
|
||||||
notes=self._call_site_note(loc, is_syscall=True))
|
notes=self._call_site_note(loc, fn_kind='syscall'))
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
ret_type = types.TVar()
|
ret_type = types.TVar()
|
||||||
|
|
||||||
|
@ -1005,7 +1000,7 @@ class Stitcher:
|
||||||
signature = inspect.signature(host_function.__func__)
|
signature = inspect.signature(host_function.__func__)
|
||||||
if signature.return_annotation is not inspect.Signature.empty:
|
if signature.return_annotation is not inspect.Signature.empty:
|
||||||
ret_type = self._extract_annot(host_function, signature.return_annotation,
|
ret_type = self._extract_annot(host_function, signature.return_annotation,
|
||||||
"return type", loc, is_syscall=False)
|
"return type", loc, fn_kind='rpc')
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
@ -1077,7 +1072,7 @@ class Stitcher:
|
||||||
diag = diagnostic.Diagnostic("fatal",
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
"this function cannot be called as an RPC", {},
|
"this function cannot be called as an RPC", {},
|
||||||
self._function_loc(host_function),
|
self._function_loc(host_function),
|
||||||
notes=self._call_site_note(loc, is_syscall=False))
|
notes=self._call_site_note(loc, fn_kind='rpc'))
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t.ll
|
||||||
|
|
||||||
|
from artiq.language.core import *
|
||||||
|
from artiq.language.types import *
|
||||||
|
|
||||||
|
# CHECK: i64 @_Z13testbench.foozz\(i64 %ARG.x, \{ i1, i64 \} %ARG.y\)
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def foo(x: TInt64, y: TInt64 = 1) -> TInt64:
|
||||||
|
print(x+y)
|
||||||
|
return x+y
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def entrypoint():
|
||||||
|
print(foo(0, 2))
|
||||||
|
print(foo(1, 3))
|
Loading…
Reference in New Issue